Merge remote-tracking branch 'upstream/main'
# Conflicts: # frontend/src/components/account/CreateAccountModal.vue
This commit is contained in:
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -78,6 +79,36 @@ func (a *Account) IsGemini() bool {
|
||||
return a.Platform == PlatformGemini
|
||||
}
|
||||
|
||||
func (a *Account) GeminiOAuthType() string {
|
||||
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
|
||||
return ""
|
||||
}
|
||||
oauthType := strings.TrimSpace(a.GetCredential("oauth_type"))
|
||||
if oauthType == "" && strings.TrimSpace(a.GetCredential("project_id")) != "" {
|
||||
return "code_assist"
|
||||
}
|
||||
return oauthType
|
||||
}
|
||||
|
||||
func (a *Account) GeminiTierID() string {
|
||||
tierID := strings.TrimSpace(a.GetCredential("tier_id"))
|
||||
if tierID == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.ToUpper(tierID)
|
||||
}
|
||||
|
||||
func (a *Account) IsGeminiCodeAssist() bool {
|
||||
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
|
||||
return false
|
||||
}
|
||||
oauthType := a.GeminiOAuthType()
|
||||
if oauthType == "" {
|
||||
return strings.TrimSpace(a.GetCredential("project_id")) != ""
|
||||
}
|
||||
return oauthType == "code_assist"
|
||||
}
|
||||
|
||||
func (a *Account) CanGetUsage() bool {
|
||||
return a.Type == AccountTypeOAuth
|
||||
}
|
||||
@@ -110,6 +141,28 @@ func (a *Account) GetCredential(key string) string {
|
||||
}
|
||||
}
|
||||
|
||||
// GetCredentialAsTime 解析凭证中的时间戳字段,支持多种格式
|
||||
// 兼容以下格式:
|
||||
// - RFC3339 字符串: "2025-01-01T00:00:00Z"
|
||||
// - Unix 时间戳字符串: "1735689600"
|
||||
// - Unix 时间戳数字: 1735689600 (float64/int64/json.Number)
|
||||
func (a *Account) GetCredentialAsTime(key string) *time.Time {
|
||||
s := a.GetCredential(key)
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
// 尝试 RFC3339 格式
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
return &t
|
||||
}
|
||||
// 尝试 Unix 时间戳(纯数字字符串)
|
||||
if ts, err := strconv.ParseInt(s, 10, 64); err == nil {
|
||||
t := time.Unix(ts, 0)
|
||||
return &t
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Account) GetModelMapping() map[string]string {
|
||||
if a.Credentials == nil {
|
||||
return nil
|
||||
@@ -324,19 +377,7 @@ func (a *Account) GetOpenAITokenExpiresAt() *time.Time {
|
||||
if !a.IsOpenAIOAuth() {
|
||||
return nil
|
||||
}
|
||||
expiresAtStr := a.GetCredential("expires_at")
|
||||
if expiresAtStr == "" {
|
||||
return nil
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339, expiresAtStr)
|
||||
if err != nil {
|
||||
if v, ok := a.Credentials["expires_at"].(float64); ok {
|
||||
tt := time.Unix(int64(v), 0)
|
||||
return &tt
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return &t
|
||||
return a.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
|
||||
func (a *Account) IsOpenAITokenExpired() bool {
|
||||
|
||||
@@ -5,12 +5,13 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAccountNotFound = infraerrors.NotFound("ACCOUNT_NOT_FOUND", "account not found")
|
||||
ErrAccountNilInput = infraerrors.BadRequest("ACCOUNT_NIL_INPUT", "account input cannot be nil")
|
||||
)
|
||||
|
||||
type AccountRepository interface {
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -187,9 +186,8 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
|
||||
// Check if token needs refresh
|
||||
needRefresh := false
|
||||
if expiresAtStr := account.GetCredential("expires_at"); expiresAtStr != "" {
|
||||
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
|
||||
if err == nil && time.Now().Unix()+300 > expiresAt {
|
||||
if expiresAt := account.GetCredentialAsTime("expires_at"); expiresAt != nil {
|
||||
if time.Now().Add(5 * time.Minute).After(*expiresAt) {
|
||||
needRefresh = true
|
||||
}
|
||||
}
|
||||
@@ -263,7 +261,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL)
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
@@ -378,7 +376,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL)
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
@@ -449,7 +447,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL)
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
|
||||
@@ -52,6 +52,9 @@ type UsageLogRepository interface {
|
||||
// Aggregated stats (optimized)
|
||||
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
|
||||
}
|
||||
|
||||
// apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at)
|
||||
@@ -90,10 +93,12 @@ type UsageProgress struct {
|
||||
|
||||
// UsageInfo 账号使用量信息
|
||||
type UsageInfo struct {
|
||||
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
||||
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
|
||||
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
|
||||
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
|
||||
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
||||
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
|
||||
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
|
||||
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
|
||||
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
|
||||
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
|
||||
}
|
||||
|
||||
// ClaudeUsageResponse Anthropic API返回的usage结构
|
||||
@@ -119,17 +124,19 @@ type ClaudeUsageFetcher interface {
|
||||
|
||||
// AccountUsageService 账号使用量查询服务
|
||||
type AccountUsageService struct {
|
||||
accountRepo AccountRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
usageFetcher ClaudeUsageFetcher
|
||||
accountRepo AccountRepository
|
||||
usageLogRepo UsageLogRepository
|
||||
usageFetcher ClaudeUsageFetcher
|
||||
geminiQuotaService *GeminiQuotaService
|
||||
}
|
||||
|
||||
// NewAccountUsageService 创建AccountUsageService实例
|
||||
func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
|
||||
func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher, geminiQuotaService *GeminiQuotaService) *AccountUsageService {
|
||||
return &AccountUsageService{
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
usageFetcher: usageFetcher,
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
usageFetcher: usageFetcher,
|
||||
geminiQuotaService: geminiQuotaService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,6 +150,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
||||
return nil, fmt.Errorf("get account failed: %w", err)
|
||||
}
|
||||
|
||||
if account.Platform == PlatformGemini {
|
||||
return s.getGeminiUsage(ctx, account)
|
||||
}
|
||||
|
||||
// 只有oauth类型账号可以通过API获取usage(有profile scope)
|
||||
if account.CanGetUsage() {
|
||||
var apiResp *ClaudeUsageResponse
|
||||
@@ -189,6 +200,36 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
|
||||
return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
|
||||
}
|
||||
|
||||
func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
|
||||
now := time.Now()
|
||||
usage := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
}
|
||||
|
||||
if s.geminiQuotaService == nil || s.usageLogRepo == nil {
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account)
|
||||
if !ok {
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
start := geminiDailyWindowStart(now)
|
||||
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
|
||||
}
|
||||
|
||||
totals := geminiAggregateUsage(stats)
|
||||
resetAt := geminiDailyResetTime(now)
|
||||
|
||||
usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost, now)
|
||||
usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost, now)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// addWindowStats 为 usage 数据添加窗口期统计
|
||||
// 使用独立缓存(1 分钟),与 API 缓存分离
|
||||
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
|
||||
@@ -385,3 +426,25 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
|
||||
// Setup Token无法获取7d数据
|
||||
return info
|
||||
}
|
||||
|
||||
func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress {
|
||||
if limit <= 0 {
|
||||
return nil
|
||||
}
|
||||
utilization := (float64(used) / float64(limit)) * 100
|
||||
remainingSeconds := int(resetAt.Sub(now).Seconds())
|
||||
if remainingSeconds < 0 {
|
||||
remainingSeconds = 0
|
||||
}
|
||||
resetCopy := resetAt
|
||||
return &UsageProgress{
|
||||
Utilization: utilization,
|
||||
ResetsAt: &resetCopy,
|
||||
RemainingSeconds: remainingSeconds,
|
||||
WindowStats: &WindowStats{
|
||||
Requests: used,
|
||||
Tokens: tokens,
|
||||
Cost: cost,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -488,6 +488,11 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
subscriptionType = SubscriptionTypeStandard
|
||||
}
|
||||
|
||||
// 限额字段:0 和 nil 都表示"无限制"
|
||||
dailyLimit := normalizeLimit(input.DailyLimitUSD)
|
||||
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
|
||||
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
|
||||
|
||||
group := &Group{
|
||||
Name: input.Name,
|
||||
Description: input.Description,
|
||||
@@ -496,9 +501,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
IsExclusive: input.IsExclusive,
|
||||
Status: StatusActive,
|
||||
SubscriptionType: subscriptionType,
|
||||
DailyLimitUSD: input.DailyLimitUSD,
|
||||
WeeklyLimitUSD: input.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: input.MonthlyLimitUSD,
|
||||
DailyLimitUSD: dailyLimit,
|
||||
WeeklyLimitUSD: weeklyLimit,
|
||||
MonthlyLimitUSD: monthlyLimit,
|
||||
}
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
return nil, err
|
||||
@@ -506,6 +511,14 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
return group, nil
|
||||
}
|
||||
|
||||
// normalizeLimit 将 0 或负数转换为 nil(表示无限制)
|
||||
func normalizeLimit(limit *float64) *float64 {
|
||||
if limit == nil || *limit <= 0 {
|
||||
return nil
|
||||
}
|
||||
return limit
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
|
||||
group, err := s.groupRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
@@ -535,15 +548,15 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
if input.SubscriptionType != "" {
|
||||
group.SubscriptionType = input.SubscriptionType
|
||||
}
|
||||
// 限额字段支持设置为nil(清除限额)或具体值
|
||||
// 限额字段:0 和 nil 都表示"无限制",正数表示具体限额
|
||||
if input.DailyLimitUSD != nil {
|
||||
group.DailyLimitUSD = input.DailyLimitUSD
|
||||
group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
|
||||
}
|
||||
if input.WeeklyLimitUSD != nil {
|
||||
group.WeeklyLimitUSD = input.WeeklyLimitUSD
|
||||
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
|
||||
}
|
||||
if input.MonthlyLimitUSD != nil {
|
||||
group.MonthlyLimitUSD = input.MonthlyLimitUSD
|
||||
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
|
||||
}
|
||||
|
||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||
|
||||
@@ -25,7 +25,7 @@ const (
|
||||
antigravityRetryMaxDelay = 16 * time.Second
|
||||
)
|
||||
|
||||
// Antigravity 直接支持的模型
|
||||
// Antigravity 直接支持的模型(精确匹配透传)
|
||||
var antigravitySupportedModels = map[string]bool{
|
||||
"claude-opus-4-5-thinking": true,
|
||||
"claude-sonnet-4-5": true,
|
||||
@@ -36,23 +36,26 @@ var antigravitySupportedModels = map[string]bool{
|
||||
"gemini-3-flash": true,
|
||||
"gemini-3-pro-low": true,
|
||||
"gemini-3-pro-high": true,
|
||||
"gemini-3-pro-preview": true,
|
||||
"gemini-3-pro-image": true,
|
||||
}
|
||||
|
||||
// Antigravity 系统默认模型映射表(不支持 → 支持)
|
||||
var antigravityModelMapping = map[string]string{
|
||||
"claude-3-5-sonnet-20241022": "claude-sonnet-4-5",
|
||||
"claude-3-5-sonnet-20240620": "claude-sonnet-4-5",
|
||||
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5-thinking",
|
||||
"claude-opus-4": "claude-opus-4-5-thinking",
|
||||
"claude-opus-4-5-20251101": "claude-opus-4-5-thinking",
|
||||
"claude-haiku-4": "gemini-3-flash",
|
||||
"claude-haiku-4-5": "gemini-3-flash",
|
||||
"claude-3-haiku-20240307": "gemini-3-flash",
|
||||
"claude-haiku-4-5-20251001": "gemini-3-flash",
|
||||
// 生图模型:官方名 → Antigravity 内部名
|
||||
"gemini-3-pro-image-preview": "gemini-3-pro-image",
|
||||
// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先)
|
||||
// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
|
||||
var antigravityPrefixMapping = []struct {
|
||||
prefix string
|
||||
target string
|
||||
}{
|
||||
// 长前缀优先
|
||||
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
|
||||
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
|
||||
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
|
||||
{"claude-haiku-4-5", "gemini-3-flash"}, // claude-haiku-4-5-xxx
|
||||
{"claude-opus-4-5", "claude-opus-4-5-thinking"},
|
||||
{"claude-3-haiku", "gemini-3-flash"}, // 旧版 claude-3-haiku-xxx
|
||||
{"claude-sonnet-4", "claude-sonnet-4-5"},
|
||||
{"claude-haiku-4", "gemini-3-flash"},
|
||||
{"claude-opus-4", "claude-opus-4-5-thinking"},
|
||||
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
|
||||
}
|
||||
|
||||
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
|
||||
@@ -84,24 +87,27 @@ func (s *AntigravityGatewayService) GetTokenProvider() *AntigravityTokenProvider
|
||||
}
|
||||
|
||||
// getMappedModel 获取映射后的模型名
|
||||
// 逻辑:账户映射 → 直接支持透传 → 前缀映射 → gemini透传 → 默认值
|
||||
func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string {
|
||||
// 1. 优先使用账户级映射(复用现有方法)
|
||||
// 1. 账户级映射(用户自定义优先)
|
||||
if mapped := account.GetMappedModel(requestedModel); mapped != requestedModel {
|
||||
return mapped
|
||||
}
|
||||
|
||||
// 2. 系统默认映射
|
||||
if mapped, ok := antigravityModelMapping[requestedModel]; ok {
|
||||
return mapped
|
||||
}
|
||||
|
||||
// 3. Gemini 模型透传
|
||||
if strings.HasPrefix(requestedModel, "gemini-") {
|
||||
// 2. 直接支持的模型透传
|
||||
if antigravitySupportedModels[requestedModel] {
|
||||
return requestedModel
|
||||
}
|
||||
|
||||
// 4. Claude 前缀透传直接支持的模型
|
||||
if antigravitySupportedModels[requestedModel] {
|
||||
// 3. 前缀映射(处理版本号变化,如 -20251111, -thinking, -preview)
|
||||
for _, pm := range antigravityPrefixMapping {
|
||||
if strings.HasPrefix(requestedModel, pm.prefix) {
|
||||
return pm.target
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Gemini 模型透传(未匹配到前缀的 gemini 模型)
|
||||
if strings.HasPrefix(requestedModel, "gemini-") {
|
||||
return requestedModel
|
||||
}
|
||||
|
||||
@@ -110,24 +116,10 @@ func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedMo
|
||||
}
|
||||
|
||||
// IsModelSupported 检查模型是否被支持
|
||||
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
|
||||
func (s *AntigravityGatewayService) IsModelSupported(requestedModel string) bool {
|
||||
// 直接支持的模型
|
||||
if antigravitySupportedModels[requestedModel] {
|
||||
return true
|
||||
}
|
||||
// 可映射的模型
|
||||
if _, ok := antigravityModelMapping[requestedModel]; ok {
|
||||
return true
|
||||
}
|
||||
// Gemini 前缀透传
|
||||
if strings.HasPrefix(requestedModel, "gemini-") {
|
||||
return true
|
||||
}
|
||||
// Claude 模型支持(通过默认映射)
|
||||
if strings.HasPrefix(requestedModel, "claude-") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
return strings.HasPrefix(requestedModel, "claude-") ||
|
||||
strings.HasPrefix(requestedModel, "gemini-")
|
||||
}
|
||||
|
||||
// TestConnectionResult 测试连接结果
|
||||
@@ -180,7 +172,7 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL)
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求失败: %w", err)
|
||||
}
|
||||
@@ -358,6 +350,15 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
return nil, fmt.Errorf("transform request: %w", err)
|
||||
}
|
||||
|
||||
// 调试:记录转换后的请求体(仅记录前 2000 字符)
|
||||
if bodyJSON, err := json.Marshal(geminiBody); err == nil {
|
||||
truncated := string(bodyJSON)
|
||||
if len(truncated) > 2000 {
|
||||
truncated = truncated[:2000] + "..."
|
||||
}
|
||||
log.Printf("[Debug] Transformed Gemini request: %s", truncated)
|
||||
}
|
||||
|
||||
// 构建上游 action
|
||||
action := "generateContent"
|
||||
if claudeReq.Stream {
|
||||
@@ -372,7 +373,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err)
|
||||
@@ -515,7 +516,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
if attempt < antigravityMaxRetries {
|
||||
log.Printf("Antigravity account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, antigravityMaxRetries, err)
|
||||
|
||||
@@ -131,7 +131,7 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
|
||||
name: "系统映射 - claude-sonnet-4-5-20250929",
|
||||
requestedModel: "claude-sonnet-4-5-20250929",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5-thinking",
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
|
||||
// 3. Gemini 透传
|
||||
|
||||
@@ -191,7 +191,7 @@ func (r *AntigravityQuotaRefresher) refreshAccountQuota(ctx context.Context, acc
|
||||
|
||||
// isTokenExpired 检查 token 是否过期
|
||||
func (r *AntigravityQuotaRefresher) isTokenExpired(account *Account) bool {
|
||||
expiresAt := parseAntigravityExpiresAt(account)
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
}
|
||||
|
||||
// 2. 如果即将过期则刷新
|
||||
expiresAt := parseAntigravityExpiresAt(account)
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
@@ -72,7 +72,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
expiresAt = parseAntigravityExpiresAt(account)
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew {
|
||||
if p.antigravityOAuthService == nil {
|
||||
return "", errors.New("antigravity oauth service not configured")
|
||||
@@ -91,7 +91,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
log.Printf("[AntigravityTokenProvider] Failed to update account credentials: %v", updateErr)
|
||||
}
|
||||
expiresAt = parseAntigravityExpiresAt(account)
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -128,18 +128,3 @@ func antigravityTokenCacheKey(account *Account) string {
|
||||
}
|
||||
return "ag:account:" + strconv.FormatInt(account.ID, 10)
|
||||
}
|
||||
|
||||
func parseAntigravityExpiresAt(account *Account) *time.Time {
|
||||
raw := strings.TrimSpace(account.GetCredential("expires_at"))
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
if unixSec, err := strconv.ParseInt(raw, 10, 64); err == nil && unixSec > 0 {
|
||||
t := time.Unix(unixSec, 0)
|
||||
return &t
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339, raw); err == nil {
|
||||
return &t
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -29,21 +29,22 @@ func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
|
||||
}
|
||||
|
||||
// NeedsRefresh 检查账户是否需要刷新
|
||||
// Antigravity 使用固定的10分钟刷新窗口,忽略全局配置
|
||||
// Antigravity 使用固定的15分钟刷新窗口,忽略全局配置
|
||||
func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Duration) bool {
|
||||
if !r.CanRefresh(account) {
|
||||
return false
|
||||
}
|
||||
expiresAtStr := account.GetCredential("expires_at")
|
||||
if expiresAtStr == "" {
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil {
|
||||
return false
|
||||
}
|
||||
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
|
||||
if err != nil {
|
||||
return false
|
||||
timeUntilExpiry := time.Until(*expiresAt)
|
||||
needsRefresh := timeUntilExpiry < antigravityRefreshWindow
|
||||
if needsRefresh {
|
||||
fmt.Printf("[AntigravityTokenRefresher] Account %d needs refresh: expires_at=%s, time_until_expiry=%v, window=%v\n",
|
||||
account.ID, expiresAt.Format("2006-01-02 15:04:05"), timeUntilExpiry, antigravityRefreshWindow)
|
||||
}
|
||||
expiryTime := time.Unix(expiresAt, 0)
|
||||
return time.Until(expiryTime) < antigravityRefreshWindow
|
||||
return needsRefresh
|
||||
}
|
||||
|
||||
// Refresh 执行 token 刷新
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
@@ -4,10 +4,12 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
// 错误定义
|
||||
@@ -27,6 +29,46 @@ type subscriptionCacheData struct {
|
||||
Version int64
|
||||
}
|
||||
|
||||
// 缓存写入任务类型
|
||||
type cacheWriteKind int
|
||||
|
||||
const (
|
||||
cacheWriteSetBalance cacheWriteKind = iota
|
||||
cacheWriteSetSubscription
|
||||
cacheWriteUpdateSubscriptionUsage
|
||||
cacheWriteDeductBalance
|
||||
)
|
||||
|
||||
// 异步缓存写入工作池配置
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现在请求热路径中使用 goroutine 异步更新缓存,存在以下问题:
|
||||
// 1. 每次请求创建新 goroutine,高并发下产生大量短生命周期 goroutine
|
||||
// 2. 无法控制并发数量,可能导致 Redis 连接耗尽
|
||||
// 3. goroutine 创建/销毁带来额外开销
|
||||
//
|
||||
// 新实现使用固定大小的工作池:
|
||||
// 1. 预创建 10 个 worker goroutine,避免频繁创建销毁
|
||||
// 2. 使用带缓冲的 channel(1000)作为任务队列,平滑写入峰值
|
||||
// 3. 非阻塞写入,队列满时关键任务同步回退,非关键任务丢弃并告警
|
||||
// 4. 统一超时控制,避免慢操作阻塞工作池
|
||||
const (
|
||||
cacheWriteWorkerCount = 10 // 工作协程数量
|
||||
cacheWriteBufferSize = 1000 // 任务队列缓冲大小
|
||||
cacheWriteTimeout = 2 * time.Second // 单个写入操作超时
|
||||
cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔
|
||||
)
|
||||
|
||||
// cacheWriteTask 缓存写入任务
|
||||
type cacheWriteTask struct {
|
||||
kind cacheWriteKind
|
||||
userID int64
|
||||
groupID int64
|
||||
balance float64
|
||||
amount float64
|
||||
subscriptionData *subscriptionCacheData
|
||||
}
|
||||
|
||||
// BillingCacheService 计费缓存服务
|
||||
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
|
||||
type BillingCacheService struct {
|
||||
@@ -34,16 +76,151 @@ type BillingCacheService struct {
|
||||
userRepo UserRepository
|
||||
subRepo UserSubscriptionRepository
|
||||
cfg *config.Config
|
||||
|
||||
cacheWriteChan chan cacheWriteTask
|
||||
cacheWriteWg sync.WaitGroup
|
||||
cacheWriteStopOnce sync.Once
|
||||
// 丢弃日志节流计数器(减少高负载下日志噪音)
|
||||
cacheWriteDropFullCount uint64
|
||||
cacheWriteDropFullLastLog int64
|
||||
cacheWriteDropClosedCount uint64
|
||||
cacheWriteDropClosedLastLog int64
|
||||
}
|
||||
|
||||
// NewBillingCacheService 创建计费缓存服务
|
||||
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, cfg *config.Config) *BillingCacheService {
|
||||
return &BillingCacheService{
|
||||
svc := &BillingCacheService{
|
||||
cache: cache,
|
||||
userRepo: userRepo,
|
||||
subRepo: subRepo,
|
||||
cfg: cfg,
|
||||
}
|
||||
svc.startCacheWriteWorkers()
|
||||
return svc
|
||||
}
|
||||
|
||||
// Stop 关闭缓存写入工作池
|
||||
func (s *BillingCacheService) Stop() {
|
||||
s.cacheWriteStopOnce.Do(func() {
|
||||
if s.cacheWriteChan == nil {
|
||||
return
|
||||
}
|
||||
close(s.cacheWriteChan)
|
||||
s.cacheWriteWg.Wait()
|
||||
s.cacheWriteChan = nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) startCacheWriteWorkers() {
|
||||
s.cacheWriteChan = make(chan cacheWriteTask, cacheWriteBufferSize)
|
||||
for i := 0; i < cacheWriteWorkerCount; i++ {
|
||||
s.cacheWriteWg.Add(1)
|
||||
go s.cacheWriteWorker()
|
||||
}
|
||||
}
|
||||
|
||||
// enqueueCacheWrite 尝试将任务入队,队列满时返回 false(并记录告警)。
|
||||
func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) {
|
||||
if s.cacheWriteChan == nil {
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
// 队列已关闭时可能触发 panic,记录后静默失败。
|
||||
s.logCacheWriteDrop(task, "closed")
|
||||
enqueued = false
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case s.cacheWriteChan <- task:
|
||||
return true
|
||||
default:
|
||||
// 队列满时不阻塞主流程,交由调用方决定是否同步回退。
|
||||
s.logCacheWriteDrop(task, "full")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) cacheWriteWorker() {
|
||||
defer s.cacheWriteWg.Done()
|
||||
for task := range s.cacheWriteChan {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||||
switch task.kind {
|
||||
case cacheWriteSetBalance:
|
||||
s.setBalanceCache(ctx, task.userID, task.balance)
|
||||
case cacheWriteSetSubscription:
|
||||
s.setSubscriptionCache(ctx, task.userID, task.groupID, task.subscriptionData)
|
||||
case cacheWriteUpdateSubscriptionUsage:
|
||||
if s.cache != nil {
|
||||
if err := s.cache.UpdateSubscriptionUsage(ctx, task.userID, task.groupID, task.amount); err != nil {
|
||||
log.Printf("Warning: update subscription cache failed for user %d group %d: %v", task.userID, task.groupID, err)
|
||||
}
|
||||
}
|
||||
case cacheWriteDeductBalance:
|
||||
if s.cache != nil {
|
||||
if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil {
|
||||
log.Printf("Warning: deduct balance cache failed for user %d: %v", task.userID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// cacheWriteKindName 用于日志中的任务类型标识,便于排查丢弃原因。
|
||||
func cacheWriteKindName(kind cacheWriteKind) string {
|
||||
switch kind {
|
||||
case cacheWriteSetBalance:
|
||||
return "set_balance"
|
||||
case cacheWriteSetSubscription:
|
||||
return "set_subscription"
|
||||
case cacheWriteUpdateSubscriptionUsage:
|
||||
return "update_subscription_usage"
|
||||
case cacheWriteDeductBalance:
|
||||
return "deduct_balance"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// logCacheWriteDrop 使用节流方式记录丢弃情况,并汇总丢弃数量。
|
||||
func (s *BillingCacheService) logCacheWriteDrop(task cacheWriteTask, reason string) {
|
||||
var (
|
||||
countPtr *uint64
|
||||
lastPtr *int64
|
||||
)
|
||||
switch reason {
|
||||
case "full":
|
||||
countPtr = &s.cacheWriteDropFullCount
|
||||
lastPtr = &s.cacheWriteDropFullLastLog
|
||||
case "closed":
|
||||
countPtr = &s.cacheWriteDropClosedCount
|
||||
lastPtr = &s.cacheWriteDropClosedLastLog
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
atomic.AddUint64(countPtr, 1)
|
||||
now := time.Now().UnixNano()
|
||||
last := atomic.LoadInt64(lastPtr)
|
||||
if now-last < int64(cacheWriteDropLogInterval) {
|
||||
return
|
||||
}
|
||||
if !atomic.CompareAndSwapInt64(lastPtr, last, now) {
|
||||
return
|
||||
}
|
||||
dropped := atomic.SwapUint64(countPtr, 0)
|
||||
if dropped == 0 {
|
||||
return
|
||||
}
|
||||
log.Printf("Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)",
|
||||
reason,
|
||||
dropped,
|
||||
cacheWriteDropLogInterval,
|
||||
cacheWriteKindName(task.kind),
|
||||
task.userID,
|
||||
task.groupID,
|
||||
)
|
||||
}
|
||||
|
||||
// ============================================
|
||||
@@ -70,11 +247,11 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64)
|
||||
}
|
||||
|
||||
// 异步建立缓存
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
s.setBalanceCache(cacheCtx, userID, balance)
|
||||
}()
|
||||
_ = s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteSetBalance,
|
||||
userID: userID,
|
||||
balance: balance,
|
||||
})
|
||||
|
||||
return balance, nil
|
||||
}
|
||||
@@ -98,7 +275,7 @@ func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64,
|
||||
}
|
||||
}
|
||||
|
||||
// DeductBalanceCache 扣减余额缓存(异步调用,用于扣费后更新缓存)
|
||||
// DeductBalanceCache 扣减余额缓存(同步调用)
|
||||
func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
@@ -106,6 +283,26 @@ func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int
|
||||
return s.cache.DeductUserBalance(ctx, userID, amount)
|
||||
}
|
||||
|
||||
// QueueDeductBalance 异步扣减余额缓存
|
||||
func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
// 队列满时同步回退,避免关键扣减被静默丢弃。
|
||||
if s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteDeductBalance,
|
||||
userID: userID,
|
||||
amount: amount,
|
||||
}) {
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||||
defer cancel()
|
||||
if err := s.DeductBalanceCache(ctx, userID, amount); err != nil {
|
||||
log.Printf("Warning: deduct balance cache fallback failed for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateUserBalance 失效用户余额缓存
|
||||
func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
if s.cache == nil {
|
||||
@@ -141,11 +338,12 @@ func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID,
|
||||
}
|
||||
|
||||
// 异步建立缓存
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
s.setSubscriptionCache(cacheCtx, userID, groupID, data)
|
||||
}()
|
||||
_ = s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteSetSubscription,
|
||||
userID: userID,
|
||||
groupID: groupID,
|
||||
subscriptionData: data,
|
||||
})
|
||||
|
||||
return data, nil
|
||||
}
|
||||
@@ -199,7 +397,7 @@ func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID,
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateSubscriptionUsage 更新订阅用量缓存(异步调用,用于扣费后更新缓存)
|
||||
// UpdateSubscriptionUsage 更新订阅用量缓存(同步调用)
|
||||
func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
@@ -207,6 +405,27 @@ func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userI
|
||||
return s.cache.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD)
|
||||
}
|
||||
|
||||
// QueueUpdateSubscriptionUsage 异步更新订阅用量缓存
|
||||
func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64, costUSD float64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
// 队列满时同步回退,确保订阅用量及时更新。
|
||||
if s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteUpdateSubscriptionUsage,
|
||||
userID: userID,
|
||||
groupID: groupID,
|
||||
amount: costUSD,
|
||||
}) {
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||||
defer cancel()
|
||||
if err := s.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD); err != nil {
|
||||
log.Printf("Warning: update subscription cache fallback failed for user %d group %d: %v", userID, groupID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateSubscription 失效指定订阅缓存
|
||||
func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error {
|
||||
if s.cache == nil {
|
||||
|
||||
75
backend/internal/service/billing_cache_service_test.go
Normal file
75
backend/internal/service/billing_cache_service_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type billingCacheWorkerStub struct {
|
||||
balanceUpdates int64
|
||||
subscriptionUpdates int64
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||
atomic.AddInt64(&b.balanceUpdates, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||
atomic.AddInt64(&b.balanceUpdates, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
|
||||
atomic.AddInt64(&b.subscriptionUpdates, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||
atomic.AddInt64(&b.subscriptionUpdates, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
|
||||
cache := &billingCacheWorkerStub{}
|
||||
svc := NewBillingCacheService(cache, nil, nil, &config.Config{})
|
||||
t.Cleanup(svc.Stop)
|
||||
|
||||
start := time.Now()
|
||||
for i := 0; i < cacheWriteBufferSize*2; i++ {
|
||||
svc.QueueDeductBalance(1, 1)
|
||||
}
|
||||
require.Less(t, time.Since(start), 2*time.Second)
|
||||
|
||||
svc.QueueUpdateSubscriptionUsage(1, 2, 1.5)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return atomic.LoadInt64(&cache.balanceUpdates) > 0
|
||||
}, 2*time.Second, 10*time.Millisecond)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return atomic.LoadInt64(&cache.subscriptionUpdates) > 0
|
||||
}, 2*time.Second, 10*time.Millisecond)
|
||||
}
|
||||
@@ -9,24 +9,35 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// ConcurrencyCache defines cache operations for concurrency service
|
||||
// Uses independent keys per request slot with native Redis TTL for automatic cleanup
|
||||
// ConcurrencyCache 定义并发控制的缓存接口
|
||||
// 使用有序集合存储槽位,按时间戳清理过期条目
|
||||
type ConcurrencyCache interface {
|
||||
// Account slot management - each slot is a separate key with independent TTL
|
||||
// Key format: concurrency:account:{accountID}:{requestID}
|
||||
// 账号槽位管理
|
||||
// 键格式: concurrency:account:{accountID}(有序集合,成员为 requestID)
|
||||
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
|
||||
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
|
||||
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
|
||||
|
||||
// User slot management - each slot is a separate key with independent TTL
|
||||
// Key format: concurrency:user:{userID}:{requestID}
|
||||
// 账号等待队列(账号级)
|
||||
IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error)
|
||||
DecrementAccountWaitCount(ctx context.Context, accountID int64) error
|
||||
GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error)
|
||||
|
||||
// 用户槽位管理
|
||||
// 键格式: concurrency:user:{userID}(有序集合,成员为 requestID)
|
||||
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
|
||||
ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error
|
||||
GetUserConcurrency(ctx context.Context, userID int64) (int, error)
|
||||
|
||||
// Wait queue - uses counter with TTL set only on creation
|
||||
// 等待队列计数(只在首次创建时设置 TTL)
|
||||
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
|
||||
DecrementWaitCount(ctx context.Context, userID int64) error
|
||||
|
||||
// 批量负载查询(只读)
|
||||
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
|
||||
|
||||
// 清理过期槽位(后台任务)
|
||||
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
|
||||
}
|
||||
|
||||
// generateRequestID generates a unique request ID for concurrency slot tracking
|
||||
@@ -61,6 +72,18 @@ type AcquireResult struct {
|
||||
ReleaseFunc func() // Must be called when done (typically via defer)
|
||||
}
|
||||
|
||||
type AccountWithConcurrency struct {
|
||||
ID int64
|
||||
MaxConcurrency int
|
||||
}
|
||||
|
||||
type AccountLoadInfo struct {
|
||||
AccountID int64
|
||||
CurrentConcurrency int
|
||||
WaitingCount int
|
||||
LoadRate int // 0-100+ (percent)
|
||||
}
|
||||
|
||||
// AcquireAccountSlot attempts to acquire a concurrency slot for an account.
|
||||
// If the account is at max concurrency, it waits until a slot is available or timeout.
|
||||
// Returns a release function that MUST be called when the request completes.
|
||||
@@ -177,6 +200,42 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
|
||||
}
|
||||
}
|
||||
|
||||
// IncrementAccountWaitCount increments the wait queue counter for an account.
|
||||
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
if s.cache == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait)
|
||||
if err != nil {
|
||||
log.Printf("Warning: increment wait count failed for account %d: %v", accountID, err)
|
||||
return true, nil
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DecrementAccountWaitCount decrements the wait queue counter for an account.
|
||||
func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil {
|
||||
log.Printf("Warning: decrement wait count failed for account %d: %v", accountID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccountWaitingCount gets current wait queue count for an account.
|
||||
func (s *ConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
if s.cache == nil {
|
||||
return 0, nil
|
||||
}
|
||||
return s.cache.GetAccountWaitingCount(ctx, accountID)
|
||||
}
|
||||
|
||||
// CalculateMaxWait calculates the maximum wait queue size for a user
|
||||
// maxWait = userConcurrency + defaultExtraWaitSlots
|
||||
func CalculateMaxWait(userConcurrency int) int {
|
||||
@@ -186,6 +245,57 @@ func CalculateMaxWait(userConcurrency int) int {
|
||||
return userConcurrency + defaultExtraWaitSlots
|
||||
}
|
||||
|
||||
// GetAccountsLoadBatch returns load info for multiple accounts.
|
||||
func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
if s.cache == nil {
|
||||
return map[int64]*AccountLoadInfo{}, nil
|
||||
}
|
||||
return s.cache.GetAccountsLoadBatch(ctx, accounts)
|
||||
}
|
||||
|
||||
// CleanupExpiredAccountSlots removes expired slots for one account (background task).
|
||||
func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
return s.cache.CleanupExpiredAccountSlots(ctx, accountID)
|
||||
}
|
||||
|
||||
// StartSlotCleanupWorker starts a background cleanup worker for expired account slots.
|
||||
func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepository, interval time.Duration) {
|
||||
if s == nil || s.cache == nil || accountRepo == nil || interval <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
runCleanup := func() {
|
||||
listCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
accounts, err := accountRepo.ListSchedulable(listCtx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
log.Printf("Warning: list schedulable accounts failed: %v", err)
|
||||
return
|
||||
}
|
||||
for _, account := range accounts {
|
||||
accountCtx, accountCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID)
|
||||
accountCancel()
|
||||
if err != nil {
|
||||
log.Printf("Warning: cleanup expired slots failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
runCleanup()
|
||||
for range ticker.C {
|
||||
runCleanup()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
|
||||
// Returns a map of accountID -> current concurrency count
|
||||
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
)
|
||||
|
||||
type CRSSyncService struct {
|
||||
@@ -193,7 +195,12 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
return nil, errors.New("username and password are required")
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 20 * time.Second}
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 20 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
client = &http.Client{Timeout: 20 * time.Second}
|
||||
}
|
||||
|
||||
adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password)
|
||||
if err != nil {
|
||||
|
||||
@@ -91,6 +91,9 @@ const (
|
||||
|
||||
// 管理员 API Key
|
||||
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
|
||||
|
||||
// Gemini 配额策略(JSON)
|
||||
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
|
||||
)
|
||||
|
||||
// Admin API Key prefix (distinct from user "sk-" keys)
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -261,6 +261,34 @@ func TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed(t
|
||||
require.Equal(t, int64(2), acc.ID, "同优先级应选择最久未用的账户")
|
||||
}
|
||||
|
||||
func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户")
|
||||
}
|
||||
|
||||
// TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts 测试无可用账户
|
||||
func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
@@ -576,6 +604,32 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
|
||||
func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户")
|
||||
})
|
||||
|
||||
t.Run("混合调度-包含启用mixed_scheduling的antigravity账户", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
@@ -783,3 +837,160 @@ func TestAccount_IsMixedSchedulingEnabled(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mockConcurrencyService for testing
|
||||
type mockConcurrencyService struct {
|
||||
accountLoads map[int64]*AccountLoadInfo
|
||||
accountWaitCounts map[int64]int
|
||||
acquireResults map[int64]bool
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
if m.accountLoads == nil {
|
||||
return map[int64]*AccountLoadInfo{}, nil
|
||||
}
|
||||
result := make(map[int64]*AccountLoadInfo)
|
||||
for _, acc := range accounts {
|
||||
if load, ok := m.accountLoads[acc.ID]; ok {
|
||||
result[acc.ID] = load
|
||||
} else {
|
||||
result[acc.ID] = &AccountLoadInfo{
|
||||
AccountID: acc.ID,
|
||||
CurrentConcurrency: 0,
|
||||
WaitingCount: 0,
|
||||
LoadRate: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
if m.accountWaitCounts == nil {
|
||||
return 0, nil
|
||||
}
|
||||
return m.accountWaitCounts[accountID], nil
|
||||
}
|
||||
|
||||
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
|
||||
func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("禁用负载批量查询-降级到传统选择", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
cfg := testConfig()
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = false
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: nil, // No concurrency service
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号")
|
||||
})
|
||||
|
||||
t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
cfg := testConfig()
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = true
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: nil,
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, int64(2), result.Account.ID, "应选择优先级最高的账号")
|
||||
})
|
||||
|
||||
t.Run("排除账号-不选择被排除的账号", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
cfg := testConfig()
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = false
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: nil,
|
||||
}
|
||||
|
||||
excludedIDs := map[int64]struct{}{1: {}}
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, int64(2), result.Account.ID, "不应选择被排除的账号")
|
||||
})
|
||||
|
||||
t.Run("无可用账号-返回错误", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
cfg := testConfig()
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = false
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: nil,
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Contains(t, err.Error(), "no available accounts")
|
||||
})
|
||||
}
|
||||
|
||||
72
backend/internal/service/gateway_request.go
Normal file
72
backend/internal/service/gateway_request.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ParsedRequest 保存网关请求的预解析结果
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现在多个位置重复解析请求体(Handler、Service 各解析一次):
|
||||
// 1. gateway_handler.go 解析获取 model 和 stream
|
||||
// 2. gateway_service.go 再次解析获取 system、messages、metadata
|
||||
// 3. GenerateSessionHash 又一次解析获取会话哈希所需字段
|
||||
//
|
||||
// 新实现一次解析,多处复用:
|
||||
// 1. 在 Handler 层统一调用 ParseGatewayRequest 一次性解析
|
||||
// 2. 将解析结果 ParsedRequest 传递给 Service 层
|
||||
// 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销
|
||||
type ParsedRequest struct {
|
||||
Body []byte // 原始请求体(保留用于转发)
|
||||
Model string // 请求的模型名称
|
||||
Stream bool // 是否为流式请求
|
||||
MetadataUserID string // metadata.user_id(用于会话亲和)
|
||||
System any // system 字段内容
|
||||
Messages []any // messages 数组
|
||||
HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入)
|
||||
}
|
||||
|
||||
// ParseGatewayRequest 解析网关请求体并返回结构化结果
|
||||
// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal
|
||||
func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parsed := &ParsedRequest{
|
||||
Body: body,
|
||||
}
|
||||
|
||||
if rawModel, exists := req["model"]; exists {
|
||||
model, ok := rawModel.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid model field type")
|
||||
}
|
||||
parsed.Model = model
|
||||
}
|
||||
if rawStream, exists := req["stream"]; exists {
|
||||
stream, ok := rawStream.(bool)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid stream field type")
|
||||
}
|
||||
parsed.Stream = stream
|
||||
}
|
||||
if metadata, ok := req["metadata"].(map[string]any); ok {
|
||||
if userID, ok := metadata["user_id"].(string); ok {
|
||||
parsed.MetadataUserID = userID
|
||||
}
|
||||
}
|
||||
// system 字段只要存在就视为显式提供(即使为 null),
|
||||
// 以避免客户端传 null 时被默认 system 误注入。
|
||||
if system, ok := req["system"]; ok {
|
||||
parsed.HasSystem = true
|
||||
parsed.System = system
|
||||
}
|
||||
if messages, ok := req["messages"].([]any); ok {
|
||||
parsed.Messages = messages
|
||||
}
|
||||
|
||||
return parsed, nil
|
||||
}
|
||||
40
backend/internal/service/gateway_request_test.go
Normal file
40
backend/internal/service/gateway_request_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseGatewayRequest(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`)
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "claude-3-7-sonnet", parsed.Model)
|
||||
require.True(t, parsed.Stream)
|
||||
require.Equal(t, "session_123e4567-e89b-12d3-a456-426614174000", parsed.MetadataUserID)
|
||||
require.True(t, parsed.HasSystem)
|
||||
require.NotNil(t, parsed.System)
|
||||
require.Len(t, parsed.Messages, 1)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_SystemNull(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-3","system":null}`)
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
require.NoError(t, err)
|
||||
// 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。
|
||||
require.True(t, parsed.HasSystem)
|
||||
require.Nil(t, parsed.System)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_InvalidModelType(t *testing.T) {
|
||||
body := []byte(`{"model":123}`)
|
||||
_, err := ParseGatewayRequest(body)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseGatewayRequest_InvalidStreamType(t *testing.T) {
|
||||
body := []byte(`{"stream":"true"}`)
|
||||
_, err := ParseGatewayRequest(body)
|
||||
require.Error(t, err)
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -33,7 +34,10 @@ const (
|
||||
|
||||
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||
var sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
var (
|
||||
sseDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
|
||||
)
|
||||
|
||||
// allowedHeaders 白名单headers(参考CRS项目)
|
||||
var allowedHeaders = map[string]bool{
|
||||
@@ -64,6 +68,20 @@ type GatewayCache interface {
|
||||
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
|
||||
}
|
||||
|
||||
type AccountWaitPlan struct {
|
||||
AccountID int64
|
||||
MaxConcurrency int
|
||||
Timeout time.Duration
|
||||
MaxWaiting int
|
||||
}
|
||||
|
||||
type AccountSelectionResult struct {
|
||||
Account *Account
|
||||
Acquired bool
|
||||
ReleaseFunc func()
|
||||
WaitPlan *AccountWaitPlan // nil means no wait allowed
|
||||
}
|
||||
|
||||
// ClaudeUsage 表示Claude API返回的usage信息
|
||||
type ClaudeUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
@@ -106,6 +124,7 @@ type GatewayService struct {
|
||||
identityService *IdentityService
|
||||
httpUpstream HTTPUpstream
|
||||
deferredService *DeferredService
|
||||
concurrencyService *ConcurrencyService
|
||||
}
|
||||
|
||||
// NewGatewayService creates a new GatewayService
|
||||
@@ -117,6 +136,7 @@ func NewGatewayService(
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
cache GatewayCache,
|
||||
cfg *config.Config,
|
||||
concurrencyService *ConcurrencyService,
|
||||
billingService *BillingService,
|
||||
rateLimitService *RateLimitService,
|
||||
billingCacheService *BillingCacheService,
|
||||
@@ -132,6 +152,7 @@ func NewGatewayService(
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: concurrencyService,
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
billingCacheService: billingCacheService,
|
||||
@@ -141,40 +162,36 @@ func NewGatewayService(
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateSessionHash 从请求体计算粘性会话hash
|
||||
func (s *GatewayService) GenerateSessionHash(body []byte) string {
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
// GenerateSessionHash 从预解析请求计算粘性会话 hash
|
||||
func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
||||
if parsed == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 1. 最高优先级:从metadata.user_id提取session_xxx
|
||||
if metadata, ok := req["metadata"].(map[string]any); ok {
|
||||
if userID, ok := metadata["user_id"].(string); ok {
|
||||
re := regexp.MustCompile(`session_([a-f0-9-]{36})`)
|
||||
if match := re.FindStringSubmatch(userID); len(match) > 1 {
|
||||
return match[1]
|
||||
}
|
||||
// 1. 最高优先级:从 metadata.user_id 提取 session_xxx
|
||||
if parsed.MetadataUserID != "" {
|
||||
if match := sessionIDRegex.FindStringSubmatch(parsed.MetadataUserID); len(match) > 1 {
|
||||
return match[1]
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 提取带cache_control: {type: "ephemeral"}的内容
|
||||
cacheableContent := s.extractCacheableContent(req)
|
||||
// 2. 提取带 cache_control: {type: "ephemeral"} 的内容
|
||||
cacheableContent := s.extractCacheableContent(parsed)
|
||||
if cacheableContent != "" {
|
||||
return s.hashContent(cacheableContent)
|
||||
}
|
||||
|
||||
// 3. Fallback: 使用system内容
|
||||
if system := req["system"]; system != nil {
|
||||
systemText := s.extractTextFromSystem(system)
|
||||
// 3. Fallback: 使用 system 内容
|
||||
if parsed.System != nil {
|
||||
systemText := s.extractTextFromSystem(parsed.System)
|
||||
if systemText != "" {
|
||||
return s.hashContent(systemText)
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 最后fallback: 使用第一条消息
|
||||
if messages, ok := req["messages"].([]any); ok && len(messages) > 0 {
|
||||
if firstMsg, ok := messages[0].(map[string]any); ok {
|
||||
// 4. 最后 fallback: 使用第一条消息
|
||||
if len(parsed.Messages) > 0 {
|
||||
if firstMsg, ok := parsed.Messages[0].(map[string]any); ok {
|
||||
msgText := s.extractTextFromContent(firstMsg["content"])
|
||||
if msgText != "" {
|
||||
return s.hashContent(msgText)
|
||||
@@ -185,36 +202,46 @@ func (s *GatewayService) GenerateSessionHash(body []byte) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *GatewayService) extractCacheableContent(req map[string]any) string {
|
||||
var content string
|
||||
// BindStickySession sets session -> account binding with standard TTL.
|
||||
func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
|
||||
if sessionHash == "" || accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL)
|
||||
}
|
||||
|
||||
// 检查system中的cacheable内容
|
||||
if system, ok := req["system"].([]any); ok {
|
||||
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
|
||||
if parsed == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var builder strings.Builder
|
||||
|
||||
// 检查 system 中的 cacheable 内容
|
||||
if system, ok := parsed.System.([]any); ok {
|
||||
for _, part := range system {
|
||||
if partMap, ok := part.(map[string]any); ok {
|
||||
if cc, ok := partMap["cache_control"].(map[string]any); ok {
|
||||
if cc["type"] == "ephemeral" {
|
||||
if text, ok := partMap["text"].(string); ok {
|
||||
content += text
|
||||
_, _ = builder.WriteString(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
systemText := builder.String()
|
||||
|
||||
// 检查messages中的cacheable内容
|
||||
if messages, ok := req["messages"].([]any); ok {
|
||||
for _, msg := range messages {
|
||||
if msgMap, ok := msg.(map[string]any); ok {
|
||||
if msgContent, ok := msgMap["content"].([]any); ok {
|
||||
for _, part := range msgContent {
|
||||
if partMap, ok := part.(map[string]any); ok {
|
||||
if cc, ok := partMap["cache_control"].(map[string]any); ok {
|
||||
if cc["type"] == "ephemeral" {
|
||||
// 找到cacheable内容,提取第一条消息的文本
|
||||
return s.extractTextFromContent(msgMap["content"])
|
||||
}
|
||||
// 检查 messages 中的 cacheable 内容
|
||||
for _, msg := range parsed.Messages {
|
||||
if msgMap, ok := msg.(map[string]any); ok {
|
||||
if msgContent, ok := msgMap["content"].([]any); ok {
|
||||
for _, part := range msgContent {
|
||||
if partMap, ok := part.(map[string]any); ok {
|
||||
if cc, ok := partMap["cache_control"].(map[string]any); ok {
|
||||
if cc["type"] == "ephemeral" {
|
||||
return s.extractTextFromContent(msgMap["content"])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -223,7 +250,7 @@ func (s *GatewayService) extractCacheableContent(req map[string]any) string {
|
||||
}
|
||||
}
|
||||
|
||||
return content
|
||||
return systemText
|
||||
}
|
||||
|
||||
func (s *GatewayService) extractTextFromSystem(system any) string {
|
||||
@@ -332,8 +359,354 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||
}
|
||||
|
||||
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
|
||||
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
|
||||
cfg := s.schedulingConfig()
|
||||
var stickyAccountID int64
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash); err == nil {
|
||||
stickyAccountID = accountID
|
||||
}
|
||||
}
|
||||
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
|
||||
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
preferOAuth := platform == PlatformGemini
|
||||
|
||||
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(accounts) == 0 {
|
||||
return nil, errors.New("no available accounts")
|
||||
}
|
||||
|
||||
isExcluded := func(accountID int64) bool {
|
||||
if excludedIDs == nil {
|
||||
return false
|
||||
}
|
||||
_, excluded := excludedIDs[accountID]
|
||||
return excluded
|
||||
}
|
||||
|
||||
// ============ Layer 1: 粘性会话优先 ============
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||
account.IsSchedulable() &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Layer 2: 负载感知选择 ============
|
||||
candidates := make([]*Account, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if isExcluded(acc.ID) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, acc)
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return nil, errors.New("no available accounts")
|
||||
}
|
||||
|
||||
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
|
||||
for _, acc := range candidates {
|
||||
accountLoads = append(accountLoads, AccountWithConcurrency{
|
||||
ID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
})
|
||||
}
|
||||
|
||||
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
||||
if err != nil {
|
||||
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, sessionHash, preferOAuth); ok {
|
||||
return result, nil
|
||||
}
|
||||
} else {
|
||||
type accountWithLoad struct {
|
||||
account *Account
|
||||
loadInfo *AccountLoadInfo
|
||||
}
|
||||
var available []accountWithLoad
|
||||
for _, acc := range candidates {
|
||||
loadInfo := loadMap[acc.ID]
|
||||
if loadInfo == nil {
|
||||
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
|
||||
}
|
||||
if loadInfo.LoadRate < 100 {
|
||||
available = append(available, accountWithLoad{
|
||||
account: acc,
|
||||
loadInfo: loadInfo,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(available) > 0 {
|
||||
sort.SliceStable(available, func(i, j int) bool {
|
||||
a, b := available[i], available[j]
|
||||
if a.account.Priority != b.account.Priority {
|
||||
return a.account.Priority < b.account.Priority
|
||||
}
|
||||
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
||||
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
|
||||
}
|
||||
switch {
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
|
||||
return true
|
||||
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
|
||||
if preferOAuth && a.account.Type != b.account.Type {
|
||||
return a.account.Type == AccountTypeOAuth
|
||||
}
|
||||
return false
|
||||
default:
|
||||
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||
}
|
||||
})
|
||||
|
||||
for _, item := range available {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" {
|
||||
_ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: item.account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Layer 3: 兜底排队 ============
|
||||
sortAccountsByPriorityAndLastUsed(candidates, preferOAuth)
|
||||
for _, acc := range candidates {
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return nil, errors.New("no available accounts")
|
||||
}
|
||||
|
||||
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
|
||||
ordered := append([]*Account(nil), candidates...)
|
||||
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
|
||||
|
||||
for _, acc := range ordered {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" {
|
||||
_ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||
if s.cfg != nil {
|
||||
return s.cfg.Gateway.Scheduling
|
||||
}
|
||||
return config.GatewaySchedulingConfig{
|
||||
StickySessionMaxWaiting: 3,
|
||||
StickySessionWaitTimeout: 45 * time.Second,
|
||||
FallbackWaitTimeout: 30 * time.Second,
|
||||
FallbackMaxWaiting: 100,
|
||||
LoadBatchEnabled: true,
|
||||
SlotCleanupInterval: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) {
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform != "" {
|
||||
return forcePlatform, true, nil
|
||||
}
|
||||
if groupID != nil {
|
||||
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||||
if err != nil {
|
||||
return "", false, fmt.Errorf("get group failed: %w", err)
|
||||
}
|
||||
return group.Platform, false, nil
|
||||
}
|
||||
return PlatformAnthropic, false, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
|
||||
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
|
||||
if useMixed {
|
||||
platforms := []string{platform, PlatformAntigravity}
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, useMixed, err
|
||||
}
|
||||
filtered := make([]Account, 0, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, acc)
|
||||
}
|
||||
return filtered, useMixed, nil
|
||||
}
|
||||
|
||||
var accounts []Account
|
||||
var err error
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
} else if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
|
||||
if err == nil && len(accounts) == 0 && hasForcePlatform {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, useMixed, err
|
||||
}
|
||||
return accounts, useMixed, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if useMixed {
|
||||
if account.Platform == platform {
|
||||
return true
|
||||
}
|
||||
return account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()
|
||||
}
|
||||
return account.Platform == platform
|
||||
}
|
||||
|
||||
func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||
if s.concurrencyService == nil {
|
||||
return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil
|
||||
}
|
||||
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
}
|
||||
|
||||
func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
|
||||
sort.SliceStable(accounts, func(i, j int) bool {
|
||||
a, b := accounts[i], accounts[j]
|
||||
if a.Priority != b.Priority {
|
||||
return a.Priority < b.Priority
|
||||
}
|
||||
switch {
|
||||
case a.LastUsedAt == nil && b.LastUsedAt != nil:
|
||||
return true
|
||||
case a.LastUsedAt != nil && b.LastUsedAt == nil:
|
||||
return false
|
||||
case a.LastUsedAt == nil && b.LastUsedAt == nil:
|
||||
if preferOAuth && a.Type != b.Type {
|
||||
return a.Type == AccountTypeOAuth
|
||||
}
|
||||
return false
|
||||
default:
|
||||
return a.LastUsedAt.Before(*b.LastUsedAt)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
|
||||
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
||||
preferOAuth := platform == PlatformGemini
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
@@ -389,7 +762,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||||
// keep selected (never used is preferred)
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||
// keep selected (both never used)
|
||||
if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
|
||||
selected = acc
|
||||
}
|
||||
default:
|
||||
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
||||
selected = acc
|
||||
@@ -419,6 +794,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
|
||||
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
|
||||
platforms := []string{nativePlatform, PlatformAntigravity}
|
||||
preferOAuth := nativePlatform == PlatformGemini
|
||||
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" {
|
||||
@@ -478,7 +854,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||||
// keep selected (never used is preferred)
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||
// keep selected (both never used)
|
||||
if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
|
||||
selected = acc
|
||||
}
|
||||
default:
|
||||
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
||||
selected = acc
|
||||
@@ -515,24 +893,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
|
||||
}
|
||||
|
||||
// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型
|
||||
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
|
||||
func IsAntigravityModelSupported(requestedModel string) bool {
|
||||
// 直接支持的模型
|
||||
if antigravitySupportedModels[requestedModel] {
|
||||
return true
|
||||
}
|
||||
// 可映射的模型
|
||||
if _, ok := antigravityModelMapping[requestedModel]; ok {
|
||||
return true
|
||||
}
|
||||
// Gemini 前缀透传
|
||||
if strings.HasPrefix(requestedModel, "gemini-") {
|
||||
return true
|
||||
}
|
||||
// Claude 模型支持(通过默认映射到 claude-sonnet-4-5)
|
||||
if strings.HasPrefix(requestedModel, "claude-") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
return strings.HasPrefix(requestedModel, "claude-") ||
|
||||
strings.HasPrefix(requestedModel, "gemini-")
|
||||
}
|
||||
|
||||
// GetAccessToken 获取账号凭证
|
||||
@@ -588,19 +952,17 @@ func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
|
||||
}
|
||||
|
||||
// Forward 转发请求到Claude API
|
||||
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||||
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 解析请求获取model和stream
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return nil, fmt.Errorf("parse request: %w", err)
|
||||
if parsed == nil {
|
||||
return nil, fmt.Errorf("parse request: empty request")
|
||||
}
|
||||
|
||||
if !gjson.GetBytes(body, "system").Exists() {
|
||||
body := parsed.Body
|
||||
reqModel := parsed.Model
|
||||
reqStream := parsed.Stream
|
||||
|
||||
if !parsed.HasSystem {
|
||||
body, _ = sjson.SetBytes(body, "system", []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
@@ -613,13 +975,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
|
||||
// 应用模型映射(仅对apikey类型账号)
|
||||
originalModel := req.Model
|
||||
originalModel := reqModel
|
||||
if account.Type == AccountTypeApiKey {
|
||||
mappedModel := account.GetMappedModel(req.Model)
|
||||
if mappedModel != req.Model {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
// 替换请求体中的模型名
|
||||
body = s.replaceModelInBody(body, mappedModel)
|
||||
req.Model = mappedModel
|
||||
reqModel = mappedModel
|
||||
log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name)
|
||||
}
|
||||
}
|
||||
@@ -640,13 +1002,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
var resp *http.Response
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
|
||||
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
@@ -686,14 +1048,38 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
|
||||
// 处理错误响应(不可重试的错误)
|
||||
if resp.StatusCode >= 400 {
|
||||
// 可选:对部分 400 触发 failover(默认关闭以保持语义)
|
||||
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
// ReadAll failed, fall back to normal error handling without consuming the stream
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
if s.shouldFailoverOn400(respBody) {
|
||||
if s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
log.Printf(
|
||||
"Account %d: 400 error, attempting failover: %s",
|
||||
account.ID,
|
||||
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
|
||||
)
|
||||
} else {
|
||||
log.Printf("Account %d: 400 error, attempting failover", account.ID)
|
||||
}
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
}
|
||||
|
||||
// 处理正常响应
|
||||
var usage *ClaudeUsage
|
||||
var firstTokenMs *int
|
||||
if req.Stream {
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, req.Model)
|
||||
if reqStream {
|
||||
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel)
|
||||
if err != nil {
|
||||
if err.Error() == "have error in stream" {
|
||||
return nil, &UpstreamFailoverError{
|
||||
@@ -705,7 +1091,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
usage = streamResult.usage
|
||||
firstTokenMs = streamResult.firstTokenMs
|
||||
} else {
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, req.Model)
|
||||
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -715,13 +1101,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
RequestID: resp.Header.Get("x-request-id"),
|
||||
Usage: *usage,
|
||||
Model: originalModel, // 使用原始模型用于计费和日志
|
||||
Stream: req.Stream,
|
||||
Stream: reqStream,
|
||||
Duration: time.Since(startTime),
|
||||
FirstTokenMs: firstTokenMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) {
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
||||
// 确定目标URL
|
||||
targetURL := claudeAPIURL
|
||||
if account.Type == AccountTypeApiKey {
|
||||
@@ -787,7 +1173,14 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
|
||||
// 处理anthropic-beta header(OAuth账号需要特殊处理)
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
|
||||
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
||||
if requestNeedsBetaFeatures(body) {
|
||||
if beta := defaultApiKeyBetaHeader(body); beta != "" {
|
||||
req.Header.Set("anthropic-beta", beta)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -795,7 +1188,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
|
||||
// getBetaHeader 处理anthropic-beta header
|
||||
// 对于OAuth账号,需要确保包含oauth-2025-04-20
|
||||
func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) string {
|
||||
func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string {
|
||||
// 如果客户端传了anthropic-beta
|
||||
if clientBetaHeader != "" {
|
||||
// 已包含oauth beta则直接返回
|
||||
@@ -832,15 +1225,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
|
||||
}
|
||||
|
||||
// 客户端没传,根据模型生成
|
||||
var modelID string
|
||||
var reqMap map[string]any
|
||||
if json.Unmarshal(body, &reqMap) == nil {
|
||||
if m, ok := reqMap["model"].(string); ok {
|
||||
modelID = m
|
||||
}
|
||||
}
|
||||
|
||||
// haiku模型不需要claude-code beta
|
||||
// haiku 模型不需要 claude-code beta
|
||||
if strings.Contains(strings.ToLower(modelID), "haiku") {
|
||||
return claude.HaikuBetaHeader
|
||||
}
|
||||
@@ -848,6 +1233,83 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
|
||||
return claude.DefaultBetaHeader
|
||||
}
|
||||
|
||||
func requestNeedsBetaFeatures(body []byte) bool {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 {
|
||||
return true
|
||||
}
|
||||
if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func defaultApiKeyBetaHeader(body []byte) string {
|
||||
modelID := gjson.GetBytes(body, "model").String()
|
||||
if strings.Contains(strings.ToLower(modelID), "haiku") {
|
||||
return claude.ApiKeyHaikuBetaHeader
|
||||
}
|
||||
return claude.ApiKeyBetaHeader
|
||||
}
|
||||
|
||||
func truncateForLog(b []byte, maxBytes int) string {
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
if len(b) > maxBytes {
|
||||
b = b[:maxBytes]
|
||||
}
|
||||
s := string(b)
|
||||
// 保持一行,避免污染日志格式
|
||||
s = strings.ReplaceAll(s, "\n", "\\n")
|
||||
s = strings.ReplaceAll(s, "\r", "\\r")
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
|
||||
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
|
||||
// 默认保守:无法识别则不切换。
|
||||
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
|
||||
if msg == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 缺少/错误的 beta header:换账号/链路可能成功(尤其是混合调度时)。
|
||||
// 更精确匹配 beta 相关的兼容性问题,避免误触发切换。
|
||||
if strings.Contains(msg, "anthropic-beta") ||
|
||||
strings.Contains(msg, "beta feature") ||
|
||||
strings.Contains(msg, "requires beta") {
|
||||
return true
|
||||
}
|
||||
|
||||
// thinking/tool streaming 等兼容性约束(常见于中间转换链路)
|
||||
if strings.Contains(msg, "thinking") || strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(msg, "tool_use") || strings.Contains(msg, "tool_result") || strings.Contains(msg, "tools") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func extractUpstreamErrorMessage(body []byte) string {
|
||||
// Claude 风格:{"type":"error","error":{"type":"...","message":"..."}}
|
||||
if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" {
|
||||
inner := strings.TrimSpace(m)
|
||||
// 有些上游会把完整 JSON 作为字符串塞进 message
|
||||
if strings.HasPrefix(inner, "{") {
|
||||
if innerMsg := gjson.Get(inner, "error.message").String(); strings.TrimSpace(innerMsg) != "" {
|
||||
return innerMsg
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// 兜底:尝试顶层 message
|
||||
return gjson.GetBytes(body, "message").String()
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
@@ -860,6 +1322,16 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
||||
|
||||
switch resp.StatusCode {
|
||||
case 400:
|
||||
// 仅记录上游错误摘要(避免输出请求内容);需要时可通过配置打开
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
log.Printf(
|
||||
"Upstream 400 error (account=%d platform=%s type=%s): %s",
|
||||
account.ID,
|
||||
account.Platform,
|
||||
account.Type,
|
||||
truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
|
||||
)
|
||||
}
|
||||
c.Data(http.StatusBadRequest, "application/json", body)
|
||||
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||
case 401:
|
||||
@@ -1248,13 +1720,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
log.Printf("Increment subscription usage failed: %v", err)
|
||||
}
|
||||
// 异步更新订阅缓存
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.billingCacheService.UpdateSubscriptionUsage(cacheCtx, user.ID, *apiKey.GroupID, cost.TotalCost); err != nil {
|
||||
log.Printf("Update subscription cache failed: %v", err)
|
||||
}
|
||||
}()
|
||||
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
} else {
|
||||
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
|
||||
@@ -1263,13 +1729,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
log.Printf("Deduct balance failed: %v", err)
|
||||
}
|
||||
// 异步更新余额缓存
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.billingCacheService.DeductBalanceCache(cacheCtx, user.ID, cost.ActualCost); err != nil {
|
||||
log.Printf("Update balance cache failed: %v", err)
|
||||
}
|
||||
}()
|
||||
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1281,7 +1741,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
|
||||
// ForwardCountTokens 转发 count_tokens 请求到上游 API
|
||||
// 特点:不记录使用量、仅支持非流式响应
|
||||
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, body []byte) error {
|
||||
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
|
||||
if parsed == nil {
|
||||
s.countTokensError(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return fmt.Errorf("parse request: empty request")
|
||||
}
|
||||
|
||||
body := parsed.Body
|
||||
reqModel := parsed.Model
|
||||
|
||||
// Antigravity 账户不支持 count_tokens 转发,返回估算值
|
||||
// 参考 Antigravity-Manager 和 proxycast 实现
|
||||
if account.Platform == PlatformAntigravity {
|
||||
@@ -1291,14 +1759,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
|
||||
// 应用模型映射(仅对 apikey 类型账号)
|
||||
if account.Type == AccountTypeApiKey {
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err == nil && req.Model != "" {
|
||||
mappedModel := account.GetMappedModel(req.Model)
|
||||
if mappedModel != req.Model {
|
||||
if reqModel != "" {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
body = s.replaceModelInBody(body, mappedModel)
|
||||
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", req.Model, mappedModel, account.Name)
|
||||
reqModel = mappedModel
|
||||
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1311,7 +1777,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 构建上游请求
|
||||
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
|
||||
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel)
|
||||
if err != nil {
|
||||
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
|
||||
return err
|
||||
@@ -1324,7 +1790,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL)
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
||||
return fmt.Errorf("upstream request failed: %w", err)
|
||||
@@ -1345,6 +1811,18 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
// 标记账号状态(429/529等)
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
|
||||
// 记录上游错误摘要便于排障(不回显请求内容)
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
log.Printf(
|
||||
"count_tokens upstream error %d (account=%d platform=%s type=%s): %s",
|
||||
resp.StatusCode,
|
||||
account.ID,
|
||||
account.Platform,
|
||||
account.Type,
|
||||
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
|
||||
)
|
||||
}
|
||||
|
||||
// 返回简化的错误响应
|
||||
errMsg := "Upstream request failed"
|
||||
switch resp.StatusCode {
|
||||
@@ -1363,7 +1841,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// buildCountTokensRequest 构建 count_tokens 上游请求
|
||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) {
|
||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
||||
// 确定目标 URL
|
||||
targetURL := claudeAPICountTokensURL
|
||||
if account.Type == AccountTypeApiKey {
|
||||
@@ -1424,7 +1902,14 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
|
||||
// OAuth 账号:处理 anthropic-beta header
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
|
||||
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
|
||||
if requestNeedsBetaFeatures(body) {
|
||||
if beta := defaultApiKeyBetaHeader(body); beta != "" {
|
||||
req.Header.Set("anthropic-beta", beta)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
|
||||
50
backend/internal/service/gateway_service_benchmark_test.go
Normal file
50
backend/internal/service/gateway_service_benchmark_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var benchmarkStringSink string
|
||||
|
||||
// BenchmarkGenerateSessionHash_Metadata 关注 JSON 解析与正则匹配开销。
|
||||
func BenchmarkGenerateSessionHash_Metadata(b *testing.B) {
|
||||
svc := &GatewayService{}
|
||||
body := []byte(`{"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"messages":[{"content":"hello"}]}`)
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
parsed, err := ParseGatewayRequest(body)
|
||||
if err != nil {
|
||||
b.Fatalf("解析请求失败: %v", err)
|
||||
}
|
||||
benchmarkStringSink = svc.GenerateSessionHash(parsed)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkExtractCacheableContent_System 关注字符串拼接路径的性能。
|
||||
func BenchmarkExtractCacheableContent_System(b *testing.B) {
|
||||
svc := &GatewayService{}
|
||||
req := buildSystemCacheableRequest(12)
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkStringSink = svc.extractCacheableContent(req)
|
||||
}
|
||||
}
|
||||
|
||||
func buildSystemCacheableRequest(parts int) *ParsedRequest {
|
||||
systemParts := make([]any, 0, parts)
|
||||
for i := 0; i < parts; i++ {
|
||||
systemParts = append(systemParts, map[string]any{
|
||||
"text": "system_part_" + strconv.Itoa(i),
|
||||
"cache_control": map[string]any{
|
||||
"type": "ephemeral",
|
||||
},
|
||||
})
|
||||
}
|
||||
return &ParsedRequest{
|
||||
System: systemParts,
|
||||
HasSystem: true,
|
||||
}
|
||||
}
|
||||
@@ -116,8 +116,20 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
valid = true
|
||||
}
|
||||
if valid {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
|
||||
return account, nil
|
||||
usable := true
|
||||
if s.rateLimitService != nil && requestedModel != "" {
|
||||
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
|
||||
if err != nil {
|
||||
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
|
||||
}
|
||||
if !ok {
|
||||
usable = false
|
||||
}
|
||||
}
|
||||
if usable {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -157,6 +169,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if s.rateLimitService != nil && requestedModel != "" {
|
||||
ok, err := s.rateLimitService.PreCheckUsage(ctx, acc, requestedModel)
|
||||
if err != nil {
|
||||
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", acc.ID, err)
|
||||
}
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
continue
|
||||
@@ -472,7 +493,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
requestIDHeader = idHeader
|
||||
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
if attempt < geminiMaxRetries {
|
||||
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
|
||||
@@ -725,7 +746,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}
|
||||
requestIDHeader = idHeader
|
||||
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL)
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
if attempt < geminiMaxRetries {
|
||||
log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
|
||||
@@ -921,7 +942,10 @@ func sleepGeminiBackoff(attempt int) {
|
||||
time.Sleep(sleepFor)
|
||||
}
|
||||
|
||||
var sensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|client_secret|access_token|refresh_token)=)[^&"\s]+`)
|
||||
var (
|
||||
sensitiveQueryParamRegex = regexp.MustCompile(`(?i)([?&](?:key|client_secret|access_token|refresh_token)=)[^&"\s]+`)
|
||||
retryInRegex = regexp.MustCompile(`Please retry in ([0-9.]+)s`)
|
||||
)
|
||||
|
||||
func sanitizeUpstreamErrorMessage(msg string) string {
|
||||
if msg == "" {
|
||||
@@ -1753,7 +1777,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
|
||||
return nil, fmt.Errorf("unsupported account type: %s", account.Type)
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL)
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1883,13 +1907,44 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont
|
||||
if statusCode != 429 {
|
||||
return
|
||||
}
|
||||
|
||||
oauthType := account.GeminiOAuthType()
|
||||
tierID := account.GeminiTierID()
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
isCodeAssist := account.IsGeminiCodeAssist()
|
||||
|
||||
resetAt := ParseGeminiRateLimitResetTime(body)
|
||||
if resetAt == nil {
|
||||
ra := time.Now().Add(5 * time.Minute)
|
||||
// 根据账号类型使用不同的默认重置时间
|
||||
var ra time.Time
|
||||
if isCodeAssist {
|
||||
// Code Assist: fallback cooldown by tier
|
||||
cooldown := geminiCooldownForTier(tierID)
|
||||
if s.rateLimitService != nil {
|
||||
cooldown = s.rateLimitService.GeminiCooldown(ctx, account)
|
||||
}
|
||||
ra = time.Now().Add(cooldown)
|
||||
log.Printf("[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second))
|
||||
} else {
|
||||
// API Key / AI Studio OAuth: PST 午夜
|
||||
if ts := nextGeminiDailyResetUnix(); ts != nil {
|
||||
ra = time.Unix(*ts, 0)
|
||||
log.Printf("[Gemini 429] Account %d (API Key/AI Studio, type=%s) rate limited, reset at PST midnight (%v)", account.ID, account.Type, ra)
|
||||
} else {
|
||||
// 兜底:5 分钟
|
||||
ra = time.Now().Add(5 * time.Minute)
|
||||
log.Printf("[Gemini 429] Account %d rate limited, fallback to 5min", account.ID)
|
||||
}
|
||||
}
|
||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
|
||||
return
|
||||
}
|
||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, time.Unix(*resetAt, 0))
|
||||
|
||||
// 使用解析到的重置时间
|
||||
resetTime := time.Unix(*resetAt, 0)
|
||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime)
|
||||
log.Printf("[Gemini 429] Account %d rate limited until %v (oauth_type=%s, tier=%s)",
|
||||
account.ID, resetTime, oauthType, tierID)
|
||||
}
|
||||
|
||||
// ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳
|
||||
@@ -1925,7 +1980,6 @@ func ParseGeminiRateLimitResetTime(body []byte) *int64 {
|
||||
}
|
||||
|
||||
// Match "Please retry in Xs"
|
||||
retryInRegex := regexp.MustCompile(`Please retry in ([0-9.]+)s`)
|
||||
matches := retryInRegex.FindStringSubmatch(string(body))
|
||||
if len(matches) == 2 {
|
||||
if dur, err := time.ParseDuration(matches[1] + "s"); err == nil {
|
||||
@@ -1946,16 +2000,7 @@ func looksLikeGeminiDailyQuota(message string) bool {
|
||||
}
|
||||
|
||||
func nextGeminiDailyResetUnix() *int64 {
|
||||
loc, err := time.LoadLocation("America/Los_Angeles")
|
||||
if err != nil {
|
||||
// Fallback: PST without DST.
|
||||
loc = time.FixedZone("PST", -8*3600)
|
||||
}
|
||||
now := time.Now().In(loc)
|
||||
reset := time.Date(now.Year(), now.Month(), now.Day(), 0, 5, 0, 0, loc)
|
||||
if !reset.After(now) {
|
||||
reset = reset.Add(24 * time.Hour)
|
||||
}
|
||||
reset := geminiDailyResetTime(time.Now())
|
||||
ts := reset.Unix()
|
||||
return &ts
|
||||
}
|
||||
@@ -2243,16 +2288,46 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name, _ := tm["name"].(string)
|
||||
desc, _ := tm["description"].(string)
|
||||
params := tm["input_schema"]
|
||||
|
||||
var name, desc string
|
||||
var params any
|
||||
|
||||
// 检查是否为 custom 类型工具 (MCP)
|
||||
toolType, _ := tm["type"].(string)
|
||||
if toolType == "custom" {
|
||||
// Custom 格式: 从 custom 字段获取 description 和 input_schema
|
||||
custom, ok := tm["custom"].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name, _ = tm["name"].(string)
|
||||
desc, _ = custom["description"].(string)
|
||||
params = custom["input_schema"]
|
||||
} else {
|
||||
// 标准格式: 从顶层字段获取
|
||||
name, _ = tm["name"].(string)
|
||||
desc, _ = tm["description"].(string)
|
||||
params = tm["input_schema"]
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 为 nil params 提供默认值
|
||||
if params == nil {
|
||||
params = map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
// 清理 JSON Schema
|
||||
cleanedParams := cleanToolSchema(params)
|
||||
|
||||
funcDecls = append(funcDecls, map[string]any{
|
||||
"name": name,
|
||||
"description": desc,
|
||||
"parameters": params,
|
||||
"parameters": cleanedParams,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2266,6 +2341,41 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
|
||||
}
|
||||
}
|
||||
|
||||
// cleanToolSchema 清理工具的 JSON Schema,移除 Gemini 不支持的字段
|
||||
func cleanToolSchema(schema any) any {
|
||||
if schema == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v := schema.(type) {
|
||||
case map[string]any:
|
||||
cleaned := make(map[string]any)
|
||||
for key, value := range v {
|
||||
// 跳过不支持的字段
|
||||
if key == "$schema" || key == "$id" || key == "$ref" ||
|
||||
key == "additionalProperties" || key == "minLength" ||
|
||||
key == "maxLength" || key == "minItems" || key == "maxItems" {
|
||||
continue
|
||||
}
|
||||
// 递归清理嵌套对象
|
||||
cleaned[key] = cleanToolSchema(value)
|
||||
}
|
||||
// 规范化 type 字段为大写
|
||||
if typeVal, ok := cleaned["type"].(string); ok {
|
||||
cleaned["type"] = strings.ToUpper(typeVal)
|
||||
}
|
||||
return cleaned
|
||||
case []any:
|
||||
cleaned := make([]any, len(v))
|
||||
for i, item := range v {
|
||||
cleaned[i] = cleanToolSchema(item)
|
||||
}
|
||||
return cleaned
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
func convertClaudeGenerationConfig(req map[string]any) map[string]any {
|
||||
out := make(map[string]any)
|
||||
if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 {
|
||||
|
||||
128
backend/internal/service/gemini_messages_compat_service_test.go
Normal file
128
backend/internal/service/gemini_messages_compat_service_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
|
||||
func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tools any
|
||||
expectedLen int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Standard tools",
|
||||
tools: []any{
|
||||
map[string]any{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather info",
|
||||
"input_schema": map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
expectedLen: 1,
|
||||
description: "标准工具格式应该正常转换",
|
||||
},
|
||||
{
|
||||
name: "Custom type tool (MCP format)",
|
||||
tools: []any{
|
||||
map[string]any{
|
||||
"type": "custom",
|
||||
"name": "mcp_tool",
|
||||
"custom": map[string]any{
|
||||
"description": "MCP tool description",
|
||||
"input_schema": map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLen: 1,
|
||||
description: "Custom类型工具应该从custom字段读取",
|
||||
},
|
||||
{
|
||||
name: "Mixed standard and custom tools",
|
||||
tools: []any{
|
||||
map[string]any{
|
||||
"name": "standard_tool",
|
||||
"description": "Standard",
|
||||
"input_schema": map[string]any{"type": "object"},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "custom",
|
||||
"name": "custom_tool",
|
||||
"custom": map[string]any{
|
||||
"description": "Custom",
|
||||
"input_schema": map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLen: 1,
|
||||
description: "混合工具应该都能正确转换",
|
||||
},
|
||||
{
|
||||
name: "Custom tool without custom field",
|
||||
tools: []any{
|
||||
map[string]any{
|
||||
"type": "custom",
|
||||
"name": "invalid_custom",
|
||||
// 缺少 custom 字段
|
||||
},
|
||||
},
|
||||
expectedLen: 0, // 应该被跳过
|
||||
description: "缺少custom字段的custom工具应该被跳过",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := convertClaudeToolsToGeminiTools(tt.tools)
|
||||
|
||||
if tt.expectedLen == 0 {
|
||||
if result != nil {
|
||||
t.Errorf("%s: expected nil result, got %v", tt.description, result)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatalf("%s: expected non-nil result", tt.description)
|
||||
}
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Errorf("%s: expected 1 tool declaration, got %d", tt.description, len(result))
|
||||
return
|
||||
}
|
||||
|
||||
toolDecl, ok := result[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("%s: result[0] is not map[string]any", tt.description)
|
||||
}
|
||||
|
||||
funcDecls, ok := toolDecl["functionDeclarations"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("%s: functionDeclarations is not []any", tt.description)
|
||||
}
|
||||
|
||||
toolsArr, _ := tt.tools.([]any)
|
||||
expectedFuncCount := 0
|
||||
for _, tool := range toolsArr {
|
||||
toolMap, _ := tool.(map[string]any)
|
||||
if toolMap["name"] != "" {
|
||||
// 检查是否为有效的custom工具
|
||||
if toolMap["type"] == "custom" {
|
||||
if toolMap["custom"] != nil {
|
||||
expectedFuncCount++
|
||||
}
|
||||
} else {
|
||||
expectedFuncCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(funcDecls) != expectedFuncCount {
|
||||
t.Errorf("%s: expected %d function declarations, got %d",
|
||||
tt.description, expectedFuncCount, len(funcDecls))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,13 +7,14 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
)
|
||||
|
||||
type GeminiOAuthService struct {
|
||||
@@ -163,6 +164,45 @@ type GeminiTokenInfo struct {
|
||||
Scope string `json:"scope,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
|
||||
TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA
|
||||
}
|
||||
|
||||
// validateTierID validates tier_id format and length
|
||||
func validateTierID(tierID string) error {
|
||||
if tierID == "" {
|
||||
return nil // Empty is allowed
|
||||
}
|
||||
if len(tierID) > 64 {
|
||||
return fmt.Errorf("tier_id exceeds maximum length of 64 characters")
|
||||
}
|
||||
// Allow alphanumeric, underscore, hyphen, and slash (for tier paths)
|
||||
if !regexp.MustCompile(`^[a-zA-Z0-9_/-]+$`).MatchString(tierID) {
|
||||
return fmt.Errorf("tier_id contains invalid characters")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractTierIDFromAllowedTiers extracts tierID from LoadCodeAssist response
|
||||
// Prioritizes IsDefault tier, falls back to first non-empty tier
|
||||
func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string {
|
||||
tierID := "LEGACY"
|
||||
// First pass: look for default tier
|
||||
for _, tier := range allowedTiers {
|
||||
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" {
|
||||
tierID = strings.TrimSpace(tier.ID)
|
||||
break
|
||||
}
|
||||
}
|
||||
// Second pass: if still LEGACY, take first non-empty tier
|
||||
if tierID == "LEGACY" {
|
||||
for _, tier := range allowedTiers {
|
||||
if strings.TrimSpace(tier.ID) != "" {
|
||||
tierID = strings.TrimSpace(tier.ID)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return tierID
|
||||
}
|
||||
|
||||
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
|
||||
@@ -219,25 +259,45 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
sessionProjectID := strings.TrimSpace(session.ProjectID)
|
||||
s.sessionStore.Delete(input.SessionID)
|
||||
|
||||
// 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
|
||||
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
|
||||
// 计算过期时间:减去 5 分钟安全时间窗口(考虑网络延迟和时钟偏差)
|
||||
// 同时设置下界保护,防止 expires_in 过小导致过去时间(引发刷新风暴)
|
||||
const safetyWindow = 300 // 5 minutes
|
||||
const minTTL = 30 // minimum 30 seconds
|
||||
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - safetyWindow
|
||||
minExpiresAt := time.Now().Unix() + minTTL
|
||||
if expiresAt < minExpiresAt {
|
||||
expiresAt = minExpiresAt
|
||||
}
|
||||
|
||||
projectID := sessionProjectID
|
||||
var tierID string
|
||||
|
||||
// 对于 code_assist 模式,project_id 是必需的
|
||||
// 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API)
|
||||
if oauthType == "code_assist" {
|
||||
if projectID == "" {
|
||||
var err error
|
||||
projectID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||
projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
// 记录警告但不阻断流程,允许后续补充 project_id
|
||||
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
|
||||
}
|
||||
} else {
|
||||
// 用户手动填了 project_id,仍需调用 LoadCodeAssist 获取 tierID
|
||||
_, fetchedTierID, err := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch tierID: %v\n", err)
|
||||
} else {
|
||||
tierID = fetchedTierID
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(projectID) == "" {
|
||||
return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project")
|
||||
}
|
||||
// tierID 缺失时使用默认值
|
||||
if tierID == "" {
|
||||
tierID = "LEGACY"
|
||||
}
|
||||
}
|
||||
|
||||
return &GeminiTokenInfo{
|
||||
@@ -248,6 +308,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
ExpiresAt: expiresAt,
|
||||
Scope: tokenResp.Scope,
|
||||
ProjectID: projectID,
|
||||
TierID: tierID,
|
||||
OAuthType: oauthType,
|
||||
}, nil
|
||||
}
|
||||
@@ -266,8 +327,15 @@ func (s *GeminiOAuthService) RefreshToken(ctx context.Context, oauthType, refres
|
||||
|
||||
tokenResp, err := s.oauthClient.RefreshToken(ctx, oauthType, refreshToken, proxyURL)
|
||||
if err == nil {
|
||||
// 计算过期时间时减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
|
||||
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
|
||||
// 计算过期时间:减去 5 分钟安全时间窗口(考虑网络延迟和时钟偏差)
|
||||
// 同时设置下界保护,防止 expires_in 过小导致过去时间(引发刷新风暴)
|
||||
const safetyWindow = 300 // 5 minutes
|
||||
const minTTL = 30 // minimum 30 seconds
|
||||
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - safetyWindow
|
||||
minExpiresAt := time.Now().Unix() + minTTL
|
||||
if expiresAt < minExpiresAt {
|
||||
expiresAt = minExpiresAt
|
||||
}
|
||||
return &GeminiTokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
@@ -354,18 +422,39 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
||||
tokenInfo.ProjectID = existingProjectID
|
||||
}
|
||||
|
||||
// 尝试从账号凭证获取 tierID(向后兼容)
|
||||
existingTierID := strings.TrimSpace(account.GetCredential("tier_id"))
|
||||
|
||||
// For Code Assist, project_id is required. Auto-detect if missing.
|
||||
// For AI Studio OAuth, project_id is optional and should not block refresh.
|
||||
if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" {
|
||||
projectID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to auto-detect project_id: %w", err)
|
||||
if oauthType == "code_assist" {
|
||||
// 先设置默认值或保留旧值,确保 tier_id 始终有值
|
||||
if existingTierID != "" {
|
||||
tokenInfo.TierID = existingTierID
|
||||
} else {
|
||||
tokenInfo.TierID = "LEGACY" // 默认值
|
||||
}
|
||||
projectID = strings.TrimSpace(projectID)
|
||||
if projectID == "" {
|
||||
|
||||
// 尝试自动探测 project_id 和 tier_id
|
||||
needDetect := strings.TrimSpace(tokenInfo.ProjectID) == "" || existingTierID == ""
|
||||
if needDetect {
|
||||
projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
fmt.Printf("[GeminiOAuth] Warning: failed to auto-detect project/tier: %v\n", err)
|
||||
} else {
|
||||
if strings.TrimSpace(tokenInfo.ProjectID) == "" && projectID != "" {
|
||||
tokenInfo.ProjectID = projectID
|
||||
}
|
||||
// 只有当原来没有 tier_id 且探测成功时才更新
|
||||
if existingTierID == "" && tierID != "" {
|
||||
tokenInfo.TierID = tierID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(tokenInfo.ProjectID) == "" {
|
||||
return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
|
||||
}
|
||||
tokenInfo.ProjectID = projectID
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
@@ -388,6 +477,13 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo)
|
||||
if tokenInfo.ProjectID != "" {
|
||||
creds["project_id"] = tokenInfo.ProjectID
|
||||
}
|
||||
if tokenInfo.TierID != "" {
|
||||
// Validate tier_id before storing
|
||||
if err := validateTierID(tokenInfo.TierID); err == nil {
|
||||
creds["tier_id"] = tokenInfo.TierID
|
||||
}
|
||||
// Silently skip invalid tier_id (don't block account creation)
|
||||
}
|
||||
if tokenInfo.OAuthType != "" {
|
||||
creds["oauth_type"] = tokenInfo.OAuthType
|
||||
}
|
||||
@@ -398,33 +494,22 @@ func (s *GeminiOAuthService) Stop() {
|
||||
s.sessionStore.Stop()
|
||||
}
|
||||
|
||||
func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, error) {
|
||||
func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, string, error) {
|
||||
if s.codeAssist == nil {
|
||||
return "", errors.New("code assist client not configured")
|
||||
return "", "", errors.New("code assist client not configured")
|
||||
}
|
||||
|
||||
loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil)
|
||||
if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
|
||||
return strings.TrimSpace(loadResp.CloudAICompanionProject), nil
|
||||
}
|
||||
|
||||
// Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID.
|
||||
// Extract tierID from response (works whether CloudAICompanionProject is set or not)
|
||||
tierID := "LEGACY"
|
||||
if loadResp != nil {
|
||||
for _, tier := range loadResp.AllowedTiers {
|
||||
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" {
|
||||
tierID = strings.TrimSpace(tier.ID)
|
||||
break
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(tierID) == "" || tierID == "LEGACY" {
|
||||
for _, tier := range loadResp.AllowedTiers {
|
||||
if strings.TrimSpace(tier.ID) != "" {
|
||||
tierID = strings.TrimSpace(tier.ID)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
|
||||
}
|
||||
|
||||
// If LoadCodeAssist returned a project, use it
|
||||
if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
|
||||
return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil
|
||||
}
|
||||
|
||||
req := &geminicli.OnboardUserRequest{
|
||||
@@ -443,39 +528,39 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
|
||||
// If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects.
|
||||
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
|
||||
if fbErr == nil && strings.TrimSpace(fallback) != "" {
|
||||
return strings.TrimSpace(fallback), nil
|
||||
return strings.TrimSpace(fallback), tierID, nil
|
||||
}
|
||||
return "", err
|
||||
return "", tierID, err
|
||||
}
|
||||
if resp.Done {
|
||||
if resp.Response != nil && resp.Response.CloudAICompanionProject != nil {
|
||||
switch v := resp.Response.CloudAICompanionProject.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v), nil
|
||||
return strings.TrimSpace(v), tierID, nil
|
||||
case map[string]any:
|
||||
if id, ok := v["id"].(string); ok {
|
||||
return strings.TrimSpace(id), nil
|
||||
return strings.TrimSpace(id), tierID, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
|
||||
if fbErr == nil && strings.TrimSpace(fallback) != "" {
|
||||
return strings.TrimSpace(fallback), nil
|
||||
return strings.TrimSpace(fallback), tierID, nil
|
||||
}
|
||||
return "", errors.New("onboardUser completed but no project_id returned")
|
||||
return "", tierID, errors.New("onboardUser completed but no project_id returned")
|
||||
}
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
|
||||
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
|
||||
if fbErr == nil && strings.TrimSpace(fallback) != "" {
|
||||
return strings.TrimSpace(fallback), nil
|
||||
return strings.TrimSpace(fallback), tierID, nil
|
||||
}
|
||||
if loadErr != nil {
|
||||
return "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
|
||||
return "", tierID, fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
|
||||
}
|
||||
return "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
|
||||
return "", tierID, fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
|
||||
}
|
||||
|
||||
type googleCloudProject struct {
|
||||
@@ -497,11 +582,12 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
if strings.TrimSpace(proxyURL) != "" {
|
||||
if proxyURLParsed, err := url.Parse(strings.TrimSpace(proxyURL)); err == nil {
|
||||
client.Transport = &http.Transport{Proxy: http.ProxyURL(proxyURLParsed)}
|
||||
}
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: strings.TrimSpace(proxyURL),
|
||||
Timeout: 30 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
client = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
|
||||
268
backend/internal/service/gemini_quota.go
Normal file
268
backend/internal/service/gemini_quota.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
)
|
||||
|
||||
type geminiModelClass string
|
||||
|
||||
const (
|
||||
geminiModelPro geminiModelClass = "pro"
|
||||
geminiModelFlash geminiModelClass = "flash"
|
||||
)
|
||||
|
||||
type GeminiDailyQuota struct {
|
||||
ProRPD int64
|
||||
FlashRPD int64
|
||||
}
|
||||
|
||||
type GeminiTierPolicy struct {
|
||||
Quota GeminiDailyQuota
|
||||
Cooldown time.Duration
|
||||
}
|
||||
|
||||
type GeminiQuotaPolicy struct {
|
||||
tiers map[string]GeminiTierPolicy
|
||||
}
|
||||
|
||||
type GeminiUsageTotals struct {
|
||||
ProRequests int64
|
||||
FlashRequests int64
|
||||
ProTokens int64
|
||||
FlashTokens int64
|
||||
ProCost float64
|
||||
FlashCost float64
|
||||
}
|
||||
|
||||
const geminiQuotaCacheTTL = time.Minute
|
||||
|
||||
type geminiQuotaOverrides struct {
|
||||
Tiers map[string]config.GeminiTierQuotaConfig `json:"tiers"`
|
||||
}
|
||||
|
||||
type GeminiQuotaService struct {
|
||||
cfg *config.Config
|
||||
settingRepo SettingRepository
|
||||
mu sync.Mutex
|
||||
cachedAt time.Time
|
||||
policy *GeminiQuotaPolicy
|
||||
}
|
||||
|
||||
func NewGeminiQuotaService(cfg *config.Config, settingRepo SettingRepository) *GeminiQuotaService {
|
||||
return &GeminiQuotaService{
|
||||
cfg: cfg,
|
||||
settingRepo: settingRepo,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
|
||||
if s == nil {
|
||||
return newGeminiQuotaPolicy()
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
s.mu.Lock()
|
||||
if s.policy != nil && now.Sub(s.cachedAt) < geminiQuotaCacheTTL {
|
||||
policy := s.policy
|
||||
s.mu.Unlock()
|
||||
return policy
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
policy := newGeminiQuotaPolicy()
|
||||
if s.cfg != nil {
|
||||
policy.ApplyOverrides(s.cfg.Gemini.Quota.Tiers)
|
||||
if strings.TrimSpace(s.cfg.Gemini.Quota.Policy) != "" {
|
||||
var overrides geminiQuotaOverrides
|
||||
if err := json.Unmarshal([]byte(s.cfg.Gemini.Quota.Policy), &overrides); err != nil {
|
||||
log.Printf("gemini quota: parse config policy failed: %v", err)
|
||||
} else {
|
||||
policy.ApplyOverrides(overrides.Tiers)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if s.settingRepo != nil {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyGeminiQuotaPolicy)
|
||||
if err != nil && !errors.Is(err, ErrSettingNotFound) {
|
||||
log.Printf("gemini quota: load setting failed: %v", err)
|
||||
} else if strings.TrimSpace(value) != "" {
|
||||
var overrides geminiQuotaOverrides
|
||||
if err := json.Unmarshal([]byte(value), &overrides); err != nil {
|
||||
log.Printf("gemini quota: parse setting failed: %v", err)
|
||||
} else {
|
||||
policy.ApplyOverrides(overrides.Tiers)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.policy = policy
|
||||
s.cachedAt = now
|
||||
s.mu.Unlock()
|
||||
|
||||
return policy
|
||||
}
|
||||
|
||||
func (s *GeminiQuotaService) QuotaForAccount(ctx context.Context, account *Account) (GeminiDailyQuota, bool) {
|
||||
if account == nil || !account.IsGeminiCodeAssist() {
|
||||
return GeminiDailyQuota{}, false
|
||||
}
|
||||
policy := s.Policy(ctx)
|
||||
return policy.QuotaForTier(account.GeminiTierID())
|
||||
}
|
||||
|
||||
func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string) time.Duration {
|
||||
policy := s.Policy(ctx)
|
||||
return policy.CooldownForTier(tierID)
|
||||
}
|
||||
|
||||
func newGeminiQuotaPolicy() *GeminiQuotaPolicy {
|
||||
return &GeminiQuotaPolicy{
|
||||
tiers: map[string]GeminiTierPolicy{
|
||||
"LEGACY": {Quota: GeminiDailyQuota{ProRPD: 50, FlashRPD: 1500}, Cooldown: 30 * time.Minute},
|
||||
"PRO": {Quota: GeminiDailyQuota{ProRPD: 1500, FlashRPD: 4000}, Cooldown: 5 * time.Minute},
|
||||
"ULTRA": {Quota: GeminiDailyQuota{ProRPD: 2000, FlashRPD: 0}, Cooldown: 5 * time.Minute},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuotaConfig) {
|
||||
if p == nil || len(tiers) == 0 {
|
||||
return
|
||||
}
|
||||
for rawID, override := range tiers {
|
||||
tierID := normalizeGeminiTierID(rawID)
|
||||
if tierID == "" {
|
||||
continue
|
||||
}
|
||||
policy, ok := p.tiers[tierID]
|
||||
if !ok {
|
||||
policy = GeminiTierPolicy{Cooldown: 5 * time.Minute}
|
||||
}
|
||||
if override.ProRPD != nil {
|
||||
policy.Quota.ProRPD = clampGeminiQuotaInt64(*override.ProRPD)
|
||||
}
|
||||
if override.FlashRPD != nil {
|
||||
policy.Quota.FlashRPD = clampGeminiQuotaInt64(*override.FlashRPD)
|
||||
}
|
||||
if override.CooldownMinutes != nil {
|
||||
minutes := clampGeminiQuotaInt(*override.CooldownMinutes)
|
||||
policy.Cooldown = time.Duration(minutes) * time.Minute
|
||||
}
|
||||
p.tiers[tierID] = policy
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GeminiQuotaPolicy) QuotaForTier(tierID string) (GeminiDailyQuota, bool) {
|
||||
policy, ok := p.policyForTier(tierID)
|
||||
if !ok {
|
||||
return GeminiDailyQuota{}, false
|
||||
}
|
||||
return policy.Quota, true
|
||||
}
|
||||
|
||||
func (p *GeminiQuotaPolicy) CooldownForTier(tierID string) time.Duration {
|
||||
policy, ok := p.policyForTier(tierID)
|
||||
if ok && policy.Cooldown > 0 {
|
||||
return policy.Cooldown
|
||||
}
|
||||
return 5 * time.Minute
|
||||
}
|
||||
|
||||
func (p *GeminiQuotaPolicy) policyForTier(tierID string) (GeminiTierPolicy, bool) {
|
||||
if p == nil {
|
||||
return GeminiTierPolicy{}, false
|
||||
}
|
||||
normalized := normalizeGeminiTierID(tierID)
|
||||
if normalized == "" {
|
||||
normalized = "LEGACY"
|
||||
}
|
||||
if policy, ok := p.tiers[normalized]; ok {
|
||||
return policy, true
|
||||
}
|
||||
policy, ok := p.tiers["LEGACY"]
|
||||
return policy, ok
|
||||
}
|
||||
|
||||
func normalizeGeminiTierID(tierID string) string {
|
||||
return strings.ToUpper(strings.TrimSpace(tierID))
|
||||
}
|
||||
|
||||
func clampGeminiQuotaInt64(value int64) int64 {
|
||||
if value < 0 {
|
||||
return 0
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func clampGeminiQuotaInt(value int) int {
|
||||
if value < 0 {
|
||||
return 0
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func geminiCooldownForTier(tierID string) time.Duration {
|
||||
policy := newGeminiQuotaPolicy()
|
||||
return policy.CooldownForTier(tierID)
|
||||
}
|
||||
|
||||
func geminiModelClassFromName(model string) geminiModelClass {
|
||||
name := strings.ToLower(strings.TrimSpace(model))
|
||||
if strings.Contains(name, "flash") || strings.Contains(name, "lite") {
|
||||
return geminiModelFlash
|
||||
}
|
||||
return geminiModelPro
|
||||
}
|
||||
|
||||
func geminiAggregateUsage(stats []usagestats.ModelStat) GeminiUsageTotals {
|
||||
var totals GeminiUsageTotals
|
||||
for _, stat := range stats {
|
||||
switch geminiModelClassFromName(stat.Model) {
|
||||
case geminiModelFlash:
|
||||
totals.FlashRequests += stat.Requests
|
||||
totals.FlashTokens += stat.TotalTokens
|
||||
totals.FlashCost += stat.ActualCost
|
||||
default:
|
||||
totals.ProRequests += stat.Requests
|
||||
totals.ProTokens += stat.TotalTokens
|
||||
totals.ProCost += stat.ActualCost
|
||||
}
|
||||
}
|
||||
return totals
|
||||
}
|
||||
|
||||
func geminiQuotaLocation() *time.Location {
|
||||
loc, err := time.LoadLocation("America/Los_Angeles")
|
||||
if err != nil {
|
||||
return time.FixedZone("PST", -8*3600)
|
||||
}
|
||||
return loc
|
||||
}
|
||||
|
||||
func geminiDailyWindowStart(now time.Time) time.Time {
|
||||
loc := geminiQuotaLocation()
|
||||
localNow := now.In(loc)
|
||||
return time.Date(localNow.Year(), localNow.Month(), localNow.Day(), 0, 0, 0, 0, loc)
|
||||
}
|
||||
|
||||
func geminiDailyResetTime(now time.Time) time.Time {
|
||||
loc := geminiQuotaLocation()
|
||||
localNow := now.In(loc)
|
||||
start := time.Date(localNow.Year(), localNow.Month(), localNow.Day(), 0, 0, 0, 0, loc)
|
||||
reset := start.Add(24 * time.Hour)
|
||||
if !reset.After(localNow) {
|
||||
reset = reset.Add(24 * time.Hour)
|
||||
}
|
||||
return reset
|
||||
}
|
||||
@@ -50,7 +50,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
}
|
||||
|
||||
// 2) Refresh if needed (pre-expiry skew).
|
||||
expiresAt := parseExpiresAt(account)
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
@@ -66,7 +66,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
expiresAt = parseExpiresAt(account)
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew {
|
||||
if p.geminiOAuthService == nil {
|
||||
return "", errors.New("gemini oauth service not configured")
|
||||
@@ -83,7 +83,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
_ = p.accountRepo.Update(ctx, account)
|
||||
expiresAt = parseExpiresAt(account)
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -112,17 +112,21 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
}
|
||||
}
|
||||
|
||||
detected, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL)
|
||||
detected, tierID, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL)
|
||||
if err != nil {
|
||||
log.Printf("[GeminiTokenProvider] Auto-detect project_id failed: %v, fallback to AI Studio API mode", err)
|
||||
return accessToken, nil
|
||||
}
|
||||
detected = strings.TrimSpace(detected)
|
||||
tierID = strings.TrimSpace(tierID)
|
||||
if detected != "" {
|
||||
if account.Credentials == nil {
|
||||
account.Credentials = make(map[string]any)
|
||||
}
|
||||
account.Credentials["project_id"] = detected
|
||||
if tierID != "" {
|
||||
account.Credentials["tier_id"] = tierID
|
||||
}
|
||||
_ = p.accountRepo.Update(ctx, account)
|
||||
}
|
||||
}
|
||||
@@ -154,18 +158,3 @@ func geminiTokenCacheKey(account *Account) string {
|
||||
}
|
||||
return "account:" + strconv.FormatInt(account.ID, 10)
|
||||
}
|
||||
|
||||
func parseExpiresAt(account *Account) *time.Time {
|
||||
raw := strings.TrimSpace(account.GetCredential("expires_at"))
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
if unixSec, err := strconv.ParseInt(raw, 10, 64); err == nil && unixSec > 0 {
|
||||
t := time.Unix(unixSec, 0)
|
||||
return &t
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339, raw); err == nil {
|
||||
return &t
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -22,16 +21,11 @@ func (r *GeminiTokenRefresher) NeedsRefresh(account *Account, refreshWindow time
|
||||
if !r.CanRefresh(account) {
|
||||
return false
|
||||
}
|
||||
expiresAtStr := account.GetCredential("expires_at")
|
||||
if expiresAtStr == "" {
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil {
|
||||
return false
|
||||
}
|
||||
expiresAt, err := strconv.ParseInt(expiresAtStr, 10, 64)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
expiryTime := time.Unix(expiresAt, 0)
|
||||
return time.Until(expiryTime) < refreshWindow
|
||||
return time.Until(*expiresAt) < refreshWindow
|
||||
}
|
||||
|
||||
func (r *GeminiTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
|
||||
|
||||
@@ -11,10 +11,11 @@ type Group struct {
|
||||
IsExclusive bool
|
||||
Status string
|
||||
|
||||
SubscriptionType string
|
||||
DailyLimitUSD *float64
|
||||
WeeklyLimitUSD *float64
|
||||
MonthlyLimitUSD *float64
|
||||
SubscriptionType string
|
||||
DailyLimitUSD *float64
|
||||
WeeklyLimitUSD *float64
|
||||
MonthlyLimitUSD *float64
|
||||
DefaultValidityDays int
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
|
||||
@@ -2,8 +2,29 @@ package service
|
||||
|
||||
import "net/http"
|
||||
|
||||
// HTTPUpstream interface for making HTTP requests to upstream APIs (Claude, OpenAI, etc.)
|
||||
// This is a generic interface that can be used for any HTTP-based upstream service.
|
||||
// HTTPUpstream 上游 HTTP 请求接口
|
||||
// 用于向上游 API(Claude、OpenAI、Gemini 等)发送请求
|
||||
// 这是一个通用接口,可用于任何基于 HTTP 的上游服务
|
||||
//
|
||||
// 设计说明:
|
||||
// - 支持可选代理配置
|
||||
// - 支持账户级连接池隔离
|
||||
// - 实现类负责连接池管理和复用
|
||||
type HTTPUpstream interface {
|
||||
Do(req *http.Request, proxyURL string) (*http.Response, error)
|
||||
// Do 执行 HTTP 请求
|
||||
//
|
||||
// 参数:
|
||||
// - req: HTTP 请求对象,由调用方构建
|
||||
// - proxyURL: 代理服务器地址,空字符串表示直连
|
||||
// - accountID: 账户 ID,用于连接池隔离(隔离策略为 account 或 account_proxy 时生效)
|
||||
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
|
||||
//
|
||||
// 返回:
|
||||
// - *http.Response: HTTP 响应,调用方必须关闭 Body
|
||||
// - error: 请求错误(网络错误、超时等)
|
||||
//
|
||||
// 注意:
|
||||
// - 调用方必须关闭 resp.Body,否则会导致连接泄漏
|
||||
// - 响应体可能已被包装以跟踪请求生命周期
|
||||
Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -80,6 +81,7 @@ type OpenAIGatewayService struct {
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache GatewayCache
|
||||
cfg *config.Config
|
||||
concurrencyService *ConcurrencyService
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
@@ -95,6 +97,7 @@ func NewOpenAIGatewayService(
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
cache GatewayCache,
|
||||
cfg *config.Config,
|
||||
concurrencyService *ConcurrencyService,
|
||||
billingService *BillingService,
|
||||
rateLimitService *RateLimitService,
|
||||
billingCacheService *BillingCacheService,
|
||||
@@ -108,6 +111,7 @@ func NewOpenAIGatewayService(
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: concurrencyService,
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
billingCacheService: billingCacheService,
|
||||
@@ -126,6 +130,14 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// BindStickySession sets session -> account binding with standard TTL.
|
||||
func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
|
||||
if sessionHash == "" || accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
return s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, accountID, openaiStickySessionTTL)
|
||||
}
|
||||
|
||||
// SelectAccount selects an OpenAI account with sticky session support
|
||||
func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
|
||||
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
|
||||
@@ -218,6 +230,254 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
|
||||
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
|
||||
cfg := s.schedulingConfig()
|
||||
var stickyAccountID int64
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash); err == nil {
|
||||
stickyAccountID = accountID
|
||||
}
|
||||
}
|
||||
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
|
||||
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
accounts, err := s.listSchedulableAccounts(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(accounts) == 0 {
|
||||
return nil, errors.New("no available accounts")
|
||||
}
|
||||
|
||||
isExcluded := func(accountID int64) bool {
|
||||
if excludedIDs == nil {
|
||||
return false
|
||||
}
|
||||
_, excluded := excludedIDs[accountID]
|
||||
return excluded
|
||||
}
|
||||
|
||||
// ============ Layer 1: Sticky session ============
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
|
||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
|
||||
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Layer 2: Load-aware selection ============
|
||||
candidates := make([]*Account, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if isExcluded(acc.ID) {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, acc)
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return nil, errors.New("no available accounts")
|
||||
}
|
||||
|
||||
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
|
||||
for _, acc := range candidates {
|
||||
accountLoads = append(accountLoads, AccountWithConcurrency{
|
||||
ID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
})
|
||||
}
|
||||
|
||||
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
||||
if err != nil {
|
||||
ordered := append([]*Account(nil), candidates...)
|
||||
sortAccountsByPriorityAndLastUsed(ordered, false)
|
||||
for _, acc := range ordered {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" {
|
||||
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, acc.ID, openaiStickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
type accountWithLoad struct {
|
||||
account *Account
|
||||
loadInfo *AccountLoadInfo
|
||||
}
|
||||
var available []accountWithLoad
|
||||
for _, acc := range candidates {
|
||||
loadInfo := loadMap[acc.ID]
|
||||
if loadInfo == nil {
|
||||
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
|
||||
}
|
||||
if loadInfo.LoadRate < 100 {
|
||||
available = append(available, accountWithLoad{
|
||||
account: acc,
|
||||
loadInfo: loadInfo,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(available) > 0 {
|
||||
sort.SliceStable(available, func(i, j int) bool {
|
||||
a, b := available[i], available[j]
|
||||
if a.account.Priority != b.account.Priority {
|
||||
return a.account.Priority < b.account.Priority
|
||||
}
|
||||
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
||||
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
|
||||
}
|
||||
switch {
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
|
||||
return true
|
||||
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
default:
|
||||
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||
}
|
||||
})
|
||||
|
||||
for _, item := range available {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" {
|
||||
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: item.account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Layer 3: Fallback wait ============
|
||||
sortAccountsByPriorityAndLastUsed(candidates, false)
|
||||
for _, acc := range candidates {
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("no available accounts")
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) {
|
||||
var accounts []Account
|
||||
var err error
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
||||
} else if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
return accounts, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||
if s.concurrencyService == nil {
|
||||
return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil
|
||||
}
|
||||
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||
if s.cfg != nil {
|
||||
return s.cfg.Gateway.Scheduling
|
||||
}
|
||||
return config.GatewaySchedulingConfig{
|
||||
StickySessionMaxWaiting: 3,
|
||||
StickySessionWaitTimeout: 45 * time.Second,
|
||||
FallbackWaitTimeout: 30 * time.Second,
|
||||
FallbackMaxWaiting: 100,
|
||||
LoadBatchEnabled: true,
|
||||
SlotCleanupInterval: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccessToken gets the access token for an OpenAI account
|
||||
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
|
||||
switch account.Type {
|
||||
@@ -311,7 +571,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}
|
||||
|
||||
// Send request
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL)
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
@@ -772,20 +1032,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
if isSubscriptionBilling {
|
||||
if cost.TotalCost > 0 {
|
||||
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.UpdateSubscriptionUsage(cacheCtx, user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
}()
|
||||
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
} else {
|
||||
if cost.ActualCost > 0 {
|
||||
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.DeductBalanceCache(cacheCtx, user.ID, cost.ActualCost)
|
||||
}()
|
||||
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,11 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
)
|
||||
|
||||
var (
|
||||
openAIModelDatePattern = regexp.MustCompile(`-\d{8}$`)
|
||||
openAIModelBasePattern = regexp.MustCompile(`^(gpt-\d+(?:\.\d+)?)(?:-|$)`)
|
||||
)
|
||||
|
||||
// LiteLLMModelPricing LiteLLM价格数据结构
|
||||
// 只保留我们需要的字段,使用指针来处理可能缺失的值
|
||||
type LiteLLMModelPricing struct {
|
||||
@@ -595,11 +600,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
// 2. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号)
|
||||
// 3. 最终回退到 DefaultTestModel (gpt-5.1-codex)
|
||||
func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
|
||||
// 正则匹配日期后缀 (如 -20251222)
|
||||
datePattern := regexp.MustCompile(`-\d{8}$`)
|
||||
|
||||
// 尝试的回退变体
|
||||
variants := s.generateOpenAIModelVariants(model, datePattern)
|
||||
variants := s.generateOpenAIModelVariants(model, openAIModelDatePattern)
|
||||
|
||||
for _, variant := range variants {
|
||||
if pricing, ok := s.pricingData[variant]; ok {
|
||||
@@ -638,14 +640,13 @@ func (s *PricingService) generateOpenAIModelVariants(model string, datePattern *
|
||||
|
||||
// 2. 提取基础版本号: gpt-5.2-codex -> gpt-5.2
|
||||
// 只匹配纯数字版本号格式 gpt-X 或 gpt-X.Y,不匹配 gpt-4o 这种带字母后缀的
|
||||
basePattern := regexp.MustCompile(`^(gpt-\d+(?:\.\d+)?)(?:-|$)`)
|
||||
if matches := basePattern.FindStringSubmatch(model); len(matches) > 1 {
|
||||
if matches := openAIModelBasePattern.FindStringSubmatch(model); len(matches) > 1 {
|
||||
addVariant(matches[1])
|
||||
}
|
||||
|
||||
// 3. 同时去掉日期后再提取基础版本号
|
||||
if withoutDate != model {
|
||||
if matches := basePattern.FindStringSubmatch(withoutDate); len(matches) > 1 {
|
||||
if matches := openAIModelBasePattern.FindStringSubmatch(withoutDate); len(matches) > 1 {
|
||||
addVariant(matches[1])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
@@ -12,15 +14,30 @@ import (
|
||||
|
||||
// RateLimitService 处理限流和过载状态管理
|
||||
type RateLimitService struct {
|
||||
accountRepo AccountRepository
|
||||
cfg *config.Config
|
||||
accountRepo AccountRepository
|
||||
usageRepo UsageLogRepository
|
||||
cfg *config.Config
|
||||
geminiQuotaService *GeminiQuotaService
|
||||
usageCacheMu sync.Mutex
|
||||
usageCache map[int64]*geminiUsageCacheEntry
|
||||
}
|
||||
|
||||
type geminiUsageCacheEntry struct {
|
||||
windowStart time.Time
|
||||
cachedAt time.Time
|
||||
totals GeminiUsageTotals
|
||||
}
|
||||
|
||||
const geminiPrecheckCacheTTL = time.Minute
|
||||
|
||||
// NewRateLimitService 创建RateLimitService实例
|
||||
func NewRateLimitService(accountRepo AccountRepository, cfg *config.Config) *RateLimitService {
|
||||
func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService) *RateLimitService {
|
||||
return &RateLimitService{
|
||||
accountRepo: accountRepo,
|
||||
cfg: cfg,
|
||||
accountRepo: accountRepo,
|
||||
usageRepo: usageRepo,
|
||||
cfg: cfg,
|
||||
geminiQuotaService: geminiQuotaService,
|
||||
usageCache: make(map[int64]*geminiUsageCacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,6 +79,106 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
}
|
||||
}
|
||||
|
||||
// PreCheckUsage proactively checks local quota before dispatching a request.
|
||||
// Returns false when the account should be skipped.
|
||||
func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, requestedModel string) (bool, error) {
|
||||
if account == nil || !account.IsGeminiCodeAssist() || strings.TrimSpace(requestedModel) == "" {
|
||||
return true, nil
|
||||
}
|
||||
if s.usageRepo == nil || s.geminiQuotaService == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account)
|
||||
if !ok {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
var limit int64
|
||||
switch geminiModelClassFromName(requestedModel) {
|
||||
case geminiModelFlash:
|
||||
limit = quota.FlashRPD
|
||||
default:
|
||||
limit = quota.ProRPD
|
||||
}
|
||||
if limit <= 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
start := geminiDailyWindowStart(now)
|
||||
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
|
||||
if !ok {
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
totals = geminiAggregateUsage(stats)
|
||||
s.setGeminiUsageTotals(account.ID, start, now, totals)
|
||||
}
|
||||
|
||||
var used int64
|
||||
switch geminiModelClassFromName(requestedModel) {
|
||||
case geminiModelFlash:
|
||||
used = totals.FlashRequests
|
||||
default:
|
||||
used = totals.ProRequests
|
||||
}
|
||||
|
||||
if used >= limit {
|
||||
resetAt := geminiDailyResetTime(now)
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), rate limited until %v", account.ID, used, limit, resetAt)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *RateLimitService) getGeminiUsageTotals(accountID int64, windowStart, now time.Time) (GeminiUsageTotals, bool) {
|
||||
s.usageCacheMu.Lock()
|
||||
defer s.usageCacheMu.Unlock()
|
||||
|
||||
if s.usageCache == nil {
|
||||
return GeminiUsageTotals{}, false
|
||||
}
|
||||
|
||||
entry, ok := s.usageCache[accountID]
|
||||
if !ok || entry == nil {
|
||||
return GeminiUsageTotals{}, false
|
||||
}
|
||||
if !entry.windowStart.Equal(windowStart) {
|
||||
return GeminiUsageTotals{}, false
|
||||
}
|
||||
if now.Sub(entry.cachedAt) >= geminiPrecheckCacheTTL {
|
||||
return GeminiUsageTotals{}, false
|
||||
}
|
||||
return entry.totals, true
|
||||
}
|
||||
|
||||
func (s *RateLimitService) setGeminiUsageTotals(accountID int64, windowStart, now time.Time, totals GeminiUsageTotals) {
|
||||
s.usageCacheMu.Lock()
|
||||
defer s.usageCacheMu.Unlock()
|
||||
if s.usageCache == nil {
|
||||
s.usageCache = make(map[int64]*geminiUsageCacheEntry)
|
||||
}
|
||||
s.usageCache[accountID] = &geminiUsageCacheEntry{
|
||||
windowStart: windowStart,
|
||||
cachedAt: now,
|
||||
totals: totals,
|
||||
}
|
||||
}
|
||||
|
||||
// GeminiCooldown returns the fallback cooldown duration for Gemini 429s based on tier.
|
||||
func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account) time.Duration {
|
||||
if account == nil {
|
||||
return 5 * time.Minute
|
||||
}
|
||||
return s.geminiQuotaService.CooldownForTier(ctx, account.GeminiTierID())
|
||||
}
|
||||
|
||||
// handleAuthError 处理认证类错误(401/403),停止账号调度
|
||||
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
|
||||
@@ -9,7 +9,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
@@ -72,6 +73,7 @@ type RedeemService struct {
|
||||
subscriptionService *SubscriptionService
|
||||
cache RedeemCache
|
||||
billingCacheService *BillingCacheService
|
||||
entClient *dbent.Client
|
||||
}
|
||||
|
||||
// NewRedeemService 创建兑换码服务实例
|
||||
@@ -81,6 +83,7 @@ func NewRedeemService(
|
||||
subscriptionService *SubscriptionService,
|
||||
cache RedeemCache,
|
||||
billingCacheService *BillingCacheService,
|
||||
entClient *dbent.Client,
|
||||
) *RedeemService {
|
||||
return &RedeemService{
|
||||
redeemRepo: redeemRepo,
|
||||
@@ -88,6 +91,7 @@ func NewRedeemService(
|
||||
subscriptionService: subscriptionService,
|
||||
cache: cache,
|
||||
billingCacheService: billingCacheService,
|
||||
entClient: entClient,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -248,9 +252,19 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
||||
}
|
||||
_ = user // 使用变量避免未使用错误
|
||||
|
||||
// 使用数据库事务保证兑换码标记与权益发放的原子性
|
||||
tx, err := s.entClient.Tx(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
// 将事务放入 context,使 repository 方法能够使用同一事务
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
|
||||
// 【关键】先标记兑换码为已使用,确保并发安全
|
||||
// 利用数据库乐观锁(WHERE status = 'unused')保证原子性
|
||||
if err := s.redeemRepo.Use(ctx, redeemCode.ID, userID); err != nil {
|
||||
if err := s.redeemRepo.Use(txCtx, redeemCode.ID, userID); err != nil {
|
||||
if errors.Is(err, ErrRedeemCodeNotFound) || errors.Is(err, ErrRedeemCodeUsed) {
|
||||
return nil, ErrRedeemCodeUsed
|
||||
}
|
||||
@@ -261,21 +275,13 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
||||
switch redeemCode.Type {
|
||||
case RedeemTypeBalance:
|
||||
// 增加用户余额
|
||||
if err := s.userRepo.UpdateBalance(ctx, userID, redeemCode.Value); err != nil {
|
||||
if err := s.userRepo.UpdateBalance(txCtx, userID, redeemCode.Value); err != nil {
|
||||
return nil, fmt.Errorf("update user balance: %w", err)
|
||||
}
|
||||
// 失效余额缓存
|
||||
if s.billingCacheService != nil {
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
||||
}()
|
||||
}
|
||||
|
||||
case RedeemTypeConcurrency:
|
||||
// 增加用户并发数
|
||||
if err := s.userRepo.UpdateConcurrency(ctx, userID, int(redeemCode.Value)); err != nil {
|
||||
if err := s.userRepo.UpdateConcurrency(txCtx, userID, int(redeemCode.Value)); err != nil {
|
||||
return nil, fmt.Errorf("update user concurrency: %w", err)
|
||||
}
|
||||
|
||||
@@ -284,7 +290,7 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
||||
if validityDays <= 0 {
|
||||
validityDays = 30
|
||||
}
|
||||
_, _, err := s.subscriptionService.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
|
||||
_, _, err := s.subscriptionService.AssignOrExtendSubscription(txCtx, &AssignSubscriptionInput{
|
||||
UserID: userID,
|
||||
GroupID: *redeemCode.GroupID,
|
||||
ValidityDays: validityDays,
|
||||
@@ -294,20 +300,19 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("assign or extend subscription: %w", err)
|
||||
}
|
||||
// 失效订阅缓存
|
||||
if s.billingCacheService != nil {
|
||||
groupID := *redeemCode.GroupID
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported redeem type: %s", redeemCode.Type)
|
||||
}
|
||||
|
||||
// 提交事务
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("commit transaction: %w", err)
|
||||
}
|
||||
|
||||
// 事务提交成功后失效缓存
|
||||
s.invalidateRedeemCaches(ctx, userID, redeemCode)
|
||||
|
||||
// 重新获取更新后的兑换码
|
||||
redeemCode, err = s.redeemRepo.GetByID(ctx, redeemCode.ID)
|
||||
if err != nil {
|
||||
@@ -317,6 +322,31 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
|
||||
return redeemCode, nil
|
||||
}
|
||||
|
||||
// invalidateRedeemCaches 失效兑换相关的缓存
|
||||
func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64, redeemCode *RedeemCode) {
|
||||
if s.billingCacheService == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch redeemCode.Type {
|
||||
case RedeemTypeBalance:
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
||||
}()
|
||||
case RedeemTypeSubscription:
|
||||
if redeemCode.GroupID != nil {
|
||||
groupID := *redeemCode.GroupID
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID)
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取兑换码
|
||||
func (s *RedeemService) GetByID(ctx context.Context, id int64) (*RedeemCode, error) {
|
||||
code, err := s.redeemRepo.GetByID(ctx, id)
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
@@ -26,6 +26,7 @@ var (
|
||||
ErrDailyLimitExceeded = infraerrors.TooManyRequests("DAILY_LIMIT_EXCEEDED", "daily usage limit exceeded")
|
||||
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
|
||||
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
|
||||
ErrSubscriptionNilInput = infraerrors.BadRequest("SUBSCRIPTION_NIL_INPUT", "subscription input cannot be nil")
|
||||
)
|
||||
|
||||
// SubscriptionService 订阅服务
|
||||
@@ -489,6 +490,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *Use
|
||||
}
|
||||
|
||||
// CheckUsageLimits 检查使用限额(返回错误如果超限)
|
||||
// 用于中间件的快速预检查,additionalCost 通常为 0
|
||||
func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *UserSubscription, group *Group, additionalCost float64) error {
|
||||
if !sub.CheckDailyLimit(group, additionalCost) {
|
||||
return ErrDailyLimitExceeded
|
||||
|
||||
@@ -43,17 +43,11 @@ func (r *ClaudeTokenRefresher) CanRefresh(account *Account) bool {
|
||||
// NeedsRefresh 检查token是否需要刷新
|
||||
// 基于 expires_at 字段判断是否在刷新窗口内
|
||||
func (r *ClaudeTokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
|
||||
s := account.GetCredential("expires_at")
|
||||
if s == "" {
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
expiresAt, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return time.Until(time.Unix(expiresAt, 0)) < refreshWindow
|
||||
return time.Until(*expiresAt) < refreshWindow
|
||||
}
|
||||
|
||||
// Refresh 执行token刷新
|
||||
|
||||
@@ -33,6 +33,13 @@ func TestClaudeTokenRefresher_NeedsRefresh(t *testing.T) {
|
||||
},
|
||||
wantRefresh: true,
|
||||
},
|
||||
{
|
||||
name: "expires_at as RFC3339 - expired",
|
||||
credentials: map[string]any{
|
||||
"expires_at": "1970-01-01T00:00:00Z", // RFC3339 格式,已过期
|
||||
},
|
||||
wantRefresh: true,
|
||||
},
|
||||
{
|
||||
name: "expires_at as string - far future",
|
||||
credentials: map[string]any{
|
||||
@@ -47,6 +54,13 @@ func TestClaudeTokenRefresher_NeedsRefresh(t *testing.T) {
|
||||
},
|
||||
wantRefresh: false,
|
||||
},
|
||||
{
|
||||
name: "expires_at as RFC3339 - far future",
|
||||
credentials: map[string]any{
|
||||
"expires_at": "2099-12-31T23:59:59Z", // RFC3339 格式,远未来
|
||||
},
|
||||
wantRefresh: false,
|
||||
},
|
||||
{
|
||||
name: "expires_at missing",
|
||||
credentials: map[string]any{},
|
||||
|
||||
@@ -5,12 +5,13 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTurnstileVerificationFailed = infraerrors.BadRequest("TURNSTILE_VERIFICATION_FAILED", "turnstile verification failed")
|
||||
ErrTurnstileNotConfigured = infraerrors.ServiceUnavailable("TURNSTILE_NOT_CONFIGURED", "turnstile not configured")
|
||||
ErrTurnstileInvalidSecretKey = infraerrors.BadRequest("TURNSTILE_INVALID_SECRET_KEY", "invalid turnstile secret key")
|
||||
)
|
||||
|
||||
// TurnstileVerifier 验证 Turnstile token 的接口
|
||||
@@ -83,3 +84,22 @@ func (s *TurnstileService) VerifyToken(ctx context.Context, token string, remote
|
||||
func (s *TurnstileService) IsEnabled(ctx context.Context) bool {
|
||||
return s.settingService.IsTurnstileEnabled(ctx)
|
||||
}
|
||||
|
||||
// ValidateSecretKey 验证 Turnstile Secret Key 是否有效
|
||||
func (s *TurnstileService) ValidateSecretKey(ctx context.Context, secretKey string) error {
|
||||
// 发送一个测试token的验证请求来检查secret_key是否有效
|
||||
result, err := s.verifier.VerifyToken(ctx, secretKey, "test-validation", "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("validate secret key: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否有 invalid-input-secret 错误
|
||||
for _, code := range result.ErrorCodes {
|
||||
if code == "invalid-input-secret" {
|
||||
return ErrTurnstileInvalidSecretKey
|
||||
}
|
||||
}
|
||||
|
||||
// 其他错误(如 invalid-input-response)说明 secret key 是有效的
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
)
|
||||
@@ -186,22 +186,40 @@ func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, sta
|
||||
|
||||
// GetStatsByAccount 获取账号的使用统计
|
||||
func (s *UsageService) GetStatsByAccount(ctx context.Context, accountID int64, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
logs, _, err := s.usageRepo.ListByAccountAndTimeRange(ctx, accountID, startTime, endTime)
|
||||
stats, err := s.usageRepo.GetAccountStatsAggregated(ctx, accountID, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list usage logs: %w", err)
|
||||
return nil, fmt.Errorf("get account stats: %w", err)
|
||||
}
|
||||
|
||||
return s.calculateStats(logs), nil
|
||||
return &UsageStats{
|
||||
TotalRequests: stats.TotalRequests,
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheTokens: stats.TotalCacheTokens,
|
||||
TotalTokens: stats.TotalTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
TotalActualCost: stats.TotalActualCost,
|
||||
AverageDurationMs: stats.AverageDurationMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetStatsByModel 获取模型的使用统计
|
||||
func (s *UsageService) GetStatsByModel(ctx context.Context, modelName string, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
logs, _, err := s.usageRepo.ListByModelAndTimeRange(ctx, modelName, startTime, endTime)
|
||||
stats, err := s.usageRepo.GetModelStatsAggregated(ctx, modelName, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list usage logs: %w", err)
|
||||
return nil, fmt.Errorf("get model stats: %w", err)
|
||||
}
|
||||
|
||||
return s.calculateStats(logs), nil
|
||||
return &UsageStats{
|
||||
TotalRequests: stats.TotalRequests,
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheTokens: stats.TotalCacheTokens,
|
||||
TotalTokens: stats.TotalTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
TotalActualCost: stats.TotalActualCost,
|
||||
AverageDurationMs: stats.AverageDurationMs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetDailyStats 获取每日使用统计(最近N天)
|
||||
@@ -209,80 +227,12 @@ func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int
|
||||
endTime := time.Now()
|
||||
startTime := endTime.AddDate(0, 0, -days)
|
||||
|
||||
logs, _, err := s.usageRepo.ListByUserAndTimeRange(ctx, userID, startTime, endTime)
|
||||
stats, err := s.usageRepo.GetDailyStatsAggregated(ctx, userID, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list usage logs: %w", err)
|
||||
return nil, fmt.Errorf("get daily stats: %w", err)
|
||||
}
|
||||
|
||||
// 按日期分组统计
|
||||
dailyStats := make(map[string]*UsageStats)
|
||||
for _, log := range logs {
|
||||
dateKey := log.CreatedAt.Format("2006-01-02")
|
||||
if _, exists := dailyStats[dateKey]; !exists {
|
||||
dailyStats[dateKey] = &UsageStats{}
|
||||
}
|
||||
|
||||
stats := dailyStats[dateKey]
|
||||
stats.TotalRequests++
|
||||
stats.TotalInputTokens += int64(log.InputTokens)
|
||||
stats.TotalOutputTokens += int64(log.OutputTokens)
|
||||
stats.TotalCacheTokens += int64(log.CacheCreationTokens + log.CacheReadTokens)
|
||||
stats.TotalTokens += int64(log.TotalTokens())
|
||||
stats.TotalCost += log.TotalCost
|
||||
stats.TotalActualCost += log.ActualCost
|
||||
|
||||
if log.DurationMs != nil {
|
||||
stats.AverageDurationMs += float64(*log.DurationMs)
|
||||
}
|
||||
}
|
||||
|
||||
// 计算平均值并转换为数组
|
||||
result := make([]map[string]any, 0, len(dailyStats))
|
||||
for date, stats := range dailyStats {
|
||||
if stats.TotalRequests > 0 {
|
||||
stats.AverageDurationMs /= float64(stats.TotalRequests)
|
||||
}
|
||||
|
||||
result = append(result, map[string]any{
|
||||
"date": date,
|
||||
"total_requests": stats.TotalRequests,
|
||||
"total_input_tokens": stats.TotalInputTokens,
|
||||
"total_output_tokens": stats.TotalOutputTokens,
|
||||
"total_cache_tokens": stats.TotalCacheTokens,
|
||||
"total_tokens": stats.TotalTokens,
|
||||
"total_cost": stats.TotalCost,
|
||||
"total_actual_cost": stats.TotalActualCost,
|
||||
"average_duration_ms": stats.AverageDurationMs,
|
||||
})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// calculateStats 计算统计数据
|
||||
func (s *UsageService) calculateStats(logs []UsageLog) *UsageStats {
|
||||
stats := &UsageStats{}
|
||||
|
||||
for _, log := range logs {
|
||||
stats.TotalRequests++
|
||||
stats.TotalInputTokens += int64(log.InputTokens)
|
||||
stats.TotalOutputTokens += int64(log.OutputTokens)
|
||||
stats.TotalCacheTokens += int64(log.CacheCreationTokens + log.CacheReadTokens)
|
||||
stats.TotalTokens += int64(log.TotalTokens())
|
||||
stats.TotalCost += log.TotalCost
|
||||
stats.TotalActualCost += log.ActualCost
|
||||
|
||||
if log.DurationMs != nil {
|
||||
stats.AverageDurationMs += float64(*log.DurationMs)
|
||||
}
|
||||
}
|
||||
|
||||
// 计算平均持续时间
|
||||
if stats.TotalRequests > 0 {
|
||||
stats.AverageDurationMs /= float64(stats.TotalRequests)
|
||||
}
|
||||
|
||||
return stats
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// Delete 删除使用日志(管理员功能,谨慎使用)
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
|
||||
@@ -73,6 +73,15 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker.
|
||||
func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService {
|
||||
svc := NewConcurrencyService(cache)
|
||||
if cfg != nil {
|
||||
svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval)
|
||||
}
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProviderSet is the Wire provider set for all services
|
||||
var ProviderSet = wire.NewSet(
|
||||
// Core services
|
||||
@@ -94,6 +103,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewOAuthService,
|
||||
NewOpenAIOAuthService,
|
||||
NewGeminiOAuthService,
|
||||
NewGeminiQuotaService,
|
||||
NewAntigravityOAuthService,
|
||||
NewGeminiTokenProvider,
|
||||
NewGeminiMessagesCompatService,
|
||||
@@ -107,7 +117,7 @@ var ProviderSet = wire.NewSet(
|
||||
ProvideEmailQueueService,
|
||||
NewTurnstileService,
|
||||
NewSubscriptionService,
|
||||
NewConcurrencyService,
|
||||
ProvideConcurrencyService,
|
||||
NewIdentityService,
|
||||
NewCRSSyncService,
|
||||
ProvideUpdateService,
|
||||
|
||||
Reference in New Issue
Block a user