This reverts commit 8d252303fc.
This commit is contained in:
@@ -358,15 +358,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
return nil, fmt.Errorf("transform request: %w", err)
|
||||
}
|
||||
|
||||
// 调试:记录转换后的请求体(仅记录前 2000 字符)
|
||||
if bodyJSON, err := json.Marshal(geminiBody); err == nil {
|
||||
truncated := string(bodyJSON)
|
||||
if len(truncated) > 2000 {
|
||||
truncated = truncated[:2000] + "..."
|
||||
}
|
||||
log.Printf("[Debug] Transformed Gemini request: %s", truncated)
|
||||
}
|
||||
|
||||
// 构建上游 action
|
||||
action := "generateContent"
|
||||
if claudeReq.Stream {
|
||||
|
||||
@@ -18,11 +18,6 @@ type ConcurrencyCache interface {
|
||||
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
|
||||
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
|
||||
|
||||
// 账号等待队列(账号级)
|
||||
IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error)
|
||||
DecrementAccountWaitCount(ctx context.Context, accountID int64) error
|
||||
GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error)
|
||||
|
||||
// 用户槽位管理
|
||||
// 键格式: concurrency:user:{userID}(有序集合,成员为 requestID)
|
||||
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
|
||||
@@ -32,12 +27,6 @@ type ConcurrencyCache interface {
|
||||
// 等待队列计数(只在首次创建时设置 TTL)
|
||||
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
|
||||
DecrementWaitCount(ctx context.Context, userID int64) error
|
||||
|
||||
// 批量负载查询(只读)
|
||||
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
|
||||
|
||||
// 清理过期槽位(后台任务)
|
||||
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
|
||||
}
|
||||
|
||||
// generateRequestID generates a unique request ID for concurrency slot tracking
|
||||
@@ -72,18 +61,6 @@ type AcquireResult struct {
|
||||
ReleaseFunc func() // Must be called when done (typically via defer)
|
||||
}
|
||||
|
||||
type AccountWithConcurrency struct {
|
||||
ID int64
|
||||
MaxConcurrency int
|
||||
}
|
||||
|
||||
type AccountLoadInfo struct {
|
||||
AccountID int64
|
||||
CurrentConcurrency int
|
||||
WaitingCount int
|
||||
LoadRate int // 0-100+ (percent)
|
||||
}
|
||||
|
||||
// AcquireAccountSlot attempts to acquire a concurrency slot for an account.
|
||||
// If the account is at max concurrency, it waits until a slot is available or timeout.
|
||||
// Returns a release function that MUST be called when the request completes.
|
||||
@@ -200,42 +177,6 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
|
||||
}
|
||||
}
|
||||
|
||||
// IncrementAccountWaitCount increments the wait queue counter for an account.
|
||||
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
if s.cache == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait)
|
||||
if err != nil {
|
||||
log.Printf("Warning: increment wait count failed for account %d: %v", accountID, err)
|
||||
return true, nil
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DecrementAccountWaitCount decrements the wait queue counter for an account.
|
||||
func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil {
|
||||
log.Printf("Warning: decrement wait count failed for account %d: %v", accountID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccountWaitingCount gets current wait queue count for an account.
|
||||
func (s *ConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
if s.cache == nil {
|
||||
return 0, nil
|
||||
}
|
||||
return s.cache.GetAccountWaitingCount(ctx, accountID)
|
||||
}
|
||||
|
||||
// CalculateMaxWait calculates the maximum wait queue size for a user
|
||||
// maxWait = userConcurrency + defaultExtraWaitSlots
|
||||
func CalculateMaxWait(userConcurrency int) int {
|
||||
@@ -245,57 +186,6 @@ func CalculateMaxWait(userConcurrency int) int {
|
||||
return userConcurrency + defaultExtraWaitSlots
|
||||
}
|
||||
|
||||
// GetAccountsLoadBatch returns load info for multiple accounts.
|
||||
func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
if s.cache == nil {
|
||||
return map[int64]*AccountLoadInfo{}, nil
|
||||
}
|
||||
return s.cache.GetAccountsLoadBatch(ctx, accounts)
|
||||
}
|
||||
|
||||
// CleanupExpiredAccountSlots removes expired slots for one account (background task).
|
||||
func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
return s.cache.CleanupExpiredAccountSlots(ctx, accountID)
|
||||
}
|
||||
|
||||
// StartSlotCleanupWorker starts a background cleanup worker for expired account slots.
|
||||
func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepository, interval time.Duration) {
|
||||
if s == nil || s.cache == nil || accountRepo == nil || interval <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
runCleanup := func() {
|
||||
listCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
accounts, err := accountRepo.ListSchedulable(listCtx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
log.Printf("Warning: list schedulable accounts failed: %v", err)
|
||||
return
|
||||
}
|
||||
for _, account := range accounts {
|
||||
accountCtx, accountCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID)
|
||||
accountCancel()
|
||||
if err != nil {
|
||||
log.Printf("Warning: cleanup expired slots failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
runCleanup()
|
||||
for range ticker.C {
|
||||
runCleanup()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
|
||||
// Returns a map of accountID -> current concurrency count
|
||||
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
|
||||
@@ -261,34 +261,6 @@ func TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed(t
|
||||
require.Equal(t, int64(2), acc.ID, "同优先级应选择最久未用的账户")
|
||||
}
|
||||
|
||||
func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户")
|
||||
}
|
||||
|
||||
// TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts 测试无可用账户
|
||||
func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
@@ -604,32 +576,6 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
|
||||
func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户")
|
||||
})
|
||||
|
||||
t.Run("混合调度-包含启用mixed_scheduling的antigravity账户", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
@@ -837,160 +783,3 @@ func TestAccount_IsMixedSchedulingEnabled(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mockConcurrencyService for testing
|
||||
type mockConcurrencyService struct {
|
||||
accountLoads map[int64]*AccountLoadInfo
|
||||
accountWaitCounts map[int64]int
|
||||
acquireResults map[int64]bool
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
if m.accountLoads == nil {
|
||||
return map[int64]*AccountLoadInfo{}, nil
|
||||
}
|
||||
result := make(map[int64]*AccountLoadInfo)
|
||||
for _, acc := range accounts {
|
||||
if load, ok := m.accountLoads[acc.ID]; ok {
|
||||
result[acc.ID] = load
|
||||
} else {
|
||||
result[acc.ID] = &AccountLoadInfo{
|
||||
AccountID: acc.ID,
|
||||
CurrentConcurrency: 0,
|
||||
WaitingCount: 0,
|
||||
LoadRate: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
if m.accountWaitCounts == nil {
|
||||
return 0, nil
|
||||
}
|
||||
return m.accountWaitCounts[accountID], nil
|
||||
}
|
||||
|
||||
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
|
||||
func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("禁用负载批量查询-降级到传统选择", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
cfg := testConfig()
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = false
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: nil, // No concurrency service
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号")
|
||||
})
|
||||
|
||||
t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
cfg := testConfig()
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = true
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: nil,
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, int64(2), result.Account.ID, "应选择优先级最高的账号")
|
||||
})
|
||||
|
||||
t.Run("排除账号-不选择被排除的账号", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
cfg := testConfig()
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = false
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: nil,
|
||||
}
|
||||
|
||||
excludedIDs := map[int64]struct{}{1: {}}
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, int64(2), result.Account.ID, "不应选择被排除的账号")
|
||||
})
|
||||
|
||||
t.Run("无可用账号-返回错误", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{}
|
||||
|
||||
cfg := testConfig()
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = false
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: nil,
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
require.Contains(t, err.Error(), "no available accounts")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -13,14 +13,12 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -68,20 +66,6 @@ type GatewayCache interface {
|
||||
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
|
||||
}
|
||||
|
||||
type AccountWaitPlan struct {
|
||||
AccountID int64
|
||||
MaxConcurrency int
|
||||
Timeout time.Duration
|
||||
MaxWaiting int
|
||||
}
|
||||
|
||||
type AccountSelectionResult struct {
|
||||
Account *Account
|
||||
Acquired bool
|
||||
ReleaseFunc func()
|
||||
WaitPlan *AccountWaitPlan // nil means no wait allowed
|
||||
}
|
||||
|
||||
// ClaudeUsage 表示Claude API返回的usage信息
|
||||
type ClaudeUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
@@ -124,7 +108,6 @@ type GatewayService struct {
|
||||
identityService *IdentityService
|
||||
httpUpstream HTTPUpstream
|
||||
deferredService *DeferredService
|
||||
concurrencyService *ConcurrencyService
|
||||
}
|
||||
|
||||
// NewGatewayService creates a new GatewayService
|
||||
@@ -136,7 +119,6 @@ func NewGatewayService(
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
cache GatewayCache,
|
||||
cfg *config.Config,
|
||||
concurrencyService *ConcurrencyService,
|
||||
billingService *BillingService,
|
||||
rateLimitService *RateLimitService,
|
||||
billingCacheService *BillingCacheService,
|
||||
@@ -152,7 +134,6 @@ func NewGatewayService(
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: concurrencyService,
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
billingCacheService: billingCacheService,
|
||||
@@ -202,14 +183,6 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// BindStickySession sets session -> account binding with standard TTL.
|
||||
func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
|
||||
if sessionHash == "" || accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL)
|
||||
}
|
||||
|
||||
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
|
||||
if parsed == nil {
|
||||
return ""
|
||||
@@ -359,354 +332,8 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||
}
|
||||
|
||||
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
|
||||
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
|
||||
cfg := s.schedulingConfig()
|
||||
var stickyAccountID int64
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash); err == nil {
|
||||
stickyAccountID = accountID
|
||||
}
|
||||
}
|
||||
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
|
||||
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
preferOAuth := platform == PlatformGemini
|
||||
|
||||
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(accounts) == 0 {
|
||||
return nil, errors.New("no available accounts")
|
||||
}
|
||||
|
||||
isExcluded := func(accountID int64) bool {
|
||||
if excludedIDs == nil {
|
||||
return false
|
||||
}
|
||||
_, excluded := excludedIDs[accountID]
|
||||
return excluded
|
||||
}
|
||||
|
||||
// ============ Layer 1: 粘性会话优先 ============
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||
account.IsSchedulable() &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Layer 2: 负载感知选择 ============
|
||||
candidates := make([]*Account, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if isExcluded(acc.ID) {
|
||||
continue
|
||||
}
|
||||
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, acc)
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return nil, errors.New("no available accounts")
|
||||
}
|
||||
|
||||
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
|
||||
for _, acc := range candidates {
|
||||
accountLoads = append(accountLoads, AccountWithConcurrency{
|
||||
ID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
})
|
||||
}
|
||||
|
||||
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
||||
if err != nil {
|
||||
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, sessionHash, preferOAuth); ok {
|
||||
return result, nil
|
||||
}
|
||||
} else {
|
||||
type accountWithLoad struct {
|
||||
account *Account
|
||||
loadInfo *AccountLoadInfo
|
||||
}
|
||||
var available []accountWithLoad
|
||||
for _, acc := range candidates {
|
||||
loadInfo := loadMap[acc.ID]
|
||||
if loadInfo == nil {
|
||||
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
|
||||
}
|
||||
if loadInfo.LoadRate < 100 {
|
||||
available = append(available, accountWithLoad{
|
||||
account: acc,
|
||||
loadInfo: loadInfo,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(available) > 0 {
|
||||
sort.SliceStable(available, func(i, j int) bool {
|
||||
a, b := available[i], available[j]
|
||||
if a.account.Priority != b.account.Priority {
|
||||
return a.account.Priority < b.account.Priority
|
||||
}
|
||||
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
||||
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
|
||||
}
|
||||
switch {
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
|
||||
return true
|
||||
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
|
||||
if preferOAuth && a.account.Type != b.account.Type {
|
||||
return a.account.Type == AccountTypeOAuth
|
||||
}
|
||||
return false
|
||||
default:
|
||||
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||
}
|
||||
})
|
||||
|
||||
for _, item := range available {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" {
|
||||
_ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: item.account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Layer 3: 兜底排队 ============
|
||||
sortAccountsByPriorityAndLastUsed(candidates, preferOAuth)
|
||||
for _, acc := range candidates {
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return nil, errors.New("no available accounts")
|
||||
}
|
||||
|
||||
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
|
||||
ordered := append([]*Account(nil), candidates...)
|
||||
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
|
||||
|
||||
for _, acc := range ordered {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" {
|
||||
_ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||
if s.cfg != nil {
|
||||
return s.cfg.Gateway.Scheduling
|
||||
}
|
||||
return config.GatewaySchedulingConfig{
|
||||
StickySessionMaxWaiting: 3,
|
||||
StickySessionWaitTimeout: 45 * time.Second,
|
||||
FallbackWaitTimeout: 30 * time.Second,
|
||||
FallbackMaxWaiting: 100,
|
||||
LoadBatchEnabled: true,
|
||||
SlotCleanupInterval: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) {
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform != "" {
|
||||
return forcePlatform, true, nil
|
||||
}
|
||||
if groupID != nil {
|
||||
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||||
if err != nil {
|
||||
return "", false, fmt.Errorf("get group failed: %w", err)
|
||||
}
|
||||
return group.Platform, false, nil
|
||||
}
|
||||
return PlatformAnthropic, false, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
|
||||
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
|
||||
if useMixed {
|
||||
platforms := []string{platform, PlatformAntigravity}
|
||||
var accounts []Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, useMixed, err
|
||||
}
|
||||
filtered := make([]Account, 0, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, acc)
|
||||
}
|
||||
return filtered, useMixed, nil
|
||||
}
|
||||
|
||||
var accounts []Account
|
||||
var err error
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
} else if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
|
||||
if err == nil && len(accounts) == 0 && hasForcePlatform {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, useMixed, err
|
||||
}
|
||||
return accounts, useMixed, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if useMixed {
|
||||
if account.Platform == platform {
|
||||
return true
|
||||
}
|
||||
return account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()
|
||||
}
|
||||
return account.Platform == platform
|
||||
}
|
||||
|
||||
func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||
if s.concurrencyService == nil {
|
||||
return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil
|
||||
}
|
||||
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
}
|
||||
|
||||
func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
|
||||
sort.SliceStable(accounts, func(i, j int) bool {
|
||||
a, b := accounts[i], accounts[j]
|
||||
if a.Priority != b.Priority {
|
||||
return a.Priority < b.Priority
|
||||
}
|
||||
switch {
|
||||
case a.LastUsedAt == nil && b.LastUsedAt != nil:
|
||||
return true
|
||||
case a.LastUsedAt != nil && b.LastUsedAt == nil:
|
||||
return false
|
||||
case a.LastUsedAt == nil && b.LastUsedAt == nil:
|
||||
if preferOAuth && a.Type != b.Type {
|
||||
return a.Type == AccountTypeOAuth
|
||||
}
|
||||
return false
|
||||
default:
|
||||
return a.LastUsedAt.Before(*b.LastUsedAt)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
|
||||
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
||||
preferOAuth := platform == PlatformGemini
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
@@ -762,9 +389,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||||
// keep selected (never used is preferred)
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||
if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
|
||||
selected = acc
|
||||
}
|
||||
// keep selected (both never used)
|
||||
default:
|
||||
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
||||
selected = acc
|
||||
@@ -794,7 +419,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
|
||||
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
|
||||
platforms := []string{nativePlatform, PlatformAntigravity}
|
||||
preferOAuth := nativePlatform == PlatformGemini
|
||||
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" {
|
||||
@@ -854,9 +478,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||||
// keep selected (never used is preferred)
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||
if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
|
||||
selected = acc
|
||||
}
|
||||
// keep selected (both never used)
|
||||
default:
|
||||
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
||||
selected = acc
|
||||
@@ -1062,30 +684,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
|
||||
// 处理错误响应(不可重试的错误)
|
||||
if resp.StatusCode >= 400 {
|
||||
// 可选:对部分 400 触发 failover(默认关闭以保持语义)
|
||||
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
// ReadAll failed, fall back to normal error handling without consuming the stream
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
if s.shouldFailoverOn400(respBody) {
|
||||
if s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
log.Printf(
|
||||
"Account %d: 400 error, attempting failover: %s",
|
||||
account.ID,
|
||||
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
|
||||
)
|
||||
} else {
|
||||
log.Printf("Account %d: 400 error, attempting failover", account.ID)
|
||||
}
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
}
|
||||
|
||||
@@ -1188,13 +786,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
// 处理anthropic-beta header(OAuth账号需要特殊处理)
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
|
||||
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
||||
if requestNeedsBetaFeatures(body) {
|
||||
if beta := defaultApiKeyBetaHeader(body); beta != "" {
|
||||
req.Header.Set("anthropic-beta", beta)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -1247,83 +838,6 @@ func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string)
|
||||
return claude.DefaultBetaHeader
|
||||
}
|
||||
|
||||
func requestNeedsBetaFeatures(body []byte) bool {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 {
|
||||
return true
|
||||
}
|
||||
if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func defaultApiKeyBetaHeader(body []byte) string {
|
||||
modelID := gjson.GetBytes(body, "model").String()
|
||||
if strings.Contains(strings.ToLower(modelID), "haiku") {
|
||||
return claude.ApiKeyHaikuBetaHeader
|
||||
}
|
||||
return claude.ApiKeyBetaHeader
|
||||
}
|
||||
|
||||
func truncateForLog(b []byte, maxBytes int) string {
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 2048
|
||||
}
|
||||
if len(b) > maxBytes {
|
||||
b = b[:maxBytes]
|
||||
}
|
||||
s := string(b)
|
||||
// 保持一行,避免污染日志格式
|
||||
s = strings.ReplaceAll(s, "\n", "\\n")
|
||||
s = strings.ReplaceAll(s, "\r", "\\r")
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
|
||||
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
|
||||
// 默认保守:无法识别则不切换。
|
||||
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
|
||||
if msg == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 缺少/错误的 beta header:换账号/链路可能成功(尤其是混合调度时)。
|
||||
// 更精确匹配 beta 相关的兼容性问题,避免误触发切换。
|
||||
if strings.Contains(msg, "anthropic-beta") ||
|
||||
strings.Contains(msg, "beta feature") ||
|
||||
strings.Contains(msg, "requires beta") {
|
||||
return true
|
||||
}
|
||||
|
||||
// thinking/tool streaming 等兼容性约束(常见于中间转换链路)
|
||||
if strings.Contains(msg, "thinking") || strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(msg, "tool_use") || strings.Contains(msg, "tool_result") || strings.Contains(msg, "tools") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func extractUpstreamErrorMessage(body []byte) string {
|
||||
// Claude 风格:{"type":"error","error":{"type":"...","message":"..."}}
|
||||
if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" {
|
||||
inner := strings.TrimSpace(m)
|
||||
// 有些上游会把完整 JSON 作为字符串塞进 message
|
||||
if strings.HasPrefix(inner, "{") {
|
||||
if innerMsg := gjson.Get(inner, "error.message").String(); strings.TrimSpace(innerMsg) != "" {
|
||||
return innerMsg
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// 兜底:尝试顶层 message
|
||||
return gjson.GetBytes(body, "message").String()
|
||||
}
|
||||
|
||||
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
@@ -1336,16 +850,6 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
||||
|
||||
switch resp.StatusCode {
|
||||
case 400:
|
||||
// 仅记录上游错误摘要(避免输出请求内容);需要时可通过配置打开
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
log.Printf(
|
||||
"Upstream 400 error (account=%d platform=%s type=%s): %s",
|
||||
account.ID,
|
||||
account.Platform,
|
||||
account.Type,
|
||||
truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
|
||||
)
|
||||
}
|
||||
c.Data(http.StatusBadRequest, "application/json", body)
|
||||
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
|
||||
case 401:
|
||||
@@ -1825,18 +1329,6 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
// 标记账号状态(429/529等)
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
|
||||
// 记录上游错误摘要便于排障(不回显请求内容)
|
||||
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||
log.Printf(
|
||||
"count_tokens upstream error %d (account=%d platform=%s type=%s): %s",
|
||||
resp.StatusCode,
|
||||
account.ID,
|
||||
account.Platform,
|
||||
account.Type,
|
||||
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
|
||||
)
|
||||
}
|
||||
|
||||
// 返回简化的错误响应
|
||||
errMsg := "Upstream request failed"
|
||||
switch resp.StatusCode {
|
||||
@@ -1917,13 +1409,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
// OAuth 账号:处理 anthropic-beta header
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
|
||||
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
|
||||
if requestNeedsBetaFeatures(body) {
|
||||
if beta := defaultApiKeyBetaHeader(body); beta != "" {
|
||||
req.Header.Set("anthropic-beta", beta)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
|
||||
@@ -2278,13 +2278,11 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
// 清理 JSON Schema
|
||||
cleanedParams := cleanToolSchema(params)
|
||||
|
||||
funcDecls = append(funcDecls, map[string]any{
|
||||
"name": name,
|
||||
"description": desc,
|
||||
"parameters": cleanedParams,
|
||||
"parameters": params,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2298,41 +2296,6 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
|
||||
}
|
||||
}
|
||||
|
||||
// cleanToolSchema 清理工具的 JSON Schema,移除 Gemini 不支持的字段
|
||||
func cleanToolSchema(schema any) any {
|
||||
if schema == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v := schema.(type) {
|
||||
case map[string]any:
|
||||
cleaned := make(map[string]any)
|
||||
for key, value := range v {
|
||||
// 跳过不支持的字段
|
||||
if key == "$schema" || key == "$id" || key == "$ref" ||
|
||||
key == "additionalProperties" || key == "minLength" ||
|
||||
key == "maxLength" || key == "minItems" || key == "maxItems" {
|
||||
continue
|
||||
}
|
||||
// 递归清理嵌套对象
|
||||
cleaned[key] = cleanToolSchema(value)
|
||||
}
|
||||
// 规范化 type 字段为大写
|
||||
if typeVal, ok := cleaned["type"].(string); ok {
|
||||
cleaned["type"] = strings.ToUpper(typeVal)
|
||||
}
|
||||
return cleaned
|
||||
case []any:
|
||||
cleaned := make([]any, len(v))
|
||||
for i, item := range v {
|
||||
cleaned[i] = cleanToolSchema(item)
|
||||
}
|
||||
return cleaned
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
func convertClaudeGenerationConfig(req map[string]any) map[string]any {
|
||||
out := make(map[string]any)
|
||||
if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 {
|
||||
|
||||
@@ -1,128 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
|
||||
func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tools any
|
||||
expectedLen int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Standard tools",
|
||||
tools: []any{
|
||||
map[string]any{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather info",
|
||||
"input_schema": map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
expectedLen: 1,
|
||||
description: "标准工具格式应该正常转换",
|
||||
},
|
||||
{
|
||||
name: "Custom type tool (MCP format)",
|
||||
tools: []any{
|
||||
map[string]any{
|
||||
"type": "custom",
|
||||
"name": "mcp_tool",
|
||||
"custom": map[string]any{
|
||||
"description": "MCP tool description",
|
||||
"input_schema": map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLen: 1,
|
||||
description: "Custom类型工具应该从custom字段读取",
|
||||
},
|
||||
{
|
||||
name: "Mixed standard and custom tools",
|
||||
tools: []any{
|
||||
map[string]any{
|
||||
"name": "standard_tool",
|
||||
"description": "Standard",
|
||||
"input_schema": map[string]any{"type": "object"},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "custom",
|
||||
"name": "custom_tool",
|
||||
"custom": map[string]any{
|
||||
"description": "Custom",
|
||||
"input_schema": map[string]any{"type": "object"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLen: 1,
|
||||
description: "混合工具应该都能正确转换",
|
||||
},
|
||||
{
|
||||
name: "Custom tool without custom field",
|
||||
tools: []any{
|
||||
map[string]any{
|
||||
"type": "custom",
|
||||
"name": "invalid_custom",
|
||||
// 缺少 custom 字段
|
||||
},
|
||||
},
|
||||
expectedLen: 0, // 应该被跳过
|
||||
description: "缺少custom字段的custom工具应该被跳过",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := convertClaudeToolsToGeminiTools(tt.tools)
|
||||
|
||||
if tt.expectedLen == 0 {
|
||||
if result != nil {
|
||||
t.Errorf("%s: expected nil result, got %v", tt.description, result)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatalf("%s: expected non-nil result", tt.description)
|
||||
}
|
||||
|
||||
if len(result) != 1 {
|
||||
t.Errorf("%s: expected 1 tool declaration, got %d", tt.description, len(result))
|
||||
return
|
||||
}
|
||||
|
||||
toolDecl, ok := result[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("%s: result[0] is not map[string]any", tt.description)
|
||||
}
|
||||
|
||||
funcDecls, ok := toolDecl["functionDeclarations"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("%s: functionDeclarations is not []any", tt.description)
|
||||
}
|
||||
|
||||
toolsArr, _ := tt.tools.([]any)
|
||||
expectedFuncCount := 0
|
||||
for _, tool := range toolsArr {
|
||||
toolMap, _ := tool.(map[string]any)
|
||||
if toolMap["name"] != "" {
|
||||
// 检查是否为有效的custom工具
|
||||
if toolMap["type"] == "custom" {
|
||||
if toolMap["custom"] != nil {
|
||||
expectedFuncCount++
|
||||
}
|
||||
} else {
|
||||
expectedFuncCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(funcDecls) != expectedFuncCount {
|
||||
t.Errorf("%s: expected %d function declarations, got %d",
|
||||
tt.description, expectedFuncCount, len(funcDecls))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -164,45 +163,6 @@ type GeminiTokenInfo struct {
|
||||
Scope string `json:"scope,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
|
||||
TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA
|
||||
}
|
||||
|
||||
// validateTierID validates tier_id format and length
|
||||
func validateTierID(tierID string) error {
|
||||
if tierID == "" {
|
||||
return nil // Empty is allowed
|
||||
}
|
||||
if len(tierID) > 64 {
|
||||
return fmt.Errorf("tier_id exceeds maximum length of 64 characters")
|
||||
}
|
||||
// Allow alphanumeric, underscore, hyphen, and slash (for tier paths)
|
||||
if !regexp.MustCompile(`^[a-zA-Z0-9_/-]+$`).MatchString(tierID) {
|
||||
return fmt.Errorf("tier_id contains invalid characters")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractTierIDFromAllowedTiers extracts tierID from LoadCodeAssist response
|
||||
// Prioritizes IsDefault tier, falls back to first non-empty tier
|
||||
func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string {
|
||||
tierID := "LEGACY"
|
||||
// First pass: look for default tier
|
||||
for _, tier := range allowedTiers {
|
||||
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" {
|
||||
tierID = strings.TrimSpace(tier.ID)
|
||||
break
|
||||
}
|
||||
}
|
||||
// Second pass: if still LEGACY, take first non-empty tier
|
||||
if tierID == "LEGACY" {
|
||||
for _, tier := range allowedTiers {
|
||||
if strings.TrimSpace(tier.ID) != "" {
|
||||
tierID = strings.TrimSpace(tier.ID)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return tierID
|
||||
}
|
||||
|
||||
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
|
||||
@@ -263,14 +223,13 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
|
||||
|
||||
projectID := sessionProjectID
|
||||
var tierID string
|
||||
|
||||
// 对于 code_assist 模式,project_id 是必需的
|
||||
// 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API)
|
||||
if oauthType == "code_assist" {
|
||||
if projectID == "" {
|
||||
var err error
|
||||
projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||
projectID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
// 记录警告但不阻断流程,允许后续补充 project_id
|
||||
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
|
||||
@@ -289,7 +248,6 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
ExpiresAt: expiresAt,
|
||||
Scope: tokenResp.Scope,
|
||||
ProjectID: projectID,
|
||||
TierID: tierID,
|
||||
OAuthType: oauthType,
|
||||
}, nil
|
||||
}
|
||||
@@ -399,7 +357,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
||||
// For Code Assist, project_id is required. Auto-detect if missing.
|
||||
// For AI Studio OAuth, project_id is optional and should not block refresh.
|
||||
if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" {
|
||||
projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
|
||||
projectID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to auto-detect project_id: %w", err)
|
||||
}
|
||||
@@ -408,7 +366,6 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
||||
return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
|
||||
}
|
||||
tokenInfo.ProjectID = projectID
|
||||
tokenInfo.TierID = tierID
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
@@ -431,13 +388,6 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo)
|
||||
if tokenInfo.ProjectID != "" {
|
||||
creds["project_id"] = tokenInfo.ProjectID
|
||||
}
|
||||
if tokenInfo.TierID != "" {
|
||||
// Validate tier_id before storing
|
||||
if err := validateTierID(tokenInfo.TierID); err == nil {
|
||||
creds["tier_id"] = tokenInfo.TierID
|
||||
}
|
||||
// Silently skip invalid tier_id (don't block account creation)
|
||||
}
|
||||
if tokenInfo.OAuthType != "" {
|
||||
creds["oauth_type"] = tokenInfo.OAuthType
|
||||
}
|
||||
@@ -448,26 +398,34 @@ func (s *GeminiOAuthService) Stop() {
|
||||
s.sessionStore.Stop()
|
||||
}
|
||||
|
||||
func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, string, error) {
|
||||
func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, error) {
|
||||
if s.codeAssist == nil {
|
||||
return "", "", errors.New("code assist client not configured")
|
||||
return "", errors.New("code assist client not configured")
|
||||
}
|
||||
|
||||
loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil)
|
||||
|
||||
// Extract tierID from response (works whether CloudAICompanionProject is set or not)
|
||||
tierID := "LEGACY"
|
||||
if loadResp != nil {
|
||||
tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
|
||||
}
|
||||
|
||||
// If LoadCodeAssist returned a project, use it
|
||||
if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
|
||||
return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil
|
||||
return strings.TrimSpace(loadResp.CloudAICompanionProject), nil
|
||||
}
|
||||
|
||||
// Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID.
|
||||
// (tierID already extracted above, reuse it)
|
||||
tierID := "LEGACY"
|
||||
if loadResp != nil {
|
||||
for _, tier := range loadResp.AllowedTiers {
|
||||
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" {
|
||||
tierID = strings.TrimSpace(tier.ID)
|
||||
break
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(tierID) == "" || tierID == "LEGACY" {
|
||||
for _, tier := range loadResp.AllowedTiers {
|
||||
if strings.TrimSpace(tier.ID) != "" {
|
||||
tierID = strings.TrimSpace(tier.ID)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
req := &geminicli.OnboardUserRequest{
|
||||
TierID: tierID,
|
||||
@@ -485,39 +443,39 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
|
||||
// If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects.
|
||||
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
|
||||
if fbErr == nil && strings.TrimSpace(fallback) != "" {
|
||||
return strings.TrimSpace(fallback), tierID, nil
|
||||
return strings.TrimSpace(fallback), nil
|
||||
}
|
||||
return "", "", err
|
||||
return "", err
|
||||
}
|
||||
if resp.Done {
|
||||
if resp.Response != nil && resp.Response.CloudAICompanionProject != nil {
|
||||
switch v := resp.Response.CloudAICompanionProject.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v), tierID, nil
|
||||
return strings.TrimSpace(v), nil
|
||||
case map[string]any:
|
||||
if id, ok := v["id"].(string); ok {
|
||||
return strings.TrimSpace(id), tierID, nil
|
||||
return strings.TrimSpace(id), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
|
||||
if fbErr == nil && strings.TrimSpace(fallback) != "" {
|
||||
return strings.TrimSpace(fallback), tierID, nil
|
||||
return strings.TrimSpace(fallback), nil
|
||||
}
|
||||
return "", "", errors.New("onboardUser completed but no project_id returned")
|
||||
return "", errors.New("onboardUser completed but no project_id returned")
|
||||
}
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
|
||||
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
|
||||
if fbErr == nil && strings.TrimSpace(fallback) != "" {
|
||||
return strings.TrimSpace(fallback), tierID, nil
|
||||
return strings.TrimSpace(fallback), nil
|
||||
}
|
||||
if loadErr != nil {
|
||||
return "", "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
|
||||
return "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
|
||||
}
|
||||
return "", "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
|
||||
return "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
|
||||
}
|
||||
|
||||
type googleCloudProject struct {
|
||||
|
||||
@@ -112,7 +112,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
}
|
||||
}
|
||||
|
||||
detected, tierID, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL)
|
||||
detected, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL)
|
||||
if err != nil {
|
||||
log.Printf("[GeminiTokenProvider] Auto-detect project_id failed: %v, fallback to AI Studio API mode", err)
|
||||
return accessToken, nil
|
||||
@@ -123,9 +123,6 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
account.Credentials = make(map[string]any)
|
||||
}
|
||||
account.Credentials["project_id"] = detected
|
||||
if tierID != "" {
|
||||
account.Credentials["tier_id"] = tierID
|
||||
}
|
||||
_ = p.accountRepo.Update(ctx, account)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -81,7 +80,6 @@ type OpenAIGatewayService struct {
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache GatewayCache
|
||||
cfg *config.Config
|
||||
concurrencyService *ConcurrencyService
|
||||
billingService *BillingService
|
||||
rateLimitService *RateLimitService
|
||||
billingCacheService *BillingCacheService
|
||||
@@ -97,7 +95,6 @@ func NewOpenAIGatewayService(
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
cache GatewayCache,
|
||||
cfg *config.Config,
|
||||
concurrencyService *ConcurrencyService,
|
||||
billingService *BillingService,
|
||||
rateLimitService *RateLimitService,
|
||||
billingCacheService *BillingCacheService,
|
||||
@@ -111,7 +108,6 @@ func NewOpenAIGatewayService(
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: concurrencyService,
|
||||
billingService: billingService,
|
||||
rateLimitService: rateLimitService,
|
||||
billingCacheService: billingCacheService,
|
||||
@@ -130,14 +126,6 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// BindStickySession sets session -> account binding with standard TTL.
|
||||
func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
|
||||
if sessionHash == "" || accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
return s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, accountID, openaiStickySessionTTL)
|
||||
}
|
||||
|
||||
// SelectAccount selects an OpenAI account with sticky session support
|
||||
func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
|
||||
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
|
||||
@@ -230,254 +218,6 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
|
||||
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
|
||||
cfg := s.schedulingConfig()
|
||||
var stickyAccountID int64
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash); err == nil {
|
||||
stickyAccountID = accountID
|
||||
}
|
||||
}
|
||||
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
|
||||
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
accounts, err := s.listSchedulableAccounts(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(accounts) == 0 {
|
||||
return nil, errors.New("no available accounts")
|
||||
}
|
||||
|
||||
isExcluded := func(accountID int64) bool {
|
||||
if excludedIDs == nil {
|
||||
return false
|
||||
}
|
||||
_, excluded := excludedIDs[accountID]
|
||||
return excluded
|
||||
}
|
||||
|
||||
// ============ Layer 1: Sticky session ============
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
|
||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
|
||||
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Layer 2: Load-aware selection ============
|
||||
candidates := make([]*Account, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if isExcluded(acc.ID) {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, acc)
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return nil, errors.New("no available accounts")
|
||||
}
|
||||
|
||||
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
|
||||
for _, acc := range candidates {
|
||||
accountLoads = append(accountLoads, AccountWithConcurrency{
|
||||
ID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
})
|
||||
}
|
||||
|
||||
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
||||
if err != nil {
|
||||
ordered := append([]*Account(nil), candidates...)
|
||||
sortAccountsByPriorityAndLastUsed(ordered, false)
|
||||
for _, acc := range ordered {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" {
|
||||
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, acc.ID, openaiStickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
type accountWithLoad struct {
|
||||
account *Account
|
||||
loadInfo *AccountLoadInfo
|
||||
}
|
||||
var available []accountWithLoad
|
||||
for _, acc := range candidates {
|
||||
loadInfo := loadMap[acc.ID]
|
||||
if loadInfo == nil {
|
||||
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
|
||||
}
|
||||
if loadInfo.LoadRate < 100 {
|
||||
available = append(available, accountWithLoad{
|
||||
account: acc,
|
||||
loadInfo: loadInfo,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(available) > 0 {
|
||||
sort.SliceStable(available, func(i, j int) bool {
|
||||
a, b := available[i], available[j]
|
||||
if a.account.Priority != b.account.Priority {
|
||||
return a.account.Priority < b.account.Priority
|
||||
}
|
||||
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
||||
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
|
||||
}
|
||||
switch {
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
|
||||
return true
|
||||
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
default:
|
||||
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||
}
|
||||
})
|
||||
|
||||
for _, item := range available {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" {
|
||||
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: item.account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Layer 3: Fallback wait ============
|
||||
sortAccountsByPriorityAndLastUsed(candidates, false)
|
||||
for _, acc := range candidates {
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("no available accounts")
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) {
|
||||
var accounts []Account
|
||||
var err error
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
||||
} else if groupID != nil {
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
|
||||
} else {
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
return accounts, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||
if s.concurrencyService == nil {
|
||||
return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil
|
||||
}
|
||||
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||
if s.cfg != nil {
|
||||
return s.cfg.Gateway.Scheduling
|
||||
}
|
||||
return config.GatewaySchedulingConfig{
|
||||
StickySessionMaxWaiting: 3,
|
||||
StickySessionWaitTimeout: 45 * time.Second,
|
||||
FallbackWaitTimeout: 30 * time.Second,
|
||||
FallbackMaxWaiting: 100,
|
||||
LoadBatchEnabled: true,
|
||||
SlotCleanupInterval: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccessToken gets the access token for an OpenAI account
|
||||
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
|
||||
switch account.Type {
|
||||
|
||||
@@ -73,15 +73,6 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker.
|
||||
func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService {
|
||||
svc := NewConcurrencyService(cache)
|
||||
if cfg != nil {
|
||||
svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval)
|
||||
}
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProviderSet is the Wire provider set for all services
|
||||
var ProviderSet = wire.NewSet(
|
||||
// Core services
|
||||
@@ -116,7 +107,7 @@ var ProviderSet = wire.NewSet(
|
||||
ProvideEmailQueueService,
|
||||
NewTurnstileService,
|
||||
NewSubscriptionService,
|
||||
ProvideConcurrencyService,
|
||||
NewConcurrencyService,
|
||||
NewIdentityService,
|
||||
NewCRSSyncService,
|
||||
ProvideUpdateService,
|
||||
|
||||
Reference in New Issue
Block a user