40 lines
1009 B
Go
40 lines
1009 B
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type openAI403CounterResetStub struct {
|
|
resetCalls []int64
|
|
}
|
|
|
|
func (s *openAI403CounterResetStub) IncrementOpenAI403Count(context.Context, int64, int) (int64, error) {
|
|
return 0, nil
|
|
}
|
|
|
|
func (s *openAI403CounterResetStub) ResetOpenAI403Count(_ context.Context, accountID int64) error {
|
|
s.resetCalls = append(s.resetCalls, accountID)
|
|
return nil
|
|
}
|
|
|
|
func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterBeforeZeroUsageReturn(t *testing.T) {
|
|
counter := &openAI403CounterResetStub{}
|
|
rateLimitSvc := NewRateLimitService(nil, nil, nil, nil, nil)
|
|
rateLimitSvc.SetOpenAI403CounterCache(counter)
|
|
|
|
svc := &OpenAIGatewayService{
|
|
rateLimitService: rateLimitSvc,
|
|
}
|
|
|
|
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
|
|
Result: &OpenAIForwardResult{},
|
|
Account: &Account{ID: 777, Platform: PlatformOpenAI},
|
|
})
|
|
|
|
require.NoError(t, err)
|
|
require.Equal(t, []int64{777}, counter.resetCalls)
|
|
}
|