Merge branch 'main' of github.com:Wei-Shaw/sub2api
This commit is contained in:
@@ -21,16 +21,42 @@ type stubOpenAIAccountRepo struct {
|
||||
accounts []Account
|
||||
}
|
||||
|
||||
func (r stubOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
for i := range r.accounts {
|
||||
if r.accounts[i].ID == id {
|
||||
return &r.accounts[i], nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("account not found")
|
||||
}
|
||||
|
||||
func (r stubOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
|
||||
return append([]Account(nil), r.accounts...), nil
|
||||
var result []Account
|
||||
for _, acc := range r.accounts {
|
||||
if acc.Platform == platform {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
return append([]Account(nil), r.accounts...), nil
|
||||
var result []Account
|
||||
for _, acc := range r.accounts {
|
||||
if acc.Platform == platform {
|
||||
result = append(result, acc)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type stubConcurrencyCache struct {
|
||||
ConcurrencyCache
|
||||
loadBatchErr error
|
||||
loadMap map[int64]*AccountLoadInfo
|
||||
acquireResults map[int64]bool
|
||||
waitCounts map[int64]int
|
||||
skipDefaultLoad bool
|
||||
}
|
||||
|
||||
type cancelReadCloser struct{}
|
||||
@@ -53,6 +79,11 @@ func (w *failingGinWriter) Write(p []byte) (int, error) {
|
||||
}
|
||||
|
||||
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
if c.acquireResults != nil {
|
||||
if result, ok := c.acquireResults[accountID]; ok {
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -61,8 +92,25 @@ func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
if c.loadBatchErr != nil {
|
||||
return nil, c.loadBatchErr
|
||||
}
|
||||
out := make(map[int64]*AccountLoadInfo, len(accounts))
|
||||
if c.skipDefaultLoad && c.loadMap != nil {
|
||||
for _, acc := range accounts {
|
||||
if load, ok := c.loadMap[acc.ID]; ok {
|
||||
out[acc.ID] = load
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
for _, acc := range accounts {
|
||||
if c.loadMap != nil {
|
||||
if load, ok := c.loadMap[acc.ID]; ok {
|
||||
out[acc.ID] = load
|
||||
continue
|
||||
}
|
||||
}
|
||||
out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0}
|
||||
}
|
||||
return out, nil
|
||||
@@ -111,6 +159,51 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
if c.waitCounts != nil {
|
||||
if count, ok := c.waitCounts[accountID]; ok {
|
||||
return count, nil
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
type stubGatewayCache struct {
|
||||
sessionBindings map[string]int64
|
||||
deletedSessions map[string]int
|
||||
}
|
||||
|
||||
func (c *stubGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
|
||||
if id, ok := c.sessionBindings[sessionHash]; ok {
|
||||
return id, nil
|
||||
}
|
||||
return 0, errors.New("not found")
|
||||
}
|
||||
|
||||
func (c *stubGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
if c.sessionBindings == nil {
|
||||
c.sessionBindings = make(map[string]int64)
|
||||
}
|
||||
c.sessionBindings[sessionHash] = accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stubGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stubGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
|
||||
if c.sessionBindings == nil {
|
||||
return nil
|
||||
}
|
||||
if c.deletedSessions == nil {
|
||||
c.deletedSessions = make(map[string]int)
|
||||
}
|
||||
c.deletedSessions[sessionHash]++
|
||||
delete(c.sessionBindings, sessionHash)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
|
||||
now := time.Now()
|
||||
resetAt := now.Add(10 * time.Minute)
|
||||
@@ -201,6 +294,515 @@ func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurre
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_StickyUnschedulableClearsSession(t *testing.T) {
|
||||
sessionHash := "session-1"
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusDisabled, Schedulable: true, Concurrency: 1},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||
}
|
||||
if acc == nil || acc.ID != 2 {
|
||||
t.Fatalf("expected account 2, got %+v", acc)
|
||||
}
|
||||
if cache.deletedSessions["openai:"+sessionHash] != 1 {
|
||||
t.Fatalf("expected sticky session to be deleted")
|
||||
}
|
||||
if cache.sessionBindings["openai:"+sessionHash] != 2 {
|
||||
t.Fatalf("expected sticky session to bind to account 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_StickyUnschedulableClearsSession(t *testing.T) {
|
||||
sessionHash := "session-2"
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusDisabled, Schedulable: true, Concurrency: 1},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, sessionHash, "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
|
||||
t.Fatalf("expected account 2, got %+v", selection)
|
||||
}
|
||||
if cache.deletedSessions["openai:"+sessionHash] != 1 {
|
||||
t.Fatalf("expected sticky session to be deleted")
|
||||
}
|
||||
if cache.sessionBindings["openai:"+sessionHash] != 2 {
|
||||
t.Fatalf("expected sticky session to bind to account 2")
|
||||
}
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_NoModelSupport(t *testing.T) {
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{
|
||||
ID: 1,
|
||||
Platform: PlatformOpenAI,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"gpt-3.5-turbo": "gpt-3.5-turbo"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for unsupported model")
|
||||
}
|
||||
if acc != nil {
|
||||
t.Fatalf("expected nil account for unsupported model")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "supporting model") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorFallback(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadBatchErr: errors.New("load batch failed"),
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "fallback", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
t.Fatalf("expected selection")
|
||||
}
|
||||
if selection.Account.ID != 2 {
|
||||
t.Fatalf("expected account 2, got %d", selection.Account.ID)
|
||||
}
|
||||
if cache.sessionBindings["openai:fallback"] != 2 {
|
||||
t.Fatalf("expected sticky session updated")
|
||||
}
|
||||
if selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_NoSlotFallbackWait(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
acquireResults: map[int64]bool{1: false},
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, LoadRate: 10},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.WaitPlan == nil {
|
||||
t.Fatalf("expected wait plan fallback")
|
||||
}
|
||||
if selection.Account == nil || selection.Account.ID != 1 {
|
||||
t.Fatalf("expected account 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_SetsStickyBinding(t *testing.T) {
|
||||
sessionHash := "bind"
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||
}
|
||||
if acc == nil || acc.ID != 1 {
|
||||
t.Fatalf("expected account 1")
|
||||
}
|
||||
if cache.sessionBindings["openai:"+sessionHash] != 1 {
|
||||
t.Fatalf("expected sticky session binding")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_StickyWaitPlan(t *testing.T) {
|
||||
sessionHash := "sticky-wait"
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||
}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
acquireResults: map[int64]bool{1: false},
|
||||
waitCounts: map[int64]int{1: 0},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, sessionHash, "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.WaitPlan == nil {
|
||||
t.Fatalf("expected sticky wait plan")
|
||||
}
|
||||
if selection.Account == nil || selection.Account.ID != 1 {
|
||||
t.Fatalf("expected account 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_PrefersLowerLoad(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, LoadRate: 80},
|
||||
2: {AccountID: 2, LoadRate: 10},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "load", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
|
||||
t.Fatalf("expected account 2")
|
||||
}
|
||||
if cache.sessionBindings["openai:load"] != 2 {
|
||||
t.Fatalf("expected sticky session updated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_StickyExcludedFallback(t *testing.T) {
|
||||
sessionHash := "excluded"
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
excluded := map[int64]struct{}{1: {}}
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", excluded)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||
}
|
||||
if acc == nil || acc.ID != 2 {
|
||||
t.Fatalf("expected account 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_StickyNonOpenAI(t *testing.T) {
|
||||
sessionHash := "non-openai"
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{
|
||||
sessionBindings: map[string]int64{"openai:" + sessionHash: 1},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||
}
|
||||
if acc == nil || acc.ID != 2 {
|
||||
t.Fatalf("expected account 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_NoAccounts(t *testing.T) {
|
||||
repo := stubOpenAIAccountRepo{accounts: []Account{}}
|
||||
cache := &stubGatewayCache{}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "", nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for no accounts")
|
||||
}
|
||||
if acc != nil {
|
||||
t.Fatalf("expected nil account")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no available OpenAI accounts") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_NoCandidates(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
resetAt := time.Now().Add(1 * time.Hour)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, RateLimitResetAt: &resetAt},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for no candidates")
|
||||
}
|
||||
if selection != nil {
|
||||
t.Fatalf("expected nil selection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_AllFullWaitPlan(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, LoadRate: 100},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.WaitPlan == nil {
|
||||
t.Fatalf("expected wait plan")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorNoAcquire(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadBatchErr: errors.New("load batch failed"),
|
||||
acquireResults: map[int64]bool{1: false},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.WaitPlan == nil {
|
||||
t.Fatalf("expected wait plan")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_MissingLoadInfo(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, LoadRate: 50},
|
||||
},
|
||||
skipDefaultLoad: true,
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
|
||||
t.Fatalf("expected account 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountForModelWithExclusions_LeastRecentlyUsed(t *testing.T) {
|
||||
oldTime := time.Now().Add(-2 * time.Hour)
|
||||
newTime := time.Now().Add(-1 * time.Hour)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1, LastUsedAt: &newTime},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1, LastUsedAt: &oldTime},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountForModelWithExclusions error: %v", err)
|
||||
}
|
||||
if acc == nil || acc.ID != 2 {
|
||||
t.Fatalf("expected account 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_PreferNeverUsed(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
lastUsed := time.Now().Add(-1 * time.Hour)
|
||||
repo := stubOpenAIAccountRepo{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, LastUsedAt: &lastUsed},
|
||||
{ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1},
|
||||
},
|
||||
}
|
||||
cache := &stubGatewayCache{}
|
||||
concurrencyCache := stubConcurrencyCache{
|
||||
loadMap: map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, LoadRate: 10},
|
||||
2: {AccountID: 2, LoadRate: 10},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &OpenAIGatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("SelectAccountWithLoadAwareness error: %v", err)
|
||||
}
|
||||
if selection == nil || selection.Account == nil || selection.Account.ID != 2 {
|
||||
t.Fatalf("expected account 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStreamingTimeout(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
cfg := &config.Config{
|
||||
|
||||
Reference in New Issue
Block a user