package service import ( "bufio" "bytes" "context" "errors" "io" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) // 编译期接口断言 var _ AccountRepository = (*stubOpenAIAccountRepo)(nil) var _ GatewayCache = (*stubGatewayCache)(nil) type stubOpenAIAccountRepo struct { AccountRepository accounts []Account } func (r stubOpenAIAccountRepo) GetByID(ctx context.Context, id int64) (*Account, error) { for i := range r.accounts { if r.accounts[i].ID == id { return &r.accounts[i], nil } } return nil, errors.New("account not found") } func (r stubOpenAIAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) { var result []Account for _, acc := range r.accounts { if acc.Platform == platform { result = append(result, acc) } } return result, nil } func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) { var result []Account for _, acc := range r.accounts { if acc.Platform == platform { result = append(result, acc) } } return result, nil } type stubConcurrencyCache struct { ConcurrencyCache loadBatchErr error loadMap map[int64]*AccountLoadInfo acquireResults map[int64]bool waitCounts map[int64]int skipDefaultLoad bool } type cancelReadCloser struct{} func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled } func (c cancelReadCloser) Close() error { return nil } type failingGinWriter struct { gin.ResponseWriter failAfter int writes int } func (w *failingGinWriter) Write(p []byte) (int, error) { if w.writes >= w.failAfter { return 0, errors.New("write failed") } w.writes++ return w.ResponseWriter.Write(p) } func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { if c.acquireResults != nil { if result, ok := c.acquireResults[accountID]; ok { return result, nil } } return true, nil } func (c stubConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error { return nil } func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { if c.loadBatchErr != nil { return nil, c.loadBatchErr } out := make(map[int64]*AccountLoadInfo, len(accounts)) if c.skipDefaultLoad && c.loadMap != nil { for _, acc := range accounts { if load, ok := c.loadMap[acc.ID]; ok { out[acc.ID] = load } } return out, nil } for _, acc := range accounts { if c.loadMap != nil { if load, ok := c.loadMap[acc.ID]; ok { out[acc.ID] = load continue } } out[acc.ID] = &AccountLoadInfo{AccountID: acc.ID, LoadRate: 0} } return out, nil } func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) svc := &OpenAIGatewayService{} bodyWithKey := []byte(`{"prompt_cache_key":"ses_aaa"}`) // 1) session_id header wins c.Request.Header.Set("session_id", "sess-123") c.Request.Header.Set("conversation_id", "conv-456") h1 := svc.GenerateSessionHash(c, bodyWithKey) if h1 == "" { t.Fatalf("expected non-empty hash") } // 2) conversation_id used when session_id absent c.Request.Header.Del("session_id") h2 := svc.GenerateSessionHash(c, bodyWithKey) if h2 == "" { t.Fatalf("expected non-empty hash") } if h1 == h2 { t.Fatalf("expected different hashes for different keys") } // 3) prompt_cache_key used when both headers absent c.Request.Header.Del("conversation_id") h3 := svc.GenerateSessionHash(c, bodyWithKey) if h3 == "" { t.Fatalf("expected non-empty hash") } if h2 == h3 { t.Fatalf("expected different hashes for different keys") } // 4) empty when no signals h4 := svc.GenerateSessionHash(c, []byte(`{}`)) if h4 != "" { t.Fatalf("expected empty hash when no signals") } } func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { if c.waitCounts != nil { if count, ok := c.waitCounts[accountID]; ok { return count, nil } } return 0, nil } type stubGatewayCache struct { sessionBindings map[string]int64 deletedSessions map[string]int } func (c *stubGatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) { if id, ok := c.sessionBindings[sessionHash]; ok { return id, nil } return 0, errors.New("not found") } func (c *stubGatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error { if c.sessionBindings == nil { c.sessionBindings = make(map[string]int64) } c.sessionBindings[sessionHash] = accountID return nil } func (c *stubGatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error { return nil } func (c *stubGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error { if c.sessionBindings == nil { return nil } if c.deletedSessions == nil { c.deletedSessions = make(map[string]int) } c.deletedSessions[sessionHash]++ delete(c.sessionBindings, sessionHash) return nil } func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) { now := time.Now() resetAt := now.Add(10 * time.Minute) groupID := int64(1) rateLimited := Account{ ID: 1, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &resetAt, } available := Account{ ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, } svc := &OpenAIGatewayService{ accountRepo: stubOpenAIAccountRepo{accounts: []Account{rateLimited, available}}, concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), } selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-5.2", nil) if err != nil { t.Fatalf("SelectAccountWithLoadAwareness error: %v", err) } if selection == nil || selection.Account == nil { t.Fatalf("expected selection with account") } if selection.Account.ID != available.ID { t.Fatalf("expected account %d, got %d", available.ID, selection.Account.ID) } if selection.ReleaseFunc != nil { selection.ReleaseFunc() } } func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurrencyService(t *testing.T) { now := time.Now() resetAt := now.Add(10 * time.Minute) groupID := int64(1) rateLimited := Account{ ID: 1, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 0, RateLimitResetAt: &resetAt, } available := Account{ ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, } svc := &OpenAIGatewayService{ accountRepo: stubOpenAIAccountRepo{accounts: []Account{rateLimited, available}}, // concurrencyService is nil, forcing the non-load-batch selection path. } selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-5.2", nil) if err != nil { t.Fatalf("SelectAccountWithLoadAwareness error: %v", err) } if selection == nil || selection.Account == nil { t.Fatalf("expected selection with account") } if selection.Account.ID != available.ID { t.Fatalf("expected account %d, got %d", available.ID, selection.Account.ID) } if selection.ReleaseFunc != nil { selection.ReleaseFunc() } } func TestOpenAISelectAccountForModelWithExclusions_StickyUnschedulableClearsSession(t *testing.T) { sessionHash := "session-1" repo := stubOpenAIAccountRepo{ accounts: []Account{ {ID: 1, Platform: PlatformOpenAI, Status: StatusDisabled, Schedulable: true, Concurrency: 1}, {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1}, }, } cache := &stubGatewayCache{ sessionBindings: map[string]int64{"openai:" + sessionHash: 1}, } svc := &OpenAIGatewayService{ accountRepo: repo, cache: cache, } acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil) if err != nil { t.Fatalf("SelectAccountForModelWithExclusions error: %v", err) } if acc == nil || acc.ID != 2 { t.Fatalf("expected account 2, got %+v", acc) } if cache.deletedSessions["openai:"+sessionHash] != 1 { t.Fatalf("expected sticky session to be deleted") } if cache.sessionBindings["openai:"+sessionHash] != 2 { t.Fatalf("expected sticky session to bind to account 2") } } func TestOpenAISelectAccountWithLoadAwareness_StickyUnschedulableClearsSession(t *testing.T) { sessionHash := "session-2" groupID := int64(1) repo := stubOpenAIAccountRepo{ accounts: []Account{ {ID: 1, Platform: PlatformOpenAI, Status: StatusDisabled, Schedulable: true, Concurrency: 1}, {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1}, }, } cache := &stubGatewayCache{ sessionBindings: map[string]int64{"openai:" + sessionHash: 1}, } svc := &OpenAIGatewayService{ accountRepo: repo, cache: cache, concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), } selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, sessionHash, "gpt-4", nil) if err != nil { t.Fatalf("SelectAccountWithLoadAwareness error: %v", err) } if selection == nil || selection.Account == nil || selection.Account.ID != 2 { t.Fatalf("expected account 2, got %+v", selection) } if cache.deletedSessions["openai:"+sessionHash] != 1 { t.Fatalf("expected sticky session to be deleted") } if cache.sessionBindings["openai:"+sessionHash] != 2 { t.Fatalf("expected sticky session to bind to account 2") } if selection.ReleaseFunc != nil { selection.ReleaseFunc() } } func TestOpenAISelectAccountForModelWithExclusions_NoModelSupport(t *testing.T) { repo := stubOpenAIAccountRepo{ accounts: []Account{ { ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Credentials: map[string]any{"model_mapping": map[string]any{"gpt-3.5-turbo": "gpt-3.5-turbo"}}, }, }, } cache := &stubGatewayCache{} svc := &OpenAIGatewayService{ accountRepo: repo, cache: cache, } acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil) if err == nil { t.Fatalf("expected error for unsupported model") } if acc != nil { t.Fatalf("expected nil account for unsupported model") } if !strings.Contains(err.Error(), "supporting model") { t.Fatalf("unexpected error: %v", err) } } func TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorFallback(t *testing.T) { groupID := int64(1) repo := stubOpenAIAccountRepo{ accounts: []Account{ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2}, {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, }, } cache := &stubGatewayCache{} concurrencyCache := stubConcurrencyCache{ loadBatchErr: errors.New("load batch failed"), } svc := &OpenAIGatewayService{ accountRepo: repo, cache: cache, concurrencyService: NewConcurrencyService(concurrencyCache), } selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "fallback", "gpt-4", nil) if err != nil { t.Fatalf("SelectAccountWithLoadAwareness error: %v", err) } if selection == nil || selection.Account == nil { t.Fatalf("expected selection") } if selection.Account.ID != 2 { t.Fatalf("expected account 2, got %d", selection.Account.ID) } if cache.sessionBindings["openai:fallback"] != 2 { t.Fatalf("expected sticky session updated") } if selection.ReleaseFunc != nil { selection.ReleaseFunc() } } func TestOpenAISelectAccountWithLoadAwareness_NoSlotFallbackWait(t *testing.T) { groupID := int64(1) repo := stubOpenAIAccountRepo{ accounts: []Account{ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, }, } cache := &stubGatewayCache{} concurrencyCache := stubConcurrencyCache{ acquireResults: map[int64]bool{1: false}, loadMap: map[int64]*AccountLoadInfo{ 1: {AccountID: 1, LoadRate: 10}, }, } svc := &OpenAIGatewayService{ accountRepo: repo, cache: cache, concurrencyService: NewConcurrencyService(concurrencyCache), } selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil) if err != nil { t.Fatalf("SelectAccountWithLoadAwareness error: %v", err) } if selection == nil || selection.WaitPlan == nil { t.Fatalf("expected wait plan fallback") } if selection.Account == nil || selection.Account.ID != 1 { t.Fatalf("expected account 1") } } func TestOpenAISelectAccountForModelWithExclusions_SetsStickyBinding(t *testing.T) { sessionHash := "bind" repo := stubOpenAIAccountRepo{ accounts: []Account{ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, }, } cache := &stubGatewayCache{} svc := &OpenAIGatewayService{ accountRepo: repo, cache: cache, } acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil) if err != nil { t.Fatalf("SelectAccountForModelWithExclusions error: %v", err) } if acc == nil || acc.ID != 1 { t.Fatalf("expected account 1") } if cache.sessionBindings["openai:"+sessionHash] != 1 { t.Fatalf("expected sticky session binding") } } func TestOpenAISelectAccountWithLoadAwareness_StickyWaitPlan(t *testing.T) { sessionHash := "sticky-wait" groupID := int64(1) repo := stubOpenAIAccountRepo{ accounts: []Account{ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, }, } cache := &stubGatewayCache{ sessionBindings: map[string]int64{"openai:" + sessionHash: 1}, } concurrencyCache := stubConcurrencyCache{ acquireResults: map[int64]bool{1: false}, waitCounts: map[int64]int{1: 0}, } svc := &OpenAIGatewayService{ accountRepo: repo, cache: cache, concurrencyService: NewConcurrencyService(concurrencyCache), } selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, sessionHash, "gpt-4", nil) if err != nil { t.Fatalf("SelectAccountWithLoadAwareness error: %v", err) } if selection == nil || selection.WaitPlan == nil { t.Fatalf("expected sticky wait plan") } if selection.Account == nil || selection.Account.ID != 1 { t.Fatalf("expected account 1") } } func TestOpenAISelectAccountWithLoadAwareness_PrefersLowerLoad(t *testing.T) { groupID := int64(1) repo := stubOpenAIAccountRepo{ accounts: []Account{ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, }, } cache := &stubGatewayCache{} concurrencyCache := stubConcurrencyCache{ loadMap: map[int64]*AccountLoadInfo{ 1: {AccountID: 1, LoadRate: 80}, 2: {AccountID: 2, LoadRate: 10}, }, } svc := &OpenAIGatewayService{ accountRepo: repo, cache: cache, concurrencyService: NewConcurrencyService(concurrencyCache), } selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "load", "gpt-4", nil) if err != nil { t.Fatalf("SelectAccountWithLoadAwareness error: %v", err) } if selection == nil || selection.Account == nil || selection.Account.ID != 2 { t.Fatalf("expected account 2") } if cache.sessionBindings["openai:load"] != 2 { t.Fatalf("expected sticky session updated") } } func TestOpenAISelectAccountForModelWithExclusions_StickyExcludedFallback(t *testing.T) { sessionHash := "excluded" repo := stubOpenAIAccountRepo{ accounts: []Account{ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2}, }, } cache := &stubGatewayCache{ sessionBindings: map[string]int64{"openai:" + sessionHash: 1}, } svc := &OpenAIGatewayService{ accountRepo: repo, cache: cache, } excluded := map[int64]struct{}{1: {}} acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", excluded) if err != nil { t.Fatalf("SelectAccountForModelWithExclusions error: %v", err) } if acc == nil || acc.ID != 2 { t.Fatalf("expected account 2") } } func TestOpenAISelectAccountForModelWithExclusions_StickyNonOpenAI(t *testing.T) { sessionHash := "non-openai" repo := stubOpenAIAccountRepo{ accounts: []Account{ {ID: 1, Platform: PlatformAnthropic, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 2}, }, } cache := &stubGatewayCache{ sessionBindings: map[string]int64{"openai:" + sessionHash: 1}, } svc := &OpenAIGatewayService{ accountRepo: repo, cache: cache, } acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, sessionHash, "gpt-4", nil) if err != nil { t.Fatalf("SelectAccountForModelWithExclusions error: %v", err) } if acc == nil || acc.ID != 2 { t.Fatalf("expected account 2") } } func TestOpenAISelectAccountForModelWithExclusions_NoAccounts(t *testing.T) { repo := stubOpenAIAccountRepo{accounts: []Account{}} cache := &stubGatewayCache{} svc := &OpenAIGatewayService{ accountRepo: repo, cache: cache, } acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "", nil) if err == nil { t.Fatalf("expected error for no accounts") } if acc != nil { t.Fatalf("expected nil account") } if !strings.Contains(err.Error(), "no available OpenAI accounts") { t.Fatalf("unexpected error: %v", err) } } func TestOpenAISelectAccountWithLoadAwareness_NoCandidates(t *testing.T) { groupID := int64(1) resetAt := time.Now().Add(1 * time.Hour) repo := stubOpenAIAccountRepo{ accounts: []Account{ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, RateLimitResetAt: &resetAt}, }, } cache := &stubGatewayCache{} concurrencyCache := stubConcurrencyCache{} svc := &OpenAIGatewayService{ accountRepo: repo, cache: cache, concurrencyService: NewConcurrencyService(concurrencyCache), } selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil) if err == nil { t.Fatalf("expected error for no candidates") } if selection != nil { t.Fatalf("expected nil selection") } } func TestOpenAISelectAccountWithLoadAwareness_AllFullWaitPlan(t *testing.T) { groupID := int64(1) repo := stubOpenAIAccountRepo{ accounts: []Account{ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, }, } cache := &stubGatewayCache{} concurrencyCache := stubConcurrencyCache{ loadMap: map[int64]*AccountLoadInfo{ 1: {AccountID: 1, LoadRate: 100}, }, } svc := &OpenAIGatewayService{ accountRepo: repo, cache: cache, concurrencyService: NewConcurrencyService(concurrencyCache), } selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil) if err != nil { t.Fatalf("SelectAccountWithLoadAwareness error: %v", err) } if selection == nil || selection.WaitPlan == nil { t.Fatalf("expected wait plan") } } func TestOpenAISelectAccountWithLoadAwareness_LoadBatchErrorNoAcquire(t *testing.T) { groupID := int64(1) repo := stubOpenAIAccountRepo{ accounts: []Account{ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, }, } cache := &stubGatewayCache{} concurrencyCache := stubConcurrencyCache{ loadBatchErr: errors.New("load batch failed"), acquireResults: map[int64]bool{1: false}, } svc := &OpenAIGatewayService{ accountRepo: repo, cache: cache, concurrencyService: NewConcurrencyService(concurrencyCache), } selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil) if err != nil { t.Fatalf("SelectAccountWithLoadAwareness error: %v", err) } if selection == nil || selection.WaitPlan == nil { t.Fatalf("expected wait plan") } } func TestOpenAISelectAccountWithLoadAwareness_MissingLoadInfo(t *testing.T) { groupID := int64(1) repo := stubOpenAIAccountRepo{ accounts: []Account{ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, }, } cache := &stubGatewayCache{} concurrencyCache := stubConcurrencyCache{ loadMap: map[int64]*AccountLoadInfo{ 1: {AccountID: 1, LoadRate: 50}, }, skipDefaultLoad: true, } svc := &OpenAIGatewayService{ accountRepo: repo, cache: cache, concurrencyService: NewConcurrencyService(concurrencyCache), } selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil) if err != nil { t.Fatalf("SelectAccountWithLoadAwareness error: %v", err) } if selection == nil || selection.Account == nil || selection.Account.ID != 2 { t.Fatalf("expected account 2") } } func TestOpenAISelectAccountForModelWithExclusions_LeastRecentlyUsed(t *testing.T) { oldTime := time.Now().Add(-2 * time.Hour) newTime := time.Now().Add(-1 * time.Hour) repo := stubOpenAIAccountRepo{ accounts: []Account{ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1, LastUsedAt: &newTime}, {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1, LastUsedAt: &oldTime}, }, } cache := &stubGatewayCache{} svc := &OpenAIGatewayService{ accountRepo: repo, cache: cache, } acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil) if err != nil { t.Fatalf("SelectAccountForModelWithExclusions error: %v", err) } if acc == nil || acc.ID != 2 { t.Fatalf("expected account 2") } } func TestOpenAISelectAccountWithLoadAwareness_PreferNeverUsed(t *testing.T) { groupID := int64(1) lastUsed := time.Now().Add(-1 * time.Hour) repo := stubOpenAIAccountRepo{ accounts: []Account{ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1, LastUsedAt: &lastUsed}, {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Concurrency: 1, Priority: 1}, }, } cache := &stubGatewayCache{} concurrencyCache := stubConcurrencyCache{ loadMap: map[int64]*AccountLoadInfo{ 1: {AccountID: 1, LoadRate: 10}, 2: {AccountID: 2, LoadRate: 10}, }, } svc := &OpenAIGatewayService{ accountRepo: repo, cache: cache, concurrencyService: NewConcurrencyService(concurrencyCache), } selection, err := svc.SelectAccountWithLoadAwareness(context.Background(), &groupID, "", "gpt-4", nil) if err != nil { t.Fatalf("SelectAccountWithLoadAwareness error: %v", err) } if selection == nil || selection.Account == nil || selection.Account.ID != 2 { t.Fatalf("expected account 2") } } func TestOpenAIStreamingTimeout(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ Gateway: config.GatewayConfig{ StreamDataIntervalTimeout: 1, StreamKeepaliveInterval: 0, MaxLineSize: defaultMaxLineSize, }, } svc := &OpenAIGatewayService{cfg: cfg} rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/", nil) pr, pw := io.Pipe() resp := &http.Response{ StatusCode: http.StatusOK, Body: pr, Header: http.Header{}, } start := time.Now() _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, start, "model", "model") _ = pw.Close() _ = pr.Close() if err == nil || !strings.Contains(err.Error(), "stream data interval timeout") { t.Fatalf("expected stream timeout error, got %v", err) } if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "stream_timeout") { t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String()) } } func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ Gateway: config.GatewayConfig{ StreamDataIntervalTimeout: 0, StreamKeepaliveInterval: 0, MaxLineSize: defaultMaxLineSize, }, } svc := &OpenAIGatewayService{cfg: cfg} rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) ctx, cancel := context.WithCancel(context.Background()) cancel() c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx) resp := &http.Response{ StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}, } _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") if err != nil { t.Fatalf("expected nil error, got %v", err) } if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") { t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String()) } } func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ Gateway: config.GatewayConfig{ StreamDataIntervalTimeout: 0, StreamKeepaliveInterval: 0, MaxLineSize: defaultMaxLineSize, }, } svc := &OpenAIGatewayService{cfg: cfg} rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/", nil) c.Writer = &failingGinWriter{ResponseWriter: c.Writer, failAfter: 0} pr, pw := io.Pipe() resp := &http.Response{ StatusCode: http.StatusOK, Body: pr, Header: http.Header{}, } go func() { defer func() { _ = pw.Close() }() _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n")) _, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":3,\"output_tokens\":5,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n")) }() result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") _ = pr.Close() if err != nil { t.Fatalf("expected nil error, got %v", err) } if result == nil || result.usage == nil { t.Fatalf("expected usage result") } if result.usage.InputTokens != 3 || result.usage.OutputTokens != 5 || result.usage.CacheReadInputTokens != 1 { t.Fatalf("unexpected usage: %+v", *result.usage) } if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "write_failed") { t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String()) } } func TestOpenAIStreamingTooLong(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ Gateway: config.GatewayConfig{ StreamDataIntervalTimeout: 0, StreamKeepaliveInterval: 0, MaxLineSize: 64 * 1024, }, } svc := &OpenAIGatewayService{cfg: cfg} rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/", nil) pr, pw := io.Pipe() resp := &http.Response{ StatusCode: http.StatusOK, Body: pr, Header: http.Header{}, } go func() { defer func() { _ = pw.Close() }() // 写入超过 MaxLineSize 的单行数据,触发 ErrTooLong payload := "data: " + strings.Repeat("a", 128*1024) + "\n" _, _ = pw.Write([]byte(payload)) }() _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 2}, time.Now(), "model", "model") _ = pr.Close() if !errors.Is(err, bufio.ErrTooLong) { t.Fatalf("expected ErrTooLong, got %v", err) } if !strings.Contains(rec.Body.String(), "\"type\":\"error\"") || !strings.Contains(rec.Body.String(), "response_too_large") { t.Fatalf("expected OpenAI-compatible error SSE event, got %q", rec.Body.String()) } } func TestOpenAINonStreamingContentTypePassThrough(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ Security: config.SecurityConfig{ ResponseHeaders: config.ResponseHeaderConfig{Enabled: false}, }, } svc := &OpenAIGatewayService{cfg: cfg} rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/", nil) body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`) resp := &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(body)), Header: http.Header{"Content-Type": []string{"application/vnd.test+json"}}, } _, err := svc.handleNonStreamingResponse(c.Request.Context(), resp, c, &Account{}, "model", "model") if err != nil { t.Fatalf("handleNonStreamingResponse error: %v", err) } if !strings.Contains(rec.Header().Get("Content-Type"), "application/vnd.test+json") { t.Fatalf("expected Content-Type passthrough, got %q", rec.Header().Get("Content-Type")) } } func TestOpenAINonStreamingContentTypeDefault(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ Security: config.SecurityConfig{ ResponseHeaders: config.ResponseHeaderConfig{Enabled: false}, }, } svc := &OpenAIGatewayService{cfg: cfg} rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/", nil) body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`) resp := &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewReader(body)), Header: http.Header{}, } _, err := svc.handleNonStreamingResponse(c.Request.Context(), resp, c, &Account{}, "model", "model") if err != nil { t.Fatalf("handleNonStreamingResponse error: %v", err) } if !strings.Contains(rec.Header().Get("Content-Type"), "application/json") { t.Fatalf("expected default Content-Type, got %q", rec.Header().Get("Content-Type")) } } func TestOpenAIStreamingHeadersOverride(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ Security: config.SecurityConfig{ ResponseHeaders: config.ResponseHeaderConfig{Enabled: false}, }, Gateway: config.GatewayConfig{ StreamDataIntervalTimeout: 0, StreamKeepaliveInterval: 0, MaxLineSize: defaultMaxLineSize, }, } svc := &OpenAIGatewayService{cfg: cfg} rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/", nil) pr, pw := io.Pipe() resp := &http.Response{ StatusCode: http.StatusOK, Body: pr, Header: http.Header{ "Cache-Control": []string{"upstream"}, "X-Request-Id": []string{"req-123"}, "Content-Type": []string{"application/custom"}, }, } go func() { defer func() { _ = pw.Close() }() _, _ = pw.Write([]byte("data: {}\n\n")) }() _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") _ = pr.Close() if err != nil { t.Fatalf("handleStreamingResponse error: %v", err) } if rec.Header().Get("Cache-Control") != "no-cache" { t.Fatalf("expected Cache-Control override, got %q", rec.Header().Get("Cache-Control")) } if rec.Header().Get("Content-Type") != "text/event-stream" { t.Fatalf("expected Content-Type override, got %q", rec.Header().Get("Content-Type")) } if rec.Header().Get("X-Request-Id") != "req-123" { t.Fatalf("expected X-Request-Id passthrough, got %q", rec.Header().Get("X-Request-Id")) } } func TestOpenAIStreamingReuseScannerBufferAndStillWorks(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ Gateway: config.GatewayConfig{ StreamDataIntervalTimeout: 0, StreamKeepaliveInterval: 0, MaxLineSize: defaultMaxLineSize, }, } svc := &OpenAIGatewayService{cfg: cfg} rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/", nil) pr, pw := io.Pipe() resp := &http.Response{ StatusCode: http.StatusOK, Body: pr, Header: http.Header{}, } go func() { defer func() { _ = pw.Close() }() _, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"input_tokens_details\":{\"cached_tokens\":3}}}}\n\n")) }() result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") _ = pr.Close() require.NoError(t, err) require.NotNil(t, result) require.NotNil(t, result.usage) require.Equal(t, 1, result.usage.InputTokens) require.Equal(t, 2, result.usage.OutputTokens) require.Equal(t, 3, result.usage.CacheReadInputTokens) } func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ Security: config.SecurityConfig{ URLAllowlist: config.URLAllowlistConfig{Enabled: false}, }, } svc := &OpenAIGatewayService{cfg: cfg} rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/", nil) account := &Account{ Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Credentials: map[string]any{"base_url": "://invalid-url"}, } _, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false, "", false) if err == nil { t.Fatalf("expected error for invalid base_url when allowlist disabled") } } func TestOpenAIValidateUpstreamBaseURLDisabledRequiresHTTPS(t *testing.T) { cfg := &config.Config{ Security: config.SecurityConfig{ URLAllowlist: config.URLAllowlistConfig{Enabled: false}, }, } svc := &OpenAIGatewayService{cfg: cfg} if _, err := svc.validateUpstreamBaseURL("http://not-https.example.com"); err == nil { t.Fatalf("expected http to be rejected when allow_insecure_http is false") } normalized, err := svc.validateUpstreamBaseURL("https://example.com") if err != nil { t.Fatalf("expected https to be allowed when allowlist disabled, got %v", err) } if normalized != "https://example.com" { t.Fatalf("expected raw url passthrough, got %q", normalized) } } func TestOpenAIValidateUpstreamBaseURLDisabledAllowsHTTP(t *testing.T) { cfg := &config.Config{ Security: config.SecurityConfig{ URLAllowlist: config.URLAllowlistConfig{ Enabled: false, AllowInsecureHTTP: true, }, }, } svc := &OpenAIGatewayService{cfg: cfg} normalized, err := svc.validateUpstreamBaseURL("http://not-https.example.com") if err != nil { t.Fatalf("expected http allowed when allow_insecure_http is true, got %v", err) } if normalized != "http://not-https.example.com" { t.Fatalf("expected raw url passthrough, got %q", normalized) } } func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) { cfg := &config.Config{ Security: config.SecurityConfig{ URLAllowlist: config.URLAllowlistConfig{ Enabled: true, UpstreamHosts: []string{"example.com"}, }, }, } svc := &OpenAIGatewayService{cfg: cfg} if _, err := svc.validateUpstreamBaseURL("https://example.com"); err != nil { t.Fatalf("expected allowlisted host to pass, got %v", err) } if _, err := svc.validateUpstreamBaseURL("https://evil.com"); err == nil { t.Fatalf("expected non-allowlisted host to fail") } } // ==================== P1-08 修复:model 替换性能优化测试 ==================== func TestReplaceModelInSSELine(t *testing.T) { svc := &OpenAIGatewayService{} tests := []struct { name string line string from string to string expected string }{ { name: "顶层 model 字段替换", line: `data: {"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`, from: "gpt-4o", to: "my-custom-model", expected: `data: {"id":"chatcmpl-123","model":"my-custom-model","choices":[]}`, }, { name: "嵌套 response.model 替换", line: `data: {"type":"response","response":{"id":"resp-1","model":"gpt-4o","output":[]}}`, from: "gpt-4o", to: "my-model", expected: `data: {"type":"response","response":{"id":"resp-1","model":"my-model","output":[]}}`, }, { name: "model 不匹配时不替换", line: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`, from: "gpt-4o", to: "my-model", expected: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`, }, { name: "无 model 字段时不替换", line: `data: {"id":"chatcmpl-123","choices":[]}`, from: "gpt-4o", to: "my-model", expected: `data: {"id":"chatcmpl-123","choices":[]}`, }, { name: "空 data 行", line: `data: `, from: "gpt-4o", to: "my-model", expected: `data: `, }, { name: "[DONE] 行", line: `data: [DONE]`, from: "gpt-4o", to: "my-model", expected: `data: [DONE]`, }, { name: "非 data: 前缀行", line: `event: message`, from: "gpt-4o", to: "my-model", expected: `event: message`, }, { name: "非法 JSON 不替换", line: `data: {invalid json}`, from: "gpt-4o", to: "my-model", expected: `data: {invalid json}`, }, { name: "无空格 data: 格式", line: `data:{"id":"x","model":"gpt-4o"}`, from: "gpt-4o", to: "my-model", expected: `data: {"id":"x","model":"my-model"}`, }, { name: "model 名含特殊字符", line: `data: {"model":"org/model-v2.1-beta"}`, from: "org/model-v2.1-beta", to: "custom/alias", expected: `data: {"model":"custom/alias"}`, }, { name: "空行", line: "", from: "gpt-4o", to: "my-model", expected: "", }, { name: "保持其他字段不变", line: `data: {"id":"abc","object":"chat.completion.chunk","model":"gpt-4o","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`, from: "gpt-4o", to: "alias", expected: `data: {"id":"abc","object":"chat.completion.chunk","model":"alias","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`, }, { name: "顶层优先于嵌套:同时存在两个 model", line: `data: {"model":"gpt-4o","response":{"model":"gpt-4o"}}`, from: "gpt-4o", to: "replaced", expected: `data: {"model":"replaced","response":{"model":"gpt-4o"}}`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := svc.replaceModelInSSELine(tt.line, tt.from, tt.to) require.Equal(t, tt.expected, got) }) } } func TestReplaceModelInSSEBody(t *testing.T) { svc := &OpenAIGatewayService{} tests := []struct { name string body string from string to string expected string }{ { name: "多行 SSE body 替换", body: "data: {\"model\":\"gpt-4o\",\"choices\":[]}\n\ndata: {\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n", from: "gpt-4o", to: "alias", expected: "data: {\"model\":\"alias\",\"choices\":[]}\n\ndata: {\"model\":\"alias\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n", }, { name: "无需替换的 body", body: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n", from: "gpt-4o", to: "alias", expected: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n", }, { name: "混合 event 和 data 行", body: "event: message\ndata: {\"model\":\"gpt-4o\"}\n\n", from: "gpt-4o", to: "alias", expected: "event: message\ndata: {\"model\":\"alias\"}\n\n", }, { name: "空 body", body: "", from: "gpt-4o", to: "alias", expected: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := svc.replaceModelInSSEBody(tt.body, tt.from, tt.to) require.Equal(t, tt.expected, got) }) } } func TestReplaceModelInResponseBody(t *testing.T) { svc := &OpenAIGatewayService{} tests := []struct { name string body string from string to string expected string }{ { name: "替换顶层 model", body: `{"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`, from: "gpt-4o", to: "alias", expected: `{"id":"chatcmpl-123","model":"alias","choices":[]}`, }, { name: "model 不匹配不替换", body: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`, from: "gpt-4o", to: "alias", expected: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`, }, { name: "无 model 字段不替换", body: `{"id":"chatcmpl-123","choices":[]}`, from: "gpt-4o", to: "alias", expected: `{"id":"chatcmpl-123","choices":[]}`, }, { name: "非法 JSON 返回原值", body: `not json`, from: "gpt-4o", to: "alias", expected: `not json`, }, { name: "空 body 返回原值", body: ``, from: "gpt-4o", to: "alias", expected: ``, }, { name: "保持嵌套结构不变", body: `{"model":"gpt-4o","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`, from: "gpt-4o", to: "alias", expected: `{"model":"alias","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := svc.replaceModelInResponseBody([]byte(tt.body), tt.from, tt.to) require.Equal(t, tt.expected, string(got)) }) } } func TestExtractOpenAISSEDataLine(t *testing.T) { tests := []struct { name string line string wantData string wantOK bool }{ {name: "标准格式", line: `data: {"type":"x"}`, wantData: `{"type":"x"}`, wantOK: true}, {name: "无空格格式", line: `data:{"type":"x"}`, wantData: `{"type":"x"}`, wantOK: true}, {name: "纯空数据", line: `data: `, wantData: ``, wantOK: true}, {name: "非 data 行", line: `event: message`, wantData: ``, wantOK: false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, ok := extractOpenAISSEDataLine(tt.line) require.Equal(t, tt.wantOK, ok) require.Equal(t, tt.wantData, got) }) } } func TestParseSSEUsage_SelectiveParsing(t *testing.T) { svc := &OpenAIGatewayService{} usage := &OpenAIUsage{InputTokens: 9, OutputTokens: 8, CacheReadInputTokens: 7} // 非 completed 事件,不应覆盖 usage svc.parseSSEUsage(`{"type":"response.in_progress","response":{"usage":{"input_tokens":1,"output_tokens":2}}}`, usage) require.Equal(t, 9, usage.InputTokens) require.Equal(t, 8, usage.OutputTokens) require.Equal(t, 7, usage.CacheReadInputTokens) // completed 事件,应提取 usage svc.parseSSEUsage(`{"type":"response.completed","response":{"usage":{"input_tokens":3,"output_tokens":5,"input_tokens_details":{"cached_tokens":2}}}}`, usage) require.Equal(t, 3, usage.InputTokens) require.Equal(t, 5, usage.OutputTokens) require.Equal(t, 2, usage.CacheReadInputTokens) } func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) { body := strings.Join([]string{ `event: message`, `data: {"type":"response.in_progress","response":{"id":"resp_1"}}`, `data: {"type":"response.completed","response":{"id":"resp_1","model":"gpt-4o","usage":{"input_tokens":11,"output_tokens":22,"input_tokens_details":{"cached_tokens":3}}}}`, `data: [DONE]`, }, "\n") finalResp, ok := extractCodexFinalResponse(body) require.True(t, ok) require.Contains(t, string(finalResp), `"id":"resp_1"`) require.Contains(t, string(finalResp), `"input_tokens":11`) } func TestHandleOAuthSSEToJSON_CompletedEventReturnsJSON(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/", nil) svc := &OpenAIGatewayService{cfg: &config.Config{}} resp := &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Content-Type": []string{"text/event-stream"}}, } body := []byte(strings.Join([]string{ `data: {"type":"response.in_progress","response":{"id":"resp_2"}}`, `data: {"type":"response.completed","response":{"id":"resp_2","model":"gpt-4o","usage":{"input_tokens":7,"output_tokens":9,"input_tokens_details":{"cached_tokens":1}}}}`, `data: [DONE]`, }, "\n")) usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o") require.NoError(t, err) require.NotNil(t, usage) require.Equal(t, 7, usage.InputTokens) require.Equal(t, 9, usage.OutputTokens) require.Equal(t, 1, usage.CacheReadInputTokens) // Header 可能由上游 Content-Type 透传;关键是 body 已转换为最终 JSON 响应。 require.NotContains(t, rec.Body.String(), "event:") require.Contains(t, rec.Body.String(), `"id":"resp_2"`) require.NotContains(t, rec.Body.String(), "data:") } func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/", nil) svc := &OpenAIGatewayService{cfg: &config.Config{}} resp := &http.Response{ StatusCode: http.StatusOK, Header: http.Header{"Content-Type": []string{"text/event-stream"}}, } body := []byte(strings.Join([]string{ `data: {"type":"response.in_progress","response":{"id":"resp_3"}}`, `data: [DONE]`, }, "\n")) usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o") require.NoError(t, err) require.NotNil(t, usage) require.Equal(t, 0, usage.InputTokens) require.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream") require.Contains(t, rec.Body.String(), `data: {"type":"response.in_progress"`) }