test: 增加 repository 测试

This commit is contained in:
Forest
2025-12-25 10:52:56 +08:00
parent 9d30ceae8d
commit 25a304c231
36 changed files with 7412 additions and 54 deletions

View File

@@ -0,0 +1,580 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type AccountRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *AccountRepository
}
func (s *AccountRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewAccountRepository(s.db)
}
func TestAccountRepoSuite(t *testing.T) {
suite.Run(t, new(AccountRepoSuite))
}
// --- Create / GetByID / Update / Delete ---
func (s *AccountRepoSuite) TestCreate() {
account := &model.Account{
Name: "test-create",
Platform: model.PlatformAnthropic,
Type: model.AccountTypeOAuth,
Status: model.StatusActive,
}
err := s.repo.Create(s.ctx, account)
s.Require().NoError(err, "Create")
s.Require().NotZero(account.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("test-create", got.Name)
}
func (s *AccountRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *AccountRepoSuite) TestUpdate() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "original"})
account.Name = "updated"
err := s.repo.Update(s.ctx, account)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated", got.Name)
}
func (s *AccountRepoSuite) TestDelete() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "to-delete"})
err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, account.ID)
s.Require().Error(err, "expected error after delete")
}
func (s *AccountRepoSuite) TestDelete_WithGroupBindings() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete should cascade remove bindings")
var count int64
s.db.Model(&model.AccountGroup{}).Where("account_id = ?", account.ID).Count(&count)
s.Require().Zero(count, "expected bindings to be removed")
}
// --- List / ListWithFilters ---
func (s *AccountRepoSuite) TestList() {
mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc2"})
accounts, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
s.Require().Len(accounts, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *AccountRepoSuite) TestListWithFilters() {
tests := []struct {
name string
setup func(db *gorm.DB)
platform string
accType string
status string
search string
wantCount int
validate func(accounts []model.Account)
}{
{
name: "filter_by_platform",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic})
mustCreateAccount(s.T(), db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI})
},
platform: model.PlatformOpenAI,
wantCount: 1,
validate: func(accounts []model.Account) {
s.Require().Equal(model.PlatformOpenAI, accounts[0].Platform)
},
},
{
name: "filter_by_type",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &model.Account{Name: "t1", Type: model.AccountTypeOAuth})
mustCreateAccount(s.T(), db, &model.Account{Name: "t2", Type: model.AccountTypeApiKey})
},
accType: model.AccountTypeApiKey,
wantCount: 1,
validate: func(accounts []model.Account) {
s.Require().Equal(model.AccountTypeApiKey, accounts[0].Type)
},
},
{
name: "filter_by_status",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &model.Account{Name: "s1", Status: model.StatusActive})
mustCreateAccount(s.T(), db, &model.Account{Name: "s2", Status: model.StatusDisabled})
},
status: model.StatusDisabled,
wantCount: 1,
validate: func(accounts []model.Account) {
s.Require().Equal(model.StatusDisabled, accounts[0].Status)
},
},
{
name: "filter_by_search",
setup: func(db *gorm.DB) {
mustCreateAccount(s.T(), db, &model.Account{Name: "alpha-account"})
mustCreateAccount(s.T(), db, &model.Account{Name: "beta-account"})
},
search: "alpha",
wantCount: 1,
validate: func(accounts []model.Account) {
s.Require().Contains(accounts[0].Name, "alpha")
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
// 每个 case 重新获取隔离资源
db := testTx(s.T())
repo := NewAccountRepository(db)
ctx := context.Background()
tt.setup(db)
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search)
s.Require().NoError(err)
s.Require().Len(accounts, tt.wantCount)
if tt.validate != nil {
tt.validate(accounts)
}
})
}
}
// --- ListByGroup / ListActive / ListByPlatform ---
func (s *AccountRepoSuite) TestListByGroup() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"})
acc1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Status: model.StatusActive})
acc2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Status: model.StatusActive})
mustBindAccountToGroup(s.T(), s.db, acc1.ID, group.ID, 2)
mustBindAccountToGroup(s.T(), s.db, acc2.ID, group.ID, 1)
accounts, err := s.repo.ListByGroup(s.ctx, group.ID)
s.Require().NoError(err, "ListByGroup")
s.Require().Len(accounts, 2)
// Should be ordered by priority
s.Require().Equal(acc2.ID, accounts[0].ID, "expected acc2 first (priority=1)")
}
func (s *AccountRepoSuite) TestListActive() {
mustCreateAccount(s.T(), s.db, &model.Account{Name: "active1", Status: model.StatusActive})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "inactive1", Status: model.StatusDisabled})
accounts, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
s.Require().Len(accounts, 1)
s.Require().Equal("active1", accounts[0].Name)
}
func (s *AccountRepoSuite) TestListByPlatform() {
mustCreateAccount(s.T(), s.db, &model.Account{Name: "p1", Platform: model.PlatformAnthropic, Status: model.StatusActive})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "p2", Platform: model.PlatformOpenAI, Status: model.StatusActive})
accounts, err := s.repo.ListByPlatform(s.ctx, model.PlatformAnthropic)
s.Require().NoError(err, "ListByPlatform")
s.Require().Len(accounts, 1)
s.Require().Equal(model.PlatformAnthropic, accounts[0].Platform)
}
// --- Preload and VirtualFields ---
func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"})
account := mustCreateAccount(s.T(), s.db, &model.Account{
Name: "acc1",
ProxyID: &proxy.ID,
})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err, "GetByID")
s.Require().NotNil(got.Proxy, "expected Proxy preload")
s.Require().Equal(proxy.ID, got.Proxy.ID)
s.Require().Len(got.GroupIDs, 1, "expected GroupIDs to be populated")
s.Require().Equal(group.ID, got.GroupIDs[0])
s.Require().Len(got.Groups, 1, "expected Groups to be populated")
s.Require().Equal(group.ID, got.Groups[0].ID)
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc")
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(accounts, 1)
s.Require().NotNil(accounts[0].Proxy, "expected Proxy preload in list")
s.Require().Equal(proxy.ID, accounts[0].Proxy.ID)
s.Require().Len(accounts[0].GroupIDs, 1, "expected GroupIDs in list")
s.Require().Equal(group.ID, accounts[0].GroupIDs[0])
}
// --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups ---
func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc"})
s.Require().NoError(s.repo.AddToGroup(s.ctx, account.ID, g1.ID, 10), "AddToGroup")
groups, err := s.repo.GetGroups(s.ctx, account.ID)
s.Require().NoError(err, "GetGroups")
s.Require().Len(groups, 1, "expected 1 group")
s.Require().Equal(g1.ID, groups[0].ID)
s.Require().NoError(s.repo.RemoveFromGroup(s.ctx, account.ID, g1.ID), "RemoveFromGroup")
groups, err = s.repo.GetGroups(s.ctx, account.ID)
s.Require().NoError(err, "GetGroups after remove")
s.Require().Empty(groups, "expected 0 groups after remove")
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{g1.ID, g2.ID}), "BindGroups")
groups, err = s.repo.GetGroups(s.ctx, account.ID)
s.Require().NoError(err, "GetGroups after bind")
s.Require().Len(groups, 2, "expected 2 groups after bind")
}
func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-empty"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"})
mustBindAccountToGroup(s.T(), s.db, account.ID, group.ID, 1)
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{}), "BindGroups empty")
groups, err := s.repo.GetGroups(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Empty(groups, "expected 0 groups after binding empty list")
}
// --- Schedulable ---
func (s *AccountRepoSuite) TestListSchedulable() {
now := time.Now()
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sched"})
okAcc := mustCreateAccount(s.T(), s.db, &model.Account{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
future := now.Add(10 * time.Minute)
overloaded := mustCreateAccount(s.T(), s.db, &model.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
sched, err := s.repo.ListSchedulable(s.ctx)
s.Require().NoError(err, "ListSchedulable")
ids := idsOfAccounts(sched)
s.Require().Contains(ids, okAcc.ID)
s.Require().NotContains(ids, overloaded.ID)
}
func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates() {
now := time.Now()
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sched"})
okAcc := mustCreateAccount(s.T(), s.db, &model.Account{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, okAcc.ID, group.ID, 1)
future := now.Add(10 * time.Minute)
overloaded := mustCreateAccount(s.T(), s.db, &model.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.db, overloaded.ID, group.ID, 1)
rateLimited := mustCreateAccount(s.T(), s.db, &model.Account{Name: "rl", Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, rateLimited.ID, group.ID, 1)
s.Require().NoError(s.repo.SetRateLimited(s.ctx, rateLimited.ID, now.Add(10*time.Minute)), "SetRateLimited")
s.Require().NoError(s.repo.SetError(s.ctx, overloaded.ID, "boom"), "SetError")
sched, err := s.repo.ListSchedulableByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ListSchedulableByGroupID")
s.Require().Len(sched, 1, "expected only ok account schedulable")
s.Require().Equal(okAcc.ID, sched[0].ID)
s.Require().NoError(s.repo.ClearRateLimit(s.ctx, rateLimited.ID), "ClearRateLimit")
sched2, err := s.repo.ListSchedulableByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ListSchedulableByGroupID after ClearRateLimit")
s.Require().Len(sched2, 2, "expected 2 schedulable accounts after ClearRateLimit")
}
func (s *AccountRepoSuite) TestListSchedulableByPlatform() {
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic, Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI, Schedulable: true})
accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, model.PlatformAnthropic)
s.Require().NoError(err)
s.Require().Len(accounts, 1)
s.Require().Equal(model.PlatformAnthropic, accounts[0].Platform)
}
func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sp"})
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", Platform: model.PlatformAnthropic, Schedulable: true})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", Platform: model.PlatformOpenAI, Schedulable: true})
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, model.PlatformAnthropic)
s.Require().NoError(err)
s.Require().Len(accounts, 1)
s.Require().Equal(a1.ID, accounts[0].ID)
}
func (s *AccountRepoSuite) TestSetSchedulable() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-sched", Schedulable: true})
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().False(got.Schedulable)
}
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
func (s *AccountRepoSuite) TestSetOverloaded() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-over"})
until := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().NotNil(got.OverloadUntil)
s.Require().WithinDuration(until, *got.OverloadUntil, time.Second)
}
func (s *AccountRepoSuite) TestSetRateLimited() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-rl"})
resetAt := time.Date(2025, 6, 15, 14, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, resetAt))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().NotNil(got.RateLimitedAt)
s.Require().NotNil(got.RateLimitResetAt)
s.Require().WithinDuration(resetAt, *got.RateLimitResetAt, time.Second)
}
func (s *AccountRepoSuite) TestClearRateLimit() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-clear"})
until := time.Now().Add(1 * time.Hour)
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, until))
s.Require().NoError(s.repo.ClearRateLimit(s.ctx, account.ID))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Nil(got.RateLimitedAt)
s.Require().Nil(got.RateLimitResetAt)
s.Require().Nil(got.OverloadUntil)
}
// --- UpdateLastUsed ---
func (s *AccountRepoSuite) TestUpdateLastUsed() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-used"})
s.Require().Nil(account.LastUsedAt)
s.Require().NoError(s.repo.UpdateLastUsed(s.ctx, account.ID))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().NotNil(got.LastUsedAt)
}
// --- SetError ---
func (s *AccountRepoSuite) TestSetError() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-err", Status: model.StatusActive})
s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong"))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Equal(model.StatusError, got.Status)
s.Require().Equal("something went wrong", got.ErrorMessage)
}
// --- UpdateSessionWindow ---
func (s *AccountRepoSuite) TestUpdateSessionWindow() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-win"})
start := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC)
end := time.Date(2025, 6, 15, 15, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.UpdateSessionWindow(s.ctx, account.ID, &start, &end, "active"))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().NotNil(got.SessionWindowStart)
s.Require().NotNil(got.SessionWindowEnd)
s.Require().Equal("active", got.SessionWindowStatus)
}
// --- UpdateExtra ---
func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
account := mustCreateAccount(s.T(), s.db, &model.Account{
Name: "acc-extra",
Extra: model.JSONB{"a": "1"},
})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"b": "2"}), "UpdateExtra")
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("1", got.Extra["a"])
s.Require().Equal("2", got.Extra["b"])
}
func (s *AccountRepoSuite) TestUpdateExtra_EmptyUpdates() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-extra-empty"})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{}))
}
func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-nil-extra", Extra: nil})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"key": "val"}))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Equal("val", got.Extra["key"])
}
// --- GetByCRSAccountID ---
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
crsID := "crs-12345"
mustCreateAccount(s.T(), s.db, &model.Account{
Name: "acc-crs",
Extra: model.JSONB{"crs_account_id": crsID},
})
got, err := s.repo.GetByCRSAccountID(s.ctx, crsID)
s.Require().NoError(err)
s.Require().NotNil(got)
s.Require().Equal("acc-crs", got.Name)
}
func (s *AccountRepoSuite) TestGetByCRSAccountID_NotFound() {
got, err := s.repo.GetByCRSAccountID(s.ctx, "non-existent")
s.Require().NoError(err)
s.Require().Nil(got)
}
func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() {
got, err := s.repo.GetByCRSAccountID(s.ctx, "")
s.Require().NoError(err)
s.Require().Nil(got)
}
// --- BulkUpdate ---
func (s *AccountRepoSuite) TestBulkUpdate() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk1", Priority: 1})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk2", Priority: 1})
newPriority := 99
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, ports.AccountBulkUpdate{
Priority: &newPriority,
})
s.Require().NoError(err)
s.Require().GreaterOrEqual(affected, int64(1), "expected at least one affected row")
got1, _ := s.repo.GetByID(s.ctx, a1.ID)
got2, _ := s.repo.GetByID(s.ctx, a2.ID)
s.Require().Equal(99, got1.Priority)
s.Require().Equal(99, got2.Priority)
}
func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{
Name: "bulk-cred",
Credentials: model.JSONB{"existing": "value"},
})
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, ports.AccountBulkUpdate{
Credentials: model.JSONB{"new_key": "new_value"},
})
s.Require().NoError(err)
got, _ := s.repo.GetByID(s.ctx, a1.ID)
s.Require().Equal("value", got.Credentials["existing"])
s.Require().Equal("new_value", got.Credentials["new_key"])
}
func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{
Name: "bulk-extra",
Extra: model.JSONB{"existing": "val"},
})
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, ports.AccountBulkUpdate{
Extra: model.JSONB{"new_key": "new_val"},
})
s.Require().NoError(err)
got, _ := s.repo.GetByID(s.ctx, a1.ID)
s.Require().Equal("val", got.Extra["existing"])
s.Require().Equal("new_val", got.Extra["new_key"])
}
func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
affected, err := s.repo.BulkUpdate(s.ctx, []int64{}, ports.AccountBulkUpdate{})
s.Require().NoError(err)
s.Require().Zero(affected)
}
func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() {
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "bulk-empty"})
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, ports.AccountBulkUpdate{})
s.Require().NoError(err)
s.Require().Zero(affected)
}
func idsOfAccounts(accounts []model.Account) []int64 {
out := make([]int64, 0, len(accounts))
for i := range accounts {
out = append(out, accounts[i].ID)
}
return out
}

View File

@@ -0,0 +1,125 @@
//go:build integration
package repository
import (
"context"
"fmt"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ApiKeyCacheSuite struct {
IntegrationRedisSuite
}
func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
}{
{
name: "missing_key_returns_redis_nil",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
userID := int64(1)
_, err := cache.GetCreateAttemptCount(ctx, userID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing key")
},
},
{
name: "increment_increases_count_and_sets_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
userID := int64(1)
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount")
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount 2")
count, err := cache.GetCreateAttemptCount(ctx, userID)
require.NoError(s.T(), err, "GetCreateAttemptCount")
require.Equal(s.T(), 2, count, "count mismatch")
ttl, err := rdb.TTL(ctx, key).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, apiKeyRateLimitDuration)
},
},
{
name: "delete_removes_key",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
userID := int64(1)
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID))
require.NoError(s.T(), cache.DeleteCreateAttemptCount(ctx, userID), "DeleteCreateAttemptCount")
_, err := cache.GetCreateAttemptCount(ctx, userID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after delete")
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
// 每个 case 重新获取隔离资源
rdb := testRedis(s.T())
cache := &apiKeyCache{rdb: rdb}
ctx := context.Background()
tt.fn(ctx, rdb, cache)
})
}
}
func (s *ApiKeyCacheSuite) TestDailyUsage() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
}{
{
name: "increment_increases_count",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
dailyKey := "daily:sk-test"
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage")
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage 2")
n, err := rdb.Get(ctx, dailyKey).Int()
require.NoError(s.T(), err, "Get dailyKey")
require.Equal(s.T(), 2, n, "expected daily usage=2")
},
},
{
name: "set_expiry_sets_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
dailyKey := "daily:sk-test-expiry"
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey))
require.NoError(s.T(), cache.SetDailyUsageExpiry(ctx, dailyKey, 1*time.Hour), "SetDailyUsageExpiry")
ttl, err := rdb.TTL(ctx, dailyKey).Result()
require.NoError(s.T(), err, "TTL dailyKey")
require.Greater(s.T(), ttl, time.Duration(0), "expected ttl > 0")
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
rdb := testRedis(s.T())
cache := &apiKeyCache{rdb: rdb}
ctx := context.Background()
tt.fn(ctx, rdb, cache)
})
}
}
func TestApiKeyCacheSuite(t *testing.T) {
suite.Run(t, new(ApiKeyCacheSuite))
}

View File

@@ -0,0 +1,355 @@
//go:build integration
package repository
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type ApiKeyRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *ApiKeyRepository
}
func (s *ApiKeyRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewApiKeyRepository(s.db)
}
func TestApiKeyRepoSuite(t *testing.T) {
suite.Run(t, new(ApiKeyRepoSuite))
}
// --- Create / GetByID / GetByKey ---
func (s *ApiKeyRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"})
key := &model.ApiKey{
UserID: user.ID,
Key: "sk-create-test",
Name: "Test Key",
Status: model.StatusActive,
}
err := s.repo.Create(s.ctx, key)
s.Require().NoError(err, "Create")
s.Require().NotZero(key.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, key.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("sk-create-test", got.Key)
}
func (s *ApiKeyRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *ApiKeyRepoSuite) TestGetByKey() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbykey@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-key"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
UserID: user.ID,
Key: "sk-getbykey",
Name: "My Key",
GroupID: &group.ID,
Status: model.StatusActive,
})
got, err := s.repo.GetByKey(s.ctx, key.Key)
s.Require().NoError(err, "GetByKey")
s.Require().Equal(key.ID, got.ID)
s.Require().NotNil(got.User, "expected User preload")
s.Require().Equal(user.ID, got.User.ID)
s.Require().NotNil(got.Group, "expected Group preload")
s.Require().Equal(group.ID, got.Group.ID)
}
func (s *ApiKeyRepoSuite) TestGetByKey_NotFound() {
_, err := s.repo.GetByKey(s.ctx, "non-existent-key")
s.Require().Error(err, "expected error for non-existent key")
}
// --- Update ---
func (s *ApiKeyRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
UserID: user.ID,
Key: "sk-update",
Name: "Original",
Status: model.StatusActive,
})
key.Name = "Renamed"
key.Status = model.StatusDisabled
err := s.repo.Update(s.ctx, key)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, key.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("sk-update", got.Key, "Update should not change key")
s.Require().Equal(user.ID, got.UserID, "Update should not change user_id")
s.Require().Equal("Renamed", got.Name)
s.Require().Equal(model.StatusDisabled, got.Status)
}
func (s *ApiKeyRepoSuite) TestUpdate_ClearGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
UserID: user.ID,
Key: "sk-clear-group",
Name: "Group Key",
GroupID: &group.ID,
})
key.GroupID = nil
err := s.repo.Update(s.ctx, key)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, key.ID)
s.Require().NoError(err)
s.Require().Nil(got.GroupID, "expected GroupID to be cleared")
}
// --- Delete ---
func (s *ApiKeyRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
UserID: user.ID,
Key: "sk-delete",
Name: "Delete Me",
})
err := s.repo.Delete(s.ctx, key.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, key.ID)
s.Require().Error(err, "expected error after delete")
}
// --- ListByUserID / CountByUserID ---
func (s *ApiKeyRepoSuite) TestListByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-1", Name: "Key 1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-list-2", Name: "Key 2"})
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUserID")
s.Require().Len(keys, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *ApiKeyRepoSuite) TestListByUserID_Pagination() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "paging@test.com"})
for i := 0; i < 5; i++ {
mustCreateApiKey(s.T(), s.db, &model.ApiKey{
UserID: user.ID,
Key: "sk-page-" + string(rune('a'+i)),
Name: "Key",
})
}
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
s.Require().NoError(err)
s.Require().Len(keys, 2)
s.Require().Equal(int64(5), page.Total)
s.Require().Equal(3, page.Pages)
}
func (s *ApiKeyRepoSuite) TestCountByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "count@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-count-2", Name: "K2"})
count, err := s.repo.CountByUserID(s.ctx, user.ID)
s.Require().NoError(err, "CountByUserID")
s.Require().Equal(int64(2), count)
}
// --- ListByGroupID / CountByGroupID ---
func (s *ApiKeyRepoSuite) TestListByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbygroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-1", Name: "K1", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-grp-3", Name: "K3"}) // no group
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByGroupID")
s.Require().Len(keys, 2)
s.Require().Equal(int64(2), page.Total)
// User preloaded
s.Require().NotNil(keys[0].User)
}
func (s *ApiKeyRepoSuite) TestCountByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "countgroup@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-gc-1", Name: "K1", GroupID: &group.ID})
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID")
s.Require().Equal(int64(1), count)
}
// --- ExistsByKey ---
func (s *ApiKeyRepoSuite) TestExistsByKey() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-exists", Name: "K"})
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
s.Require().NoError(err, "ExistsByKey")
s.Require().True(exists)
notExists, err := s.repo.ExistsByKey(s.ctx, "sk-not-exists")
s.Require().NoError(err)
s.Require().False(notExists)
}
// --- SearchApiKeys ---
func (s *ApiKeyRepoSuite) TestSearchApiKeys() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "search@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-1", Name: "Production Key"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-search-2", Name: "Development Key"})
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "prod", 10)
s.Require().NoError(err, "SearchApiKeys")
s.Require().Len(found, 1)
s.Require().Contains(found[0].Name, "Production")
}
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoKeyword() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnokw@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-1", Name: "K1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nk-2", Name: "K2"})
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "", 10)
s.Require().NoError(err)
s.Require().Len(found, 2)
}
func (s *ApiKeyRepoSuite) TestSearchApiKeys_NoUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "searchnouid@test.com"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-nu-1", Name: "TestKey"})
found, err := s.repo.SearchApiKeys(s.ctx, 0, "testkey", 10)
s.Require().NoError(err)
s.Require().Len(found, 1)
}
// --- ClearGroupIDByGroupID ---
func (s *ApiKeyRepoSuite) TestClearGroupIDByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "cleargrp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-clear-bulk"})
k1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-1", Name: "K1", GroupID: &group.ID})
k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-2", Name: "K2", GroupID: &group.ID})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-clr-3", Name: "K3"}) // no group
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ClearGroupIDByGroupID")
s.Require().Equal(int64(2), affected)
got1, _ := s.repo.GetByID(s.ctx, k1.ID)
got2, _ := s.repo.GetByID(s.ctx, k2.ID)
s.Require().Nil(got1.GroupID)
s.Require().Nil(got2.GroupID)
count, _ := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().Zero(count)
}
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
func (s *ApiKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "k@example.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-k"})
key := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
UserID: user.ID,
Key: "sk-test-1",
Name: "My Key",
GroupID: &group.ID,
Status: model.StatusActive,
})
got, err := s.repo.GetByKey(s.ctx, key.Key)
s.Require().NoError(err, "GetByKey")
s.Require().Equal(key.ID, got.ID)
s.Require().NotNil(got.User)
s.Require().Equal(user.ID, got.User.ID)
s.Require().NotNil(got.Group)
s.Require().Equal(group.ID, got.Group.ID)
key.Name = "Renamed"
key.Status = model.StatusDisabled
key.GroupID = nil
s.Require().NoError(s.repo.Update(s.ctx, key), "Update")
got2, err := s.repo.GetByID(s.ctx, key.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("sk-test-1", got2.Key, "Update should not change key")
s.Require().Equal(user.ID, got2.UserID, "Update should not change user_id")
s.Require().Equal("Renamed", got2.Name)
s.Require().Equal(model.StatusDisabled, got2.Status)
s.Require().Nil(got2.GroupID)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUserID")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(keys, 1)
exists, err := s.repo.ExistsByKey(s.ctx, "sk-test-1")
s.Require().NoError(err, "ExistsByKey")
s.Require().True(exists, "expected key to exist")
found, err := s.repo.SearchApiKeys(s.ctx, user.ID, "renam", 10)
s.Require().NoError(err, "SearchApiKeys")
s.Require().Len(found, 1)
s.Require().Equal(key.ID, found[0].ID)
// ClearGroupIDByGroupID
k2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{
UserID: user.ID,
Key: "sk-test-2",
Name: "Group Key",
GroupID: &group.ID,
})
countBefore, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID")
s.Require().Equal(int64(1), countBefore, "expected 1 key in group before clear")
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ClearGroupIDByGroupID")
s.Require().Equal(int64(1), affected, "expected 1 affected row")
got3, err := s.repo.GetByID(s.ctx, k2.ID)
s.Require().NoError(err, "GetByID")
s.Require().Nil(got3.GroupID, "expected GroupID cleared")
countAfter, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID after clear")
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
}

View File

@@ -0,0 +1,283 @@
//go:build integration
package repository
import (
"context"
"fmt"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type BillingCacheSuite struct {
IntegrationRedisSuite
}
func (s *BillingCacheSuite) TestUserBalance() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache)
}{
{
name: "missing_key_returns_redis_nil",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
_, err := cache.GetUserBalance(ctx, 1)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing balance key")
},
},
{
name: "deduct_on_nonexistent_is_noop",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
userID := int64(1)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 1), "DeductUserBalance should not error")
_, err := rdb.Get(ctx, balanceKey).Result()
require.ErrorIs(s.T(), err, redis.Nil, "expected missing key after deduct on non-existent")
},
},
{
name: "set_and_get_with_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
userID := int64(2)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
got, err := cache.GetUserBalance(ctx, userID)
require.NoError(s.T(), err, "GetUserBalance")
require.Equal(s.T(), 10.5, got, "balance mismatch")
ttl, err := rdb.TTL(ctx, balanceKey).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
},
},
{
name: "deduct_reduces_balance",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
userID := int64(3)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 2.25), "DeductUserBalance")
got, err := cache.GetUserBalance(ctx, userID)
require.NoError(s.T(), err, "GetUserBalance after deduct")
require.Equal(s.T(), 8.25, got, "deduct mismatch")
},
},
{
name: "invalidate_removes_key",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
userID := int64(100)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 50.0), "SetUserBalance")
exists, err := rdb.Exists(ctx, balanceKey).Result()
require.NoError(s.T(), err, "Exists")
require.Equal(s.T(), int64(1), exists, "expected balance key to exist")
require.NoError(s.T(), cache.InvalidateUserBalance(ctx, userID), "InvalidateUserBalance")
exists, err = rdb.Exists(ctx, balanceKey).Result()
require.NoError(s.T(), err, "Exists after invalidate")
require.Equal(s.T(), int64(0), exists, "expected balance key to be removed after invalidate")
_, err = cache.GetUserBalance(ctx, userID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
},
},
{
name: "deduct_refreshes_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
userID := int64(103)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 100.0), "SetUserBalance")
ttl1, err := rdb.TTL(ctx, balanceKey).Result()
require.NoError(s.T(), err, "TTL before deduct")
s.AssertTTLWithin(ttl1, 1*time.Second, billingCacheTTL)
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 25.0), "DeductUserBalance")
balance, err := cache.GetUserBalance(ctx, userID)
require.NoError(s.T(), err, "GetUserBalance")
require.Equal(s.T(), 75.0, balance, "expected balance 75.0")
ttl2, err := rdb.TTL(ctx, balanceKey).Result()
require.NoError(s.T(), err, "TTL after deduct")
s.AssertTTLWithin(ttl2, 1*time.Second, billingCacheTTL)
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
rdb := testRedis(s.T())
cache := NewBillingCache(rdb)
ctx := context.Background()
tt.fn(ctx, rdb, cache)
})
}
}
func (s *BillingCacheSuite) TestSubscriptionCache() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache)
}{
{
name: "missing_key_returns_redis_nil",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
userID := int64(10)
groupID := int64(20)
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing subscription key")
},
},
{
name: "update_usage_on_nonexistent_is_noop",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
userID := int64(11)
groupID := int64(21)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 1.0), "UpdateSubscriptionUsage should not error")
exists, err := rdb.Exists(ctx, subKey).Result()
require.NoError(s.T(), err, "Exists")
require.Equal(s.T(), int64(0), exists, "expected missing subscription key after UpdateSubscriptionUsage on non-existent")
},
},
{
name: "set_and_get_with_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
userID := int64(12)
groupID := int64(22)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
data := &ports.SubscriptionCacheData{
Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour),
DailyUsage: 1.0,
WeeklyUsage: 2.0,
MonthlyUsage: 3.0,
Version: 7,
}
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
require.NoError(s.T(), err, "GetSubscriptionCache")
require.Equal(s.T(), "active", gotSub.Status)
require.Equal(s.T(), int64(7), gotSub.Version)
require.Equal(s.T(), 1.0, gotSub.DailyUsage)
ttl, err := rdb.TTL(ctx, subKey).Result()
require.NoError(s.T(), err, "TTL subKey")
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
},
},
{
name: "update_usage_increments_all_fields",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
userID := int64(13)
groupID := int64(23)
data := &ports.SubscriptionCacheData{
Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour),
DailyUsage: 1.0,
WeeklyUsage: 2.0,
MonthlyUsage: 3.0,
Version: 1,
}
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 0.5), "UpdateSubscriptionUsage")
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
require.NoError(s.T(), err, "GetSubscriptionCache after update")
require.Equal(s.T(), 1.5, gotSub.DailyUsage)
require.Equal(s.T(), 2.5, gotSub.WeeklyUsage)
require.Equal(s.T(), 3.5, gotSub.MonthlyUsage)
},
},
{
name: "invalidate_removes_key",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
userID := int64(101)
groupID := int64(10)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
data := &ports.SubscriptionCacheData{
Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour),
DailyUsage: 1.0,
WeeklyUsage: 2.0,
MonthlyUsage: 3.0,
Version: 1,
}
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
exists, err := rdb.Exists(ctx, subKey).Result()
require.NoError(s.T(), err, "Exists")
require.Equal(s.T(), int64(1), exists, "expected subscription key to exist")
require.NoError(s.T(), cache.InvalidateSubscriptionCache(ctx, userID, groupID), "InvalidateSubscriptionCache")
exists, err = rdb.Exists(ctx, subKey).Result()
require.NoError(s.T(), err, "Exists after invalidate")
require.Equal(s.T(), int64(0), exists, "expected subscription key to be removed after invalidate")
_, err = cache.GetSubscriptionCache(ctx, userID, groupID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
},
},
{
name: "missing_status_returns_parsing_error",
fn: func(ctx context.Context, rdb *redis.Client, cache ports.BillingCache) {
userID := int64(102)
groupID := int64(11)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
fields := map[string]any{
"expires_at": time.Now().Add(1 * time.Hour).Unix(),
"daily_usage": 1.0,
"weekly_usage": 2.0,
"monthly_usage": 3.0,
"version": 1,
}
require.NoError(s.T(), rdb.HSet(ctx, subKey, fields).Err(), "HSet")
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
require.Error(s.T(), err, "expected error for missing status field")
require.NotErrorIs(s.T(), err, redis.Nil, "expected parsing error, not redis.Nil")
require.Equal(s.T(), "invalid cache: missing status", err.Error())
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
rdb := testRedis(s.T())
cache := NewBillingCache(rdb)
ctx := context.Background()
tt.fn(ctx, rdb, cache)
})
}
}
func TestBillingCacheSuite(t *testing.T) {
suite.Run(t, new(BillingCacheSuite))
}

View File

@@ -16,20 +16,28 @@ import (
"github.com/imroc/req/v3"
)
type claudeOAuthService struct{}
func NewClaudeOAuthClient() service.ClaudeOAuthClient {
return &claudeOAuthService{}
return &claudeOAuthService{
baseURL: "https://claude.ai",
tokenURL: oauth.TokenURL,
clientFactory: createReqClient,
}
}
type claudeOAuthService struct {
baseURL string
tokenURL string
clientFactory func(proxyURL string) *req.Client
}
func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
client := createReqClient(proxyURL)
client := s.clientFactory(proxyURL)
var orgs []struct {
UUID string `json:"uuid"`
}
targetURL := "https://claude.ai/api/organizations"
targetURL := s.baseURL + "/api/organizations"
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
resp, err := client.R().
@@ -61,9 +69,9 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
}
func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
client := createReqClient(proxyURL)
client := s.clientFactory(proxyURL)
authURL := fmt.Sprintf("https://claude.ai/v1/oauth/%s/authorize", orgUUID)
authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID)
reqBody := map[string]any{
"response_type": "code",
@@ -133,12 +141,12 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
fullCode = authCode + "#" + responseState
}
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", authCode[:20])
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code: %s...", prefix(authCode, 20))
return fullCode, nil
}
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string) (*oauth.TokenResponse, error) {
client := createReqClient(proxyURL)
client := s.clientFactory(proxyURL)
// Parse code which may contain state in format "authCode#state"
authCode := code
@@ -161,7 +169,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
}
reqBodyJSON, _ := json.Marshal(reqBody)
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", oauth.TokenURL)
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
var tokenResp oauth.TokenResponse
@@ -171,7 +179,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
SetHeader("Content-Type", "application/json").
SetBody(reqBody).
SetSuccessResult(&tokenResp).
Post(oauth.TokenURL)
Post(s.tokenURL)
if err != nil {
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
@@ -189,7 +197,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
}
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
client := createReqClient(proxyURL)
client := s.clientFactory(proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
@@ -202,7 +210,7 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(oauth.TokenURL)
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
@@ -226,3 +234,13 @@ func createReqClient(proxyURL string) *req.Client {
return client
}
func prefix(s string, n int) string {
if n <= 0 {
return ""
}
if len(s) <= n {
return s
}
return s[:n]
}

View File

@@ -0,0 +1,343 @@
package repository
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ClaudeOAuthServiceSuite struct {
suite.Suite
srv *httptest.Server
client *claudeOAuthService
}
func (s *ClaudeOAuthServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
// requestCapture holds captured request data for assertions in the main goroutine.
type requestCapture struct {
path string
method string
cookies []*http.Cookie
body []byte
formValues url.Values
bodyJSON map[string]any
contentType string
}
func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
tests := []struct {
name string
handler http.HandlerFunc
wantErr bool
errContain string
wantUUID string
validate func(captured requestCapture)
}{
{
name: "success",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`[{"uuid":"org-1"}]`))
},
wantUUID: "org-1",
validate: func(captured requestCapture) {
require.Equal(s.T(), "/api/organizations", captured.path, "unexpected path")
require.Len(s.T(), captured.cookies, 1, "expected 1 cookie")
require.Equal(s.T(), "sessionKey", captured.cookies[0].Name)
require.Equal(s.T(), "sess", captured.cookies[0].Value)
},
},
{
name: "non_200_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte("unauthorized"))
},
wantErr: true,
errContain: "401",
},
{
name: "invalid_json_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte("not-json"))
},
wantErr: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.path = r.URL.Path
captured.cookies = r.Cookies()
tt.handler(w, r)
}))
defer s.srv.Close()
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.baseURL = s.srv.URL
got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "")
if tt.wantErr {
require.Error(s.T(), err)
if tt.errContain != "" {
require.ErrorContains(s.T(), err, tt.errContain)
}
return
}
require.NoError(s.T(), err)
require.Equal(s.T(), tt.wantUUID, got)
if tt.validate != nil {
tt.validate(captured)
}
})
}
}
func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
tests := []struct {
name string
handler http.HandlerFunc
wantErr bool
wantCode string
validate func(captured requestCapture)
}{
{
name: "parses_redirect_uri",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{
"redirect_uri": oauth.RedirectURI + "?code=AUTH&state=STATE",
})
},
wantCode: "AUTH#STATE",
validate: func(captured requestCapture) {
require.True(s.T(), strings.HasPrefix(captured.path, "/v1/oauth/") && strings.HasSuffix(captured.path, "/authorize"), "unexpected path: %s", captured.path)
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
require.Len(s.T(), captured.cookies, 1, "expected 1 cookie")
require.Equal(s.T(), "sess", captured.cookies[0].Value)
require.Equal(s.T(), "org-1", captured.bodyJSON["organization_uuid"])
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
require.Equal(s.T(), "st", captured.bodyJSON["state"])
},
},
{
name: "missing_code_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{
"redirect_uri": oauth.RedirectURI + "?state=STATE", // no code
})
},
wantErr: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.path = r.URL.Path
captured.method = r.Method
captured.cookies = r.Cookies()
captured.body, _ = io.ReadAll(r.Body)
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r)
}))
defer s.srv.Close()
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.baseURL = s.srv.URL
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "")
if tt.wantErr {
require.Error(s.T(), err)
return
}
require.NoError(s.T(), err)
require.Equal(s.T(), tt.wantCode, code)
if tt.validate != nil {
tt.validate(captured)
}
})
}
}
func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
tests := []struct {
name string
handler http.HandlerFunc
code string
wantErr bool
wantResp *oauth.TokenResponse
validate func(captured requestCapture)
}{
{
name: "sends_state_when_embedded",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
AccessToken: "at",
TokenType: "bearer",
ExpiresIn: 3600,
RefreshToken: "rt",
Scope: "s",
})
},
code: "AUTH#STATE2",
wantResp: &oauth.TokenResponse{
AccessToken: "at",
RefreshToken: "rt",
},
validate: func(captured requestCapture) {
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
require.True(s.T(), strings.HasPrefix(captured.contentType, "application/json"), "unexpected content-type")
require.Equal(s.T(), "AUTH", captured.bodyJSON["code"])
require.Equal(s.T(), "STATE2", captured.bodyJSON["state"])
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
require.Equal(s.T(), "ver", captured.bodyJSON["code_verifier"])
},
},
{
name: "non_200_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte("bad request"))
},
code: "AUTH",
wantErr: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.method = r.Method
captured.contentType = r.Header.Get("Content-Type")
captured.body, _ = io.ReadAll(r.Body)
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r)
}))
defer s.srv.Close()
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.tokenURL = s.srv.URL
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "")
if tt.wantErr {
require.Error(s.T(), err)
return
}
require.NoError(s.T(), err)
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
require.Equal(s.T(), tt.wantResp.RefreshToken, resp.RefreshToken)
if tt.validate != nil {
tt.validate(captured)
}
})
}
}
func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
tests := []struct {
name string
handler http.HandlerFunc
wantErr bool
wantResp *oauth.TokenResponse
validate func(captured requestCapture)
}{
{
name: "sends_form",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{AccessToken: "at2", TokenType: "bearer", ExpiresIn: 3600})
},
wantResp: &oauth.TokenResponse{AccessToken: "at2"},
validate: func(captured requestCapture) {
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
require.Equal(s.T(), "refresh_token", captured.formValues.Get("grant_type"))
require.Equal(s.T(), "rt", captured.formValues.Get("refresh_token"))
require.Equal(s.T(), oauth.ClientID, captured.formValues.Get("client_id"))
},
},
{
name: "non_200_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte("unauthorized"))
},
wantErr: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var captured requestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.method = r.Method
captured.body, _ = io.ReadAll(r.Body)
captured.formValues, _ = url.ParseQuery(string(captured.body))
tt.handler(w, r)
}))
defer s.srv.Close()
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.tokenURL = s.srv.URL
resp, err := s.client.RefreshToken(context.Background(), "rt", "")
if tt.wantErr {
require.Error(s.T(), err)
return
}
require.NoError(s.T(), err)
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
if tt.validate != nil {
tt.validate(captured)
}
})
}
}
func TestClaudeOAuthServiceSuite(t *testing.T) {
suite.Run(t, new(ClaudeOAuthServiceSuite))
}

View File

@@ -12,10 +12,14 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service"
)
type claudeUsageService struct{}
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
type claudeUsageService struct {
usageURL string
}
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
return &claudeUsageService{}
return &claudeUsageService{usageURL: defaultClaudeUsageURL}
}
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
@@ -35,7 +39,7 @@ func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyU
Timeout: 30 * time.Second,
}
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.anthropic.com/api/oauth/usage", nil)
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}

View File

@@ -0,0 +1,105 @@
package repository
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ClaudeUsageServiceSuite struct {
suite.Suite
srv *httptest.Server
fetcher *claudeUsageService
}
func (s *ClaudeUsageServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
// usageRequestCapture holds captured request data for assertions in the main goroutine.
type usageRequestCapture struct {
authorization string
anthropicBeta string
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
var captured usageRequestCapture
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.authorization = r.Header.Get("Authorization")
captured.anthropicBeta = r.Header.Get("anthropic-beta")
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{
"five_hour": {"utilization": 12.5, "resets_at": "2025-01-01T00:00:00Z"},
"seven_day": {"utilization": 34.0, "resets_at": "2025-01-08T00:00:00Z"},
"seven_day_sonnet": {"utilization": 56.0, "resets_at": "2025-01-08T00:00:00Z"}
}`)
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
require.NoError(s.T(), err, "FetchUsage")
require.Equal(s.T(), 12.5, resp.FiveHour.Utilization, "FiveHour utilization mismatch")
require.Equal(s.T(), 34.0, resp.SevenDay.Utilization, "SevenDay utilization mismatch")
require.Equal(s.T(), 56.0, resp.SevenDaySonnet.Utilization, "SevenDaySonnet utilization mismatch")
// Assertions on captured request data
require.Equal(s.T(), "Bearer at", captured.authorization, "Authorization header mismatch")
require.Equal(s.T(), "oauth-2025-04-20", captured.anthropicBeta, "anthropic-beta header mismatch")
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = io.WriteString(w, "nope")
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "status 401")
require.ErrorContains(s.T(), err, "nope")
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-json")
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "decode response failed")
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Never respond - simulate slow server
<-r.Context().Done()
}))
s.fetcher = &claudeUsageService{usageURL: s.srv.URL}
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
_, err := s.fetcher.FetchUsage(ctx, "at", "")
require.Error(s.T(), err, "expected error for cancelled context")
}
func TestClaudeUsageServiceSuite(t *testing.T) {
suite.Run(t, new(ClaudeUsageServiceSuite))
}

View File

@@ -0,0 +1,231 @@
//go:build integration
package repository
import (
"errors"
"fmt"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ConcurrencyCacheSuite struct {
IntegrationRedisSuite
cache ports.ConcurrencyCache
}
func (s *ConcurrencyCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewConcurrencyCache(s.rdb)
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
accountID := int64(10)
reqID1, reqID2, reqID3 := "req1", "req2", "req3"
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID1)
require.NoError(s.T(), err, "AcquireAccountSlot 1")
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID2)
require.NoError(s.T(), err, "AcquireAccountSlot 2")
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID3)
require.NoError(s.T(), err, "AcquireAccountSlot 3")
require.False(s.T(), ok, "expected third acquire to fail")
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err, "GetAccountConcurrency")
require.Equal(s.T(), 2, cur, "concurrency mismatch")
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID1), "ReleaseAccountSlot")
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err, "GetAccountConcurrency after release")
require.Equal(s.T(), 1, cur, "expected 1 after release")
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
accountID := int64(11)
reqID := "req_ttl_test"
slotKey := fmt.Sprintf("%s%d:%s", accountSlotKeyPrefix, accountID, reqID)
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID)
require.NoError(s.T(), err, "AcquireAccountSlot")
require.True(s.T(), ok)
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() {
accountID := int64(12)
reqID := "dup-req"
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Acquiring with same reqID should be idempotent
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
require.NoError(s.T(), err)
require.True(s.T(), ok)
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 1, cur, "expected concurrency=1 (idempotent)")
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_ReleaseIdempotent() {
accountID := int64(13)
reqID := "release-test"
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 1, reqID)
require.NoError(s.T(), err)
require.True(s.T(), ok)
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot")
// Releasing again should not error
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot again")
// Releasing non-existent should not error
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, "non-existent"), "ReleaseAccountSlot non-existent")
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 0, cur)
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_MaxZero() {
accountID := int64(14)
reqID := "max-zero-test"
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 0, reqID)
require.NoError(s.T(), err)
require.False(s.T(), ok, "expected acquire to fail with max=0")
}
func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() {
userID := int64(42)
reqID1, reqID2 := "req1", "req2"
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID1)
require.NoError(s.T(), err, "AcquireUserSlot")
require.True(s.T(), ok)
ok, err = s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID2)
require.NoError(s.T(), err, "AcquireUserSlot 2")
require.False(s.T(), ok, "expected second acquire to fail at max=1")
cur, err := s.cache.GetUserConcurrency(s.ctx, userID)
require.NoError(s.T(), err, "GetUserConcurrency")
require.Equal(s.T(), 1, cur, "expected concurrency=1")
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, reqID1), "ReleaseUserSlot")
// Releasing a non-existent slot should not error
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, "non-existent"), "ReleaseUserSlot non-existent")
cur, err = s.cache.GetUserConcurrency(s.ctx, userID)
require.NoError(s.T(), err, "GetUserConcurrency after release")
require.Equal(s.T(), 0, cur, "expected concurrency=0 after release")
}
func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
userID := int64(200)
reqID := "req_ttl_test"
slotKey := fmt.Sprintf("%s%d:%s", userSlotKeyPrefix, userID, reqID)
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID)
require.NoError(s.T(), err, "AcquireUserSlot")
require.True(s.T(), ok)
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
}
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
userID := int64(20)
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2)
require.NoError(s.T(), err, "IncrementWaitCount 1")
require.True(s.T(), ok)
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
require.NoError(s.T(), err, "IncrementWaitCount 2")
require.True(s.T(), ok)
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
require.NoError(s.T(), err, "IncrementWaitCount 3")
require.False(s.T(), ok, "expected wait increment over max to fail")
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, slotTTL)
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.Equal(s.T(), 1, val, "expected wait count 1")
}
func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
userID := int64(300)
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
// Test decrement on non-existent key - should not error and should not create negative value
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key")
// Verify no key was created or it's not negative
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty")
// Set count to 1, then decrement twice
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 5)
require.NoError(s.T(), err, "IncrementWaitCount")
require.True(s.T(), ok)
// Decrement once (1 -> 0)
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
// Decrement again on 0 - should not go negative
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero")
// Verify count is 0, not negative
val, err = s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey after double decrement")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
}
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
// When no slots exist, GetAccountConcurrency should return 0
cur, err := s.cache.GetAccountConcurrency(s.ctx, 999)
require.NoError(s.T(), err)
require.Equal(s.T(), 0, cur)
}
func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
// When no slots exist, GetUserConcurrency should return 0
cur, err := s.cache.GetUserConcurrency(s.ctx, 999)
require.NoError(s.T(), err)
require.Equal(s.T(), 0, cur)
}
func TestConcurrencyCacheSuite(t *testing.T) {
suite.Run(t, new(ConcurrencyCacheSuite))
}

View File

@@ -0,0 +1,92 @@
//go:build integration
package repository
import (
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type EmailCacheSuite struct {
IntegrationRedisSuite
cache ports.EmailCache
}
func (s *EmailCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewEmailCache(s.rdb)
}
func (s *EmailCacheSuite) TestGetVerificationCode_Missing() {
_, err := s.cache.GetVerificationCode(s.ctx, "nonexistent@example.com")
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing verification code")
}
func (s *EmailCacheSuite) TestSetAndGetVerificationCode() {
email := "a@example.com"
emailTTL := 2 * time.Minute
data := &ports.VerificationCodeData{Code: "123456", Attempts: 1, CreatedAt: time.Now()}
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
got, err := s.cache.GetVerificationCode(s.ctx, email)
require.NoError(s.T(), err, "GetVerificationCode")
require.Equal(s.T(), "123456", got.Code)
require.Equal(s.T(), 1, got.Attempts)
}
func (s *EmailCacheSuite) TestVerificationCode_TTL() {
email := "ttl@example.com"
emailTTL := 2 * time.Minute
data := &ports.VerificationCodeData{Code: "654321", Attempts: 0, CreatedAt: time.Now()}
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
emailKey := verifyCodeKeyPrefix + email
ttl, err := s.rdb.TTL(s.ctx, emailKey).Result()
require.NoError(s.T(), err, "TTL emailKey")
s.AssertTTLWithin(ttl, 1*time.Second, emailTTL)
}
func (s *EmailCacheSuite) TestDeleteVerificationCode() {
email := "delete@example.com"
data := &ports.VerificationCodeData{Code: "999999", Attempts: 0, CreatedAt: time.Now()}
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, 2*time.Minute), "SetVerificationCode")
// Verify it exists
_, err := s.cache.GetVerificationCode(s.ctx, email)
require.NoError(s.T(), err, "GetVerificationCode before delete")
// Delete
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, email), "DeleteVerificationCode")
// Verify it's gone
_, err = s.cache.GetVerificationCode(s.ctx, email)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
}
func (s *EmailCacheSuite) TestDeleteVerificationCode_NonExistent() {
// Deleting a non-existent key should not error
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, "nonexistent@example.com"), "DeleteVerificationCode non-existent")
}
func (s *EmailCacheSuite) TestGetVerificationCode_JSONCorruption() {
emailKey := verifyCodeKeyPrefix + "corrupted@example.com"
require.NoError(s.T(), s.rdb.Set(s.ctx, emailKey, "not-json", 1*time.Minute).Err(), "Set invalid JSON")
_, err := s.cache.GetVerificationCode(s.ctx, "corrupted@example.com")
require.Error(s.T(), err, "expected error for corrupted JSON")
require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil")
}
func TestEmailCacheSuite(t *testing.T) {
suite.Run(t, new(EmailCacheSuite))
}

View File

@@ -0,0 +1,172 @@
//go:build integration
package repository
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
)
func mustCreateUser(t *testing.T, db *gorm.DB, u *model.User) *model.User {
t.Helper()
if u.PasswordHash == "" {
u.PasswordHash = "test-password-hash"
}
if u.Role == "" {
u.Role = model.RoleUser
}
if u.Status == "" {
u.Status = model.StatusActive
}
if u.CreatedAt.IsZero() {
u.CreatedAt = time.Now()
}
if u.UpdatedAt.IsZero() {
u.UpdatedAt = u.CreatedAt
}
require.NoError(t, db.Create(u).Error, "create user")
return u
}
func mustCreateGroup(t *testing.T, db *gorm.DB, g *model.Group) *model.Group {
t.Helper()
if g.Platform == "" {
g.Platform = model.PlatformAnthropic
}
if g.Status == "" {
g.Status = model.StatusActive
}
if g.SubscriptionType == "" {
g.SubscriptionType = model.SubscriptionTypeStandard
}
if g.CreatedAt.IsZero() {
g.CreatedAt = time.Now()
}
if g.UpdatedAt.IsZero() {
g.UpdatedAt = g.CreatedAt
}
require.NoError(t, db.Create(g).Error, "create group")
return g
}
func mustCreateProxy(t *testing.T, db *gorm.DB, p *model.Proxy) *model.Proxy {
t.Helper()
if p.Protocol == "" {
p.Protocol = "http"
}
if p.Host == "" {
p.Host = "127.0.0.1"
}
if p.Port == 0 {
p.Port = 8080
}
if p.Status == "" {
p.Status = model.StatusActive
}
if p.CreatedAt.IsZero() {
p.CreatedAt = time.Now()
}
if p.UpdatedAt.IsZero() {
p.UpdatedAt = p.CreatedAt
}
require.NoError(t, db.Create(p).Error, "create proxy")
return p
}
func mustCreateAccount(t *testing.T, db *gorm.DB, a *model.Account) *model.Account {
t.Helper()
if a.Platform == "" {
a.Platform = model.PlatformAnthropic
}
if a.Type == "" {
a.Type = model.AccountTypeOAuth
}
if a.Status == "" {
a.Status = model.StatusActive
}
if !a.Schedulable {
a.Schedulable = true
}
if a.Credentials == nil {
a.Credentials = model.JSONB{}
}
if a.Extra == nil {
a.Extra = model.JSONB{}
}
if a.CreatedAt.IsZero() {
a.CreatedAt = time.Now()
}
if a.UpdatedAt.IsZero() {
a.UpdatedAt = a.CreatedAt
}
require.NoError(t, db.Create(a).Error, "create account")
return a
}
func mustCreateApiKey(t *testing.T, db *gorm.DB, k *model.ApiKey) *model.ApiKey {
t.Helper()
if k.Status == "" {
k.Status = model.StatusActive
}
if k.CreatedAt.IsZero() {
k.CreatedAt = time.Now()
}
if k.UpdatedAt.IsZero() {
k.UpdatedAt = k.CreatedAt
}
require.NoError(t, db.Create(k).Error, "create api key")
return k
}
func mustCreateRedeemCode(t *testing.T, db *gorm.DB, c *model.RedeemCode) *model.RedeemCode {
t.Helper()
if c.Status == "" {
c.Status = model.StatusUnused
}
if c.Type == "" {
c.Type = model.RedeemTypeBalance
}
if c.CreatedAt.IsZero() {
c.CreatedAt = time.Now()
}
require.NoError(t, db.Create(c).Error, "create redeem code")
return c
}
func mustCreateSubscription(t *testing.T, db *gorm.DB, s *model.UserSubscription) *model.UserSubscription {
t.Helper()
if s.Status == "" {
s.Status = model.SubscriptionStatusActive
}
now := time.Now()
if s.StartsAt.IsZero() {
s.StartsAt = now.Add(-1 * time.Hour)
}
if s.ExpiresAt.IsZero() {
s.ExpiresAt = now.Add(24 * time.Hour)
}
if s.AssignedAt.IsZero() {
s.AssignedAt = now
}
if s.CreatedAt.IsZero() {
s.CreatedAt = now
}
if s.UpdatedAt.IsZero() {
s.UpdatedAt = now
}
require.NoError(t, db.Create(s).Error, "create user subscription")
return s
}
func mustBindAccountToGroup(t *testing.T, db *gorm.DB, accountID, groupID int64, priority int) {
t.Helper()
require.NoError(t, db.Create(&model.AccountGroup{
AccountID: accountID,
GroupID: groupID,
Priority: priority,
}).Error, "create account_group")
}

View File

@@ -0,0 +1,92 @@
//go:build integration
package repository
import (
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type GatewayCacheSuite struct {
IntegrationRedisSuite
cache ports.GatewayCache
}
func (s *GatewayCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewGatewayCache(s.rdb)
}
func (s *GatewayCacheSuite) TestGetSessionAccountID_Missing() {
_, err := s.cache.GetSessionAccountID(s.ctx, "nonexistent")
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing session")
}
func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() {
sessionID := "s1"
accountID := int64(99)
sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID")
sid, err := s.cache.GetSessionAccountID(s.ctx, sessionID)
require.NoError(s.T(), err, "GetSessionAccountID")
require.Equal(s.T(), accountID, sid, "session id mismatch")
}
func (s *GatewayCacheSuite) TestSessionAccountID_TTL() {
sessionID := "s2"
accountID := int64(100)
sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID")
sessionKey := stickySessionPrefix + sessionID
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
require.NoError(s.T(), err, "TTL sessionKey after Set")
s.AssertTTLWithin(ttl, 1*time.Second, sessionTTL)
}
func (s *GatewayCacheSuite) TestRefreshSessionTTL() {
sessionID := "s3"
accountID := int64(101)
initialTTL := 1 * time.Minute
refreshTTL := 3 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, initialTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, sessionID, refreshTTL), "RefreshSessionTTL")
sessionKey := stickySessionPrefix + sessionID
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
require.NoError(s.T(), err, "TTL after Refresh")
s.AssertTTLWithin(ttl, 1*time.Second, refreshTTL)
}
func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
// RefreshSessionTTL on a missing key should not error (no-op)
err := s.cache.RefreshSessionTTL(s.ctx, "missing-session", 1*time.Minute)
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
}
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
sessionID := "corrupted"
sessionKey := stickySessionPrefix + sessionID
// Set a non-integer value
require.NoError(s.T(), s.rdb.Set(s.ctx, sessionKey, "not-a-number", 1*time.Minute).Err(), "Set invalid value")
_, err := s.cache.GetSessionAccountID(s.ctx, sessionID)
require.Error(s.T(), err, "expected error for corrupted value")
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
}
func TestGatewayCacheSuite(t *testing.T) {
suite.Run(t, new(GatewayCacheSuite))
}

View File

@@ -0,0 +1,328 @@
package repository
import (
"bytes"
"context"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type GitHubReleaseServiceSuite struct {
suite.Suite
srv *httptest.Server
client *githubReleaseClient
tempDir string
}
// testTransport redirects requests to the test server
type testTransport struct {
testServerURL string
}
func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Rewrite the URL to point to our test server
testURL := t.testServerURL + req.URL.Path
newReq, err := http.NewRequestWithContext(req.Context(), req.Method, testURL, req.Body)
if err != nil {
return nil, err
}
newReq.Header = req.Header
return http.DefaultTransport.RoundTrip(newReq)
}
func (s *GitHubReleaseServiceSuite) SetupTest() {
s.tempDir = s.T().TempDir()
}
func (s *GitHubReleaseServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLength() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", "100")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(bytes.Repeat([]byte("a"), 100))
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "file1.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
require.Error(s.T(), err, "expected error for oversized download with Content-Length")
_, statErr := os.Stat(dest)
require.Error(s.T(), statErr, "expected file to not exist for rejected download")
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Force chunked encoding (unknown Content-Length) by flushing headers before writing.
w.WriteHeader(http.StatusOK)
if fl, ok := w.(http.Flusher); ok {
fl.Flush()
}
for i := 0; i < 10; i++ {
_, _ = w.Write(bytes.Repeat([]byte("b"), 10))
if fl, ok := w.(http.Flusher); ok {
fl.Flush()
}
}
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "file2.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
require.Error(s.T(), err, "expected error for oversized chunked download")
_, statErr := os.Stat(dest)
require.Error(s.T(), statErr, "expected file to be cleaned up for oversized chunked download")
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
if fl, ok := w.(http.Flusher); ok {
fl.Flush()
}
for i := 0; i < 10; i++ {
_, _ = w.Write(bytes.Repeat([]byte("b"), 10))
if fl, ok := w.(http.Flusher); ok {
fl.Flush()
}
}
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "file3.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 200)
require.NoError(s.T(), err, "expected success")
b, err := os.ReadFile(dest)
require.NoError(s.T(), err, "read")
require.True(s.T(), strings.HasPrefix(string(b), "b"), "downloaded content should start with 'b'")
require.Len(s.T(), b, 100, "downloaded content length mismatch")
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "notfound.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
require.Error(s.T(), err, "expected error for 404")
_, statErr := os.Stat(dest)
require.Error(s.T(), statErr, "expected file to not exist for 404")
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("sum"))
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
body, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
require.NoError(s.T(), err, "FetchChecksumFile")
require.Equal(s.T(), "sum", string(body), "checksum body mismatch")
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
_, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
require.Error(s.T(), err, "expected error for non-200")
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done()
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
ctx, cancel := context.WithCancel(context.Background())
cancel()
dest := filepath.Join(s.tempDir, "cancelled.bin")
err := s.client.DownloadFile(ctx, s.srv.URL, dest, 100)
require.Error(s.T(), err, "expected error for cancelled context")
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
dest := filepath.Join(s.tempDir, "invalid.bin")
err := s.client.DownloadFile(context.Background(), "://invalid-url", dest, 100)
require.Error(s.T(), err, "expected error for invalid URL")
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("content"))
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
// Use a path that cannot be created (directory doesn't exist)
dest := filepath.Join(s.tempDir, "nonexistent", "subdir", "file.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
require.Error(s.T(), err, "expected error for invalid destination path")
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_InvalidURL() {
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
_, err := s.client.FetchChecksumFile(context.Background(), "://invalid-url")
require.Error(s.T(), err, "expected error for invalid URL")
}
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
releaseJSON := `{
"tag_name": "v1.0.0",
"name": "Release 1.0.0",
"body": "Release notes",
"html_url": "https://github.com/test/repo/releases/v1.0.0",
"assets": [
{
"name": "app-linux-amd64.tar.gz",
"browser_download_url": "https://github.com/test/repo/releases/download/v1.0.0/app-linux-amd64.tar.gz"
}
]
}`
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(s.T(), "/repos/test/repo/releases/latest", r.URL.Path)
require.Equal(s.T(), "application/vnd.github.v3+json", r.Header.Get("Accept"))
require.Equal(s.T(), "Sub2API-Updater", r.Header.Get("User-Agent"))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(releaseJSON))
}))
// Use custom transport to redirect requests to test server
s.client = &githubReleaseClient{
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
}
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
require.NoError(s.T(), err)
require.Equal(s.T(), "v1.0.0", release.TagName)
require.Equal(s.T(), "Release 1.0.0", release.Name)
require.Len(s.T(), release.Assets, 1)
require.Equal(s.T(), "app-linux-amd64.tar.gz", release.Assets[0].Name)
}
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
s.client = &githubReleaseClient{
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
}
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
require.Error(s.T(), err)
require.Contains(s.T(), err.Error(), "404")
}
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("not valid json"))
}))
s.client = &githubReleaseClient{
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
}
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
require.Error(s.T(), err)
}
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done()
}))
s.client = &githubReleaseClient{
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := s.client.FetchLatestRelease(ctx, "test/repo")
require.Error(s.T(), err)
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
s.srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done()
}))
client, ok := NewGitHubReleaseClient().(*githubReleaseClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := s.client.FetchChecksumFile(ctx, s.srv.URL)
require.Error(s.T(), err)
}
func TestGitHubReleaseServiceSuite(t *testing.T) {
suite.Run(t, new(GitHubReleaseServiceSuite))
}

View File

@@ -0,0 +1,244 @@
//go:build integration
package repository
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type GroupRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *GroupRepository
}
func (s *GroupRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewGroupRepository(s.db)
}
func TestGroupRepoSuite(t *testing.T) {
suite.Run(t, new(GroupRepoSuite))
}
// --- Create / GetByID / Update / Delete ---
func (s *GroupRepoSuite) TestCreate() {
group := &model.Group{
Name: "test-create",
Platform: model.PlatformAnthropic,
Status: model.StatusActive,
}
err := s.repo.Create(s.ctx, group)
s.Require().NoError(err, "Create")
s.Require().NotZero(group.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, group.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("test-create", got.Name)
}
func (s *GroupRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *GroupRepoSuite) TestUpdate() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "original"})
group.Name = "updated"
err := s.repo.Update(s.ctx, group)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, group.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated", got.Name)
}
func (s *GroupRepoSuite) TestDelete() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "to-delete"})
err := s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, group.ID)
s.Require().Error(err, "expected error after delete")
}
// --- List / ListWithFilters ---
func (s *GroupRepoSuite) TestList() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1"})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2"})
groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
s.Require().Len(groups, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *GroupRepoSuite) TestListWithFilters_Platform() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI})
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformOpenAI, "", nil)
s.Require().NoError(err)
s.Require().Len(groups, 1)
s.Require().Equal(model.PlatformOpenAI, groups[0].Platform)
}
func (s *GroupRepoSuite) TestListWithFilters_Status() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Status: model.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Status: model.StatusDisabled})
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, nil)
s.Require().NoError(err)
s.Require().Len(groups, 1)
s.Require().Equal(model.StatusDisabled, groups[0].Status)
}
func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", IsExclusive: false})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", IsExclusive: true})
isExclusive := true
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive)
s.Require().NoError(err)
s.Require().Len(groups, 1)
s.Require().True(groups[0].IsExclusive)
}
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
g1 := mustCreateGroup(s.T(), s.db, &model.Group{
Name: "g1",
Platform: model.PlatformAnthropic,
Status: model.StatusActive,
})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{
Name: "g2",
Platform: model.PlatformAnthropic,
Status: model.StatusActive,
IsExclusive: true,
})
a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc1"})
mustBindAccountToGroup(s.T(), s.db, a.ID, g1.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a.ID, g2.ID, 1)
isExclusive := true
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.PlatformAnthropic, model.StatusActive, &isExclusive)
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(groups, 1)
s.Require().Equal(g2.ID, groups[0].ID, "ListWithFilters returned wrong group")
s.Require().Equal(int64(1), groups[0].AccountCount, "AccountCount mismatch")
}
// --- ListActive / ListActiveByPlatform ---
func (s *GroupRepoSuite) TestListActive() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "active1", Status: model.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "inactive1", Status: model.StatusDisabled})
groups, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
s.Require().Len(groups, 1)
s.Require().Equal("active1", groups[0].Name)
}
func (s *GroupRepoSuite) TestListActiveByPlatform() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g1", Platform: model.PlatformAnthropic, Status: model.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g2", Platform: model.PlatformOpenAI, Status: model.StatusActive})
mustCreateGroup(s.T(), s.db, &model.Group{Name: "g3", Platform: model.PlatformAnthropic, Status: model.StatusDisabled})
groups, err := s.repo.ListActiveByPlatform(s.ctx, model.PlatformAnthropic)
s.Require().NoError(err, "ListActiveByPlatform")
s.Require().Len(groups, 1)
s.Require().Equal("g1", groups[0].Name)
}
// --- ExistsByName ---
func (s *GroupRepoSuite) TestExistsByName() {
mustCreateGroup(s.T(), s.db, &model.Group{Name: "existing-group"})
exists, err := s.repo.ExistsByName(s.ctx, "existing-group")
s.Require().NoError(err, "ExistsByName")
s.Require().True(exists)
notExists, err := s.repo.ExistsByName(s.ctx, "non-existing")
s.Require().NoError(err)
s.Require().False(notExists)
}
// --- GetAccountCount ---
func (s *GroupRepoSuite) TestGetAccountCount() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"})
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"})
mustBindAccountToGroup(s.T(), s.db, a1.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, group.ID, 2)
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err, "GetAccountCount")
s.Require().Equal(int64(2), count)
}
func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-empty"})
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err)
s.Require().Zero(count)
}
// --- DeleteAccountGroupsByGroupID ---
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-del"})
a := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-del"})
mustBindAccountToGroup(s.T(), s.db, a.ID, g.ID, 1)
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
s.Require().NoError(err, "DeleteAccountGroupsByGroupID")
s.Require().Equal(int64(1), affected, "expected 1 affected row")
count, err := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().NoError(err, "GetAccountCount")
s.Require().Equal(int64(0), count, "expected 0 account groups")
}
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
g := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-multi"})
a1 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1"})
a2 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2"})
a3 := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"})
mustBindAccountToGroup(s.T(), s.db, a1.ID, g.ID, 1)
mustBindAccountToGroup(s.T(), s.db, a2.ID, g.ID, 2)
mustBindAccountToGroup(s.T(), s.db, a3.ID, g.ID, 3)
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
s.Require().NoError(err)
s.Require().Equal(int64(3), affected)
count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().Zero(count)
}
// --- DB ---
func (s *GroupRepoSuite) TestDB() {
db := s.repo.DB()
s.Require().NotNil(db, "DB should return non-nil")
s.Require().Equal(s.db, db, "DB should return the underlying gorm.DB")
}

View File

@@ -0,0 +1,115 @@
package repository
import (
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type HTTPUpstreamSuite struct {
suite.Suite
cfg *config.Config
}
func (s *HTTPUpstreamSuite) SetupTest() {
s.cfg = &config.Config{}
}
func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() {
up := NewHTTPUpstream(s.cfg)
svc, ok := up.(*httpUpstreamService)
require.True(s.T(), ok, "expected *httpUpstreamService")
transport, ok := svc.defaultClient.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 300*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
}
func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() {
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 7}
up := NewHTTPUpstream(s.cfg)
svc, ok := up.(*httpUpstreamService)
require.True(s.T(), ok, "expected *httpUpstreamService")
transport, ok := svc.defaultClient.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
}
func (s *HTTPUpstreamSuite) TestCreateProxyClient_InvalidURLFallsBackToDefault() {
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 5}
up := NewHTTPUpstream(s.cfg)
svc, ok := up.(*httpUpstreamService)
require.True(s.T(), ok, "expected *httpUpstreamService")
got := svc.createProxyClient("://bad-proxy-url")
require.Equal(s.T(), svc.defaultClient, got, "expected defaultClient fallback")
}
func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "direct")
}))
s.T().Cleanup(upstream.Close)
up := NewHTTPUpstream(s.cfg)
req, err := http.NewRequest(http.MethodGet, upstream.URL+"/x", nil)
require.NoError(s.T(), err, "NewRequest")
resp, err := up.Do(req, "")
require.NoError(s.T(), err, "Do")
defer func() { _ = resp.Body.Close() }()
b, _ := io.ReadAll(resp.Body)
require.Equal(s.T(), "direct", string(b), "unexpected body")
}
func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
seen := make(chan string, 1)
proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seen <- r.RequestURI
_, _ = io.WriteString(w, "proxied")
}))
s.T().Cleanup(proxySrv.Close)
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 1}
up := NewHTTPUpstream(s.cfg)
req, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil)
require.NoError(s.T(), err, "NewRequest")
resp, err := up.Do(req, proxySrv.URL)
require.NoError(s.T(), err, "Do")
defer func() { _ = resp.Body.Close() }()
b, _ := io.ReadAll(resp.Body)
require.Equal(s.T(), "proxied", string(b), "unexpected body")
select {
case uri := <-seen:
require.Equal(s.T(), "http://example.com/test", uri, "expected absolute-form request URI")
default:
require.Fail(s.T(), "expected proxy to receive request")
}
}
func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "direct-empty")
}))
s.T().Cleanup(upstream.Close)
up := NewHTTPUpstream(s.cfg)
req, err := http.NewRequest(http.MethodGet, upstream.URL+"/y", nil)
require.NoError(s.T(), err, "NewRequest")
resp, err := up.Do(req, "")
require.NoError(s.T(), err, "Do with empty proxy")
defer func() { _ = resp.Body.Close() }()
b, _ := io.ReadAll(resp.Body)
require.Equal(s.T(), "direct-empty", string(b))
}
func TestHTTPUpstreamSuite(t *testing.T) {
suite.Run(t, new(HTTPUpstreamSuite))
}

View File

@@ -0,0 +1,67 @@
//go:build integration
package repository
import (
"errors"
"fmt"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type IdentityCacheSuite struct {
IntegrationRedisSuite
cache *identityCache
}
func (s *IdentityCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewIdentityCache(s.rdb).(*identityCache)
}
func (s *IdentityCacheSuite) TestGetFingerprint_Missing() {
_, err := s.cache.GetFingerprint(s.ctx, 1)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing fingerprint")
}
func (s *IdentityCacheSuite) TestSetAndGetFingerprint() {
fp := &ports.Fingerprint{ClientID: "c1", UserAgent: "ua"}
require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 1, fp), "SetFingerprint")
gotFP, err := s.cache.GetFingerprint(s.ctx, 1)
require.NoError(s.T(), err, "GetFingerprint")
require.Equal(s.T(), "c1", gotFP.ClientID)
require.Equal(s.T(), "ua", gotFP.UserAgent)
}
func (s *IdentityCacheSuite) TestFingerprint_TTL() {
fp := &ports.Fingerprint{ClientID: "c1", UserAgent: "ua"}
require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 2, fp))
fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 2)
ttl, err := s.rdb.TTL(s.ctx, fpKey).Result()
require.NoError(s.T(), err, "TTL fpKey")
s.AssertTTLWithin(ttl, 1*time.Second, fingerprintTTL)
}
func (s *IdentityCacheSuite) TestGetFingerprint_JSONCorruption() {
fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 999)
require.NoError(s.T(), s.rdb.Set(s.ctx, fpKey, "invalid-json-data", 1*time.Minute).Err(), "Set invalid JSON")
_, err := s.cache.GetFingerprint(s.ctx, 999)
require.Error(s.T(), err, "expected error for corrupted JSON")
require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil")
}
func (s *IdentityCacheSuite) TestSetFingerprint_Nil() {
err := s.cache.SetFingerprint(s.ctx, 100, nil)
require.NoError(s.T(), err, "SetFingerprint(nil) should succeed")
}
func TestIdentityCacheSuite(t *testing.T) {
suite.Run(t, new(IdentityCacheSuite))
}

View File

@@ -0,0 +1,369 @@
//go:build integration
package repository
import (
"context"
"database/sql"
"fmt"
"log"
"os"
"os/exec"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
redisclient "github.com/redis/go-redis/v9"
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
gormpostgres "gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
const (
redisImageTag = "redis:8.4-alpine"
postgresImageTag = "postgres:18.1-alpine3.23"
)
var (
integrationDB *gorm.DB
integrationRedis *redisclient.Client
redisNamespaceSeq uint64
)
func TestMain(m *testing.M) {
ctx := context.Background()
if err := timezone.Init("UTC"); err != nil {
log.Printf("failed to init timezone: %v", err)
os.Exit(1)
}
if !dockerIsAvailable(ctx) {
// In CI we expect Docker to be available so integration tests should fail loudly.
if os.Getenv("CI") != "" {
log.Printf("docker is not available (CI=true); failing integration tests")
os.Exit(1)
}
log.Printf("docker is not available; skipping integration tests (start Docker to enable)")
os.Exit(0)
}
postgresImage := selectDockerImage(ctx, postgresImageTag)
pgContainer, err := tcpostgres.Run(
ctx,
postgresImage,
tcpostgres.WithDatabase("sub2api_test"),
tcpostgres.WithUsername("postgres"),
tcpostgres.WithPassword("postgres"),
tcpostgres.BasicWaitStrategies(),
)
if err != nil {
log.Printf("failed to start postgres container: %v", err)
os.Exit(1)
}
defer func() { _ = pgContainer.Terminate(ctx) }()
redisContainer, err := tcredis.Run(
ctx,
redisImageTag,
)
if err != nil {
log.Printf("failed to start redis container: %v", err)
os.Exit(1)
}
defer func() { _ = redisContainer.Terminate(ctx) }()
dsn, err := pgContainer.ConnectionString(ctx, "sslmode=disable", "TimeZone=UTC")
if err != nil {
log.Printf("failed to get postgres dsn: %v", err)
os.Exit(1)
}
integrationDB, err = openGormWithRetry(ctx, dsn, 30*time.Second)
if err != nil {
log.Printf("failed to open gorm db: %v", err)
os.Exit(1)
}
if err := model.AutoMigrate(integrationDB); err != nil {
log.Printf("failed to automigrate db: %v", err)
os.Exit(1)
}
redisHost, err := redisContainer.Host(ctx)
if err != nil {
log.Printf("failed to get redis host: %v", err)
os.Exit(1)
}
redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp")
if err != nil {
log.Printf("failed to get redis port: %v", err)
os.Exit(1)
}
integrationRedis = redisclient.NewClient(&redisclient.Options{
Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()),
DB: 0,
})
if err := integrationRedis.Ping(ctx).Err(); err != nil {
log.Printf("failed to ping redis: %v", err)
os.Exit(1)
}
code := m.Run()
_ = integrationRedis.Close()
os.Exit(code)
}
func dockerIsAvailable(ctx context.Context) bool {
cmd := exec.CommandContext(ctx, "docker", "info")
cmd.Env = os.Environ()
return cmd.Run() == nil
}
func selectDockerImage(ctx context.Context, preferred string) string {
if dockerImageExists(ctx, preferred) {
return preferred
}
return preferred
}
func dockerImageExists(ctx context.Context, image string) bool {
cmd := exec.CommandContext(ctx, "docker", "image", "inspect", image)
cmd.Env = os.Environ()
cmd.Stdout = nil
cmd.Stderr = nil
return cmd.Run() == nil
}
func openGormWithRetry(ctx context.Context, dsn string, timeout time.Duration) (*gorm.DB, error) {
deadline := time.Now().Add(timeout)
var lastErr error
for time.Now().Before(deadline) {
db, err := gorm.Open(gormpostgres.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
lastErr = err
time.Sleep(250 * time.Millisecond)
continue
}
sqlDB, err := db.DB()
if err != nil {
lastErr = err
time.Sleep(250 * time.Millisecond)
continue
}
if err := pingWithTimeout(ctx, sqlDB, 2*time.Second); err != nil {
lastErr = err
time.Sleep(250 * time.Millisecond)
continue
}
return db, nil
}
return nil, fmt.Errorf("db not ready after %s: %w", timeout, lastErr)
}
func pingWithTimeout(ctx context.Context, db *sql.DB, timeout time.Duration) error {
pingCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
return db.PingContext(pingCtx)
}
func testTx(t *testing.T) *gorm.DB {
t.Helper()
tx := integrationDB.Begin()
require.NoError(t, tx.Error, "begin tx")
t.Cleanup(func() {
_ = tx.Rollback().Error
})
return tx
}
func testRedis(t *testing.T) *redisclient.Client {
t.Helper()
prefix := fmt.Sprintf(
"it:%s:%d:%d:",
sanitizeRedisNamespace(t.Name()),
time.Now().UnixNano(),
atomic.AddUint64(&redisNamespaceSeq, 1),
)
opts := *integrationRedis.Options()
rdb := redisclient.NewClient(&opts)
rdb.AddHook(prefixHook{prefix: prefix})
t.Cleanup(func() {
ctx := context.Background()
var cursor uint64
for {
keys, nextCursor, err := integrationRedis.Scan(ctx, cursor, prefix+"*", 500).Result()
require.NoError(t, err, "scan redis keys for cleanup")
if len(keys) > 0 {
require.NoError(t, integrationRedis.Unlink(ctx, keys...).Err(), "unlink redis keys for cleanup")
}
cursor = nextCursor
if cursor == 0 {
break
}
}
_ = rdb.Close()
})
return rdb
}
func assertTTLWithin(t *testing.T, ttl time.Duration, min, max time.Duration) {
t.Helper()
require.GreaterOrEqual(t, ttl, min, "ttl should be >= min")
require.LessOrEqual(t, ttl, max, "ttl should be <= max")
}
func sanitizeRedisNamespace(name string) string {
name = strings.ReplaceAll(name, "/", "_")
name = strings.ReplaceAll(name, " ", "_")
return name
}
type prefixHook struct {
prefix string
}
func (h prefixHook) DialHook(next redisclient.DialHook) redisclient.DialHook { return next }
func (h prefixHook) ProcessHook(next redisclient.ProcessHook) redisclient.ProcessHook {
return func(ctx context.Context, cmd redisclient.Cmder) error {
h.prefixCmd(cmd)
return next(ctx, cmd)
}
}
func (h prefixHook) ProcessPipelineHook(next redisclient.ProcessPipelineHook) redisclient.ProcessPipelineHook {
return func(ctx context.Context, cmds []redisclient.Cmder) error {
for _, cmd := range cmds {
h.prefixCmd(cmd)
}
return next(ctx, cmds)
}
}
func (h prefixHook) prefixCmd(cmd redisclient.Cmder) {
args := cmd.Args()
if len(args) < 2 {
return
}
prefixOne := func(i int) {
if i < 0 || i >= len(args) {
return
}
switch v := args[i].(type) {
case string:
if v != "" && !strings.HasPrefix(v, h.prefix) {
args[i] = h.prefix + v
}
case []byte:
s := string(v)
if s != "" && !strings.HasPrefix(s, h.prefix) {
args[i] = []byte(h.prefix + s)
}
}
}
switch strings.ToLower(cmd.Name()) {
case "get", "set", "setnx", "setex", "psetex", "incr", "decr", "incrby", "expire", "pexpire", "ttl", "pttl",
"hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists":
prefixOne(1)
case "del", "unlink":
for i := 1; i < len(args); i++ {
prefixOne(i)
}
case "eval", "evalsha", "eval_ro", "evalsha_ro":
if len(args) < 3 {
return
}
numKeys, err := strconv.Atoi(fmt.Sprint(args[2]))
if err != nil || numKeys <= 0 {
return
}
for i := 0; i < numKeys && 3+i < len(args); i++ {
prefixOne(3 + i)
}
case "scan":
for i := 2; i+1 < len(args); i++ {
if strings.EqualFold(fmt.Sprint(args[i]), "match") {
prefixOne(i + 1)
break
}
}
}
}
// IntegrationRedisSuite provides a base suite for Redis integration tests.
// Embedding suites should call SetupTest to initialize ctx and rdb.
type IntegrationRedisSuite struct {
suite.Suite
ctx context.Context
rdb *redisclient.Client
}
// SetupTest initializes ctx and rdb for each test method.
func (s *IntegrationRedisSuite) SetupTest() {
s.ctx = context.Background()
s.rdb = testRedis(s.T())
}
// RequireNoError is a convenience method wrapping require.NoError with s.T().
func (s *IntegrationRedisSuite) RequireNoError(err error, msgAndArgs ...any) {
s.T().Helper()
require.NoError(s.T(), err, msgAndArgs...)
}
// AssertTTLWithin asserts that ttl is within [min, max].
func (s *IntegrationRedisSuite) AssertTTLWithin(ttl, min, max time.Duration) {
s.T().Helper()
assertTTLWithin(s.T(), ttl, min, max)
}
// IntegrationDBSuite provides a base suite for DB (Gorm) integration tests.
// Embedding suites should call SetupTest to initialize ctx and db.
type IntegrationDBSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
}
// SetupTest initializes ctx and db for each test method.
func (s *IntegrationDBSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
}
// RequireNoError is a convenience method wrapping require.NoError with s.T().
func (s *IntegrationDBSuite) RequireNoError(err error, msgAndArgs ...any) {
s.T().Helper()
require.NoError(s.T(), err, msgAndArgs...)
}

View File

@@ -12,11 +12,13 @@ import (
"github.com/imroc/req/v3"
)
type openaiOAuthService struct{}
// NewOpenAIOAuthClient creates a new OpenAI OAuth client
func NewOpenAIOAuthClient() ports.OpenAIOAuthClient {
return &openaiOAuthService{}
return &openaiOAuthService{tokenURL: openai.TokenURL}
}
type openaiOAuthService struct {
tokenURL string
}
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
@@ -39,7 +41,7 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(openai.TokenURL)
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
@@ -67,7 +69,7 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(openai.TokenURL)
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)

View File

@@ -0,0 +1,249 @@
package repository
import (
"context"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type OpenAIOAuthServiceSuite struct {
suite.Suite
ctx context.Context
srv *httptest.Server
svc *openaiOAuthService
received chan url.Values
}
func (s *OpenAIOAuthServiceSuite) SetupTest() {
s.ctx = context.Background()
s.received = make(chan url.Values, 1)
}
func (s *OpenAIOAuthServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
func (s *OpenAIOAuthServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler)
s.svc = &openaiOAuthService{tokenURL: s.srv.URL}
}
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() {
errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
errCh <- "method mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if err := r.ParseForm(); err != nil {
errCh <- "ParseForm failed"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("grant_type"); got != "authorization_code" {
errCh <- "grant_type mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("client_id"); got != openai.ClientID {
errCh <- "client_id mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("code"); got != "code" {
errCh <- "code mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("redirect_uri"); got != openai.DefaultRedirectURI {
errCh <- "redirect_uri mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("code_verifier"); got != "ver" {
errCh <- "code_verifier mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`)
}))
resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case msg := <-errCh:
require.Fail(s.T(), msg)
default:
}
require.Equal(s.T(), "at", resp.AccessToken)
require.Equal(s.T(), "rt", resp.RefreshToken)
}
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
errCh <- "ParseForm failed"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("grant_type"); got != "refresh_token" {
errCh <- "grant_type mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("refresh_token"); got != "rt" {
errCh <- "refresh_token mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("client_id"); got != openai.ClientID {
errCh <- "client_id mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("scope"); got != openai.RefreshScopes {
errCh <- "scope mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at2","refresh_token":"rt2","token_type":"bearer","expires_in":3600}`)
}))
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
require.NoError(s.T(), err, "RefreshToken")
select {
case msg := <-errCh:
require.Fail(s.T(), msg)
default:
}
require.Equal(s.T(), "at2", resp.AccessToken)
require.Equal(s.T(), "rt2", resp.RefreshToken)
}
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = io.WriteString(w, "bad")
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "status 400")
require.ErrorContains(s.T(), err, "bad")
}
func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
s.srv.Close()
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "request failed")
}
func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
started := make(chan struct{})
block := make(chan struct{})
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(started)
<-block
}))
ctx, cancel := context.WithCancel(s.ctx)
done := make(chan error, 1)
go func() {
_, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "")
done <- err
}()
<-started
cancel()
close(block)
err := <-done
require.Error(s.T(), err)
}
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
want := "http://localhost:9999/cb"
errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseForm()
if got := r.PostForm.Get("redirect_uri"); got != want {
errCh <- "redirect_uri mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case msg := <-errCh:
require.Fail(s.T(), msg)
default:
}
}
func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseForm()
s.received <- r.PostForm
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
}))
s.svc.tokenURL = s.srv.URL + "?x=1"
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case <-s.received:
default:
require.Fail(s.T(), "expected server to receive request")
}
}
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "not-valid-json")
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
require.Error(s.T(), err, "expected error for invalid JSON response")
}
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = io.WriteString(w, "unauthorized")
}))
_, err := s.svc.RefreshToken(s.ctx, "rt", "")
require.Error(s.T(), err, "expected error for non-2xx status")
require.ErrorContains(s.T(), err, "status 401")
}
func TestOpenAIOAuthServiceSuite(t *testing.T) {
suite.Run(t, new(OpenAIOAuthServiceSuite))
}

View File

@@ -0,0 +1,147 @@
package repository
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type PricingServiceSuite struct {
suite.Suite
ctx context.Context
srv *httptest.Server
client *pricingRemoteClient
}
func (s *PricingServiceSuite) SetupTest() {
s.ctx = context.Background()
client, ok := NewPricingRemoteClient().(*pricingRemoteClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
}
func (s *PricingServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
func (s *PricingServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler)
}
func (s *PricingServiceSuite) TestFetchPricingJSON_Success() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/ok" {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"ok":true}`))
return
}
w.WriteHeader(http.StatusInternalServerError)
}))
body, err := s.client.FetchPricingJSON(s.ctx, s.srv.URL+"/ok")
require.NoError(s.T(), err, "FetchPricingJSON")
require.Equal(s.T(), `{"ok":true}`, string(body), "body mismatch")
}
func (s *PricingServiceSuite) TestFetchPricingJSON_NonOKStatus() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
_, err := s.client.FetchPricingJSON(s.ctx, s.srv.URL+"/err")
require.Error(s.T(), err, "expected error for non-200 status")
}
func (s *PricingServiceSuite) TestFetchHashText_ParsesFields() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/hashfile":
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("abc123 model_prices.json\n"))
case "/hashonly":
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("def456\n"))
default:
w.WriteHeader(http.StatusNotFound)
}
}))
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/hashfile")
require.NoError(s.T(), err, "FetchHashText")
require.Equal(s.T(), "abc123", hash, "hash mismatch")
hash2, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/hashonly")
require.NoError(s.T(), err, "FetchHashText")
require.Equal(s.T(), "def456", hash2, "hash mismatch")
}
func (s *PricingServiceSuite) TestFetchHashText_NonOKStatus() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
_, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/nope")
require.Error(s.T(), err, "expected error for non-200 status")
}
func (s *PricingServiceSuite) TestFetchPricingJSON_InvalidURL() {
_, err := s.client.FetchPricingJSON(s.ctx, "://invalid-url")
require.Error(s.T(), err, "expected error for invalid URL")
}
func (s *PricingServiceSuite) TestFetchHashText_EmptyBody() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
// empty body
}))
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/empty")
require.NoError(s.T(), err, "FetchHashText empty body should not error")
require.Equal(s.T(), "", hash, "expected empty hash")
}
func (s *PricingServiceSuite) TestFetchHashText_WhitespaceOnly() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(" \n"))
}))
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/ws")
require.NoError(s.T(), err, "FetchHashText whitespace body should not error")
require.Equal(s.T(), "", hash, "expected empty hash after trimming")
}
func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() {
started := make(chan struct{})
block := make(chan struct{})
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(started)
<-block
}))
ctx, cancel := context.WithCancel(s.ctx)
done := make(chan error, 1)
go func() {
_, err := s.client.FetchPricingJSON(ctx, s.srv.URL+"/block")
done <- err
}()
<-started
cancel()
close(block)
err := <-done
require.Error(s.T(), err)
}
func TestPricingServiceSuite(t *testing.T) {
suite.Run(t, new(PricingServiceSuite))
}

View File

@@ -16,10 +16,14 @@ import (
"golang.org/x/net/proxy"
)
type proxyProbeService struct{}
func NewProxyExitInfoProber() service.ProxyExitInfoProber {
return &proxyProbeService{}
return &proxyProbeService{ipInfoURL: defaultIPInfoURL}
}
const defaultIPInfoURL = "https://ipinfo.io/json"
type proxyProbeService struct {
ipInfoURL string
}
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
@@ -34,7 +38,7 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
}
startTime := time.Now()
req, err := http.NewRequestWithContext(ctx, "GET", "https://ipinfo.io/json", nil)
req, err := http.NewRequestWithContext(ctx, "GET", s.ipInfoURL, nil)
if err != nil {
return nil, 0, fmt.Errorf("failed to create request: %w", err)
}

View File

@@ -0,0 +1,121 @@
package repository
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ProxyProbeServiceSuite struct {
suite.Suite
ctx context.Context
proxySrv *httptest.Server
prober *proxyProbeService
}
func (s *ProxyProbeServiceSuite) SetupTest() {
s.ctx = context.Background()
s.prober = &proxyProbeService{ipInfoURL: "http://ipinfo.test/json"}
}
func (s *ProxyProbeServiceSuite) TearDownTest() {
if s.proxySrv != nil {
s.proxySrv.Close()
s.proxySrv = nil
}
}
func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) {
s.proxySrv = httptest.NewServer(handler)
}
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_InvalidURL() {
_, err := createProxyTransport("://bad")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "invalid proxy URL")
}
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_UnsupportedScheme() {
_, err := createProxyTransport("ftp://127.0.0.1:1")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "unsupported proxy protocol")
}
func (s *ProxyProbeServiceSuite) TestCreateProxyTransport_Socks5SetsDialer() {
tr, err := createProxyTransport("socks5://127.0.0.1:1080")
require.NoError(s.T(), err, "createProxyTransport")
require.NotNil(s.T(), tr.DialContext, "expected DialContext to be set for socks5")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
seen := make(chan string, 1)
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seen <- r.RequestURI
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"ip":"1.2.3.4","city":"c","region":"r","country":"cc"}`)
}))
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.NoError(s.T(), err, "ProbeProxy")
require.GreaterOrEqual(s.T(), latencyMs, int64(0), "unexpected latency")
require.Equal(s.T(), "1.2.3.4", info.IP)
require.Equal(s.T(), "c", info.City)
require.Equal(s.T(), "r", info.Region)
require.Equal(s.T(), "cc", info.Country)
// Verify proxy received the request
select {
case uri := <-seen:
require.Contains(s.T(), uri, "ipinfo.test", "expected request to go through proxy")
default:
require.Fail(s.T(), "expected proxy to receive request")
}
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_NonOKStatus() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "status: 503")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidJSON() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-json")
}))
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "failed to parse response")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidIPInfoURL() {
s.prober.ipInfoURL = "://invalid-url"
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err, "expected error for invalid ipInfoURL")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
s.proxySrv.Close()
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err, "expected error when proxy server is closed")
}
func TestProxyProbeServiceSuite(t *testing.T) {
suite.Run(t, new(ProxyProbeServiceSuite))
}

View File

@@ -0,0 +1,302 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type ProxyRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *ProxyRepository
}
func (s *ProxyRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewProxyRepository(s.db)
}
func TestProxyRepoSuite(t *testing.T) {
suite.Run(t, new(ProxyRepoSuite))
}
// --- Create / GetByID / Update / Delete ---
func (s *ProxyRepoSuite) TestCreate() {
proxy := &model.Proxy{
Name: "test-create",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Status: model.StatusActive,
}
err := s.repo.Create(s.ctx, proxy)
s.Require().NoError(err, "Create")
s.Require().NotZero(proxy.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, proxy.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("test-create", got.Name)
}
func (s *ProxyRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *ProxyRepoSuite) TestUpdate() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "original"})
proxy.Name = "updated"
err := s.repo.Update(s.ctx, proxy)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, proxy.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated", got.Name)
}
func (s *ProxyRepoSuite) TestDelete() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "to-delete"})
err := s.repo.Delete(s.ctx, proxy.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, proxy.ID)
s.Require().Error(err, "expected error after delete")
}
// --- List / ListWithFilters ---
func (s *ProxyRepoSuite) TestList() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"})
proxies, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
s.Require().Len(proxies, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Protocol: "http"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Protocol: "socks5"})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "socks5", "", "")
s.Require().NoError(err)
s.Require().Len(proxies, 1)
s.Require().Equal("socks5", proxies[0].Protocol)
}
func (s *ProxyRepoSuite) TestListWithFilters_Status() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1", Status: model.StatusActive})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2", Status: model.StatusDisabled})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusDisabled, "")
s.Require().NoError(err)
s.Require().Len(proxies, 1)
s.Require().Equal(model.StatusDisabled, proxies[0].Status)
}
func (s *ProxyRepoSuite) TestListWithFilters_Search() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "production-proxy"})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "dev-proxy"})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "prod")
s.Require().NoError(err)
s.Require().Len(proxies, 1)
s.Require().Contains(proxies[0].Name, "production")
}
// --- ListActive ---
func (s *ProxyRepoSuite) TestListActive() {
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "active1", Status: model.StatusActive})
mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "inactive1", Status: model.StatusDisabled})
proxies, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
s.Require().Len(proxies, 1)
s.Require().Equal("active1", proxies[0].Name)
}
// --- ExistsByHostPortAuth ---
func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p1",
Protocol: "http",
Host: "1.2.3.4",
Port: 8080,
Username: "user",
Password: "pass",
})
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "user", "pass")
s.Require().NoError(err, "ExistsByHostPortAuth")
s.Require().True(exists)
notExists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "wrong", "creds")
s.Require().NoError(err)
s.Require().False(notExists)
}
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p-noauth",
Protocol: "http",
Host: "5.6.7.8",
Port: 8081,
Username: "",
Password: "",
})
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "5.6.7.8", 8081, "", "")
s.Require().NoError(err)
s.Require().True(exists)
}
// --- CountAccountsByProxyID ---
func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-count"})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &proxy.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &proxy.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3"}) // no proxy
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
s.Require().NoError(err, "CountAccountsByProxyID")
s.Require().Equal(int64(2), count)
}
func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
proxy := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p-zero"})
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
s.Require().NoError(err)
s.Require().Zero(count)
}
// --- GetAccountCountsForProxies ---
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies() {
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p1"})
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{Name: "p2"})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
s.Require().NoError(err, "GetAccountCountsForProxies")
s.Require().Equal(int64(2), counts[p1.ID])
s.Require().Equal(int64(1), counts[p2.ID])
}
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies_Empty() {
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
s.Require().NoError(err)
s.Require().Empty(counts)
}
// --- ListActiveWithAccountCount ---
func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p1",
Status: model.StatusActive,
CreatedAt: base.Add(-1 * time.Hour),
})
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p2",
Status: model.StatusActive,
CreatedAt: base,
})
mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p3-inactive",
Status: model.StatusDisabled,
})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
s.Require().NoError(err, "ListActiveWithAccountCount")
s.Require().Len(withCounts, 2, "expected 2 active proxies")
// Sorted by created_at DESC, so p2 first
s.Require().Equal(p2.ID, withCounts[0].ID)
s.Require().Equal(int64(1), withCounts[0].AccountCount)
s.Require().Equal(p1.ID, withCounts[1].ID)
s.Require().Equal(int64(2), withCounts[1].AccountCount)
}
// --- Combined original test ---
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
p1 := mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p1",
Protocol: "http",
Host: "1.2.3.4",
Port: 8080,
Username: "u",
Password: "p",
CreatedAt: time.Now().Add(-1 * time.Hour),
UpdatedAt: time.Now().Add(-1 * time.Hour),
})
p2 := mustCreateProxy(s.T(), s.db, &model.Proxy{
Name: "p2",
Protocol: "http",
Host: "5.6.7.8",
Port: 8081,
Username: "",
Password: "",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
})
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "u", "p")
s.Require().NoError(err, "ExistsByHostPortAuth")
s.Require().True(exists, "expected proxy to exist")
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a1", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a2", ProxyID: &p1.ID})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a3", ProxyID: &p2.ID})
count1, err := s.repo.CountAccountsByProxyID(s.ctx, p1.ID)
s.Require().NoError(err, "CountAccountsByProxyID")
s.Require().Equal(int64(2), count1, "expected 2 accounts for p1")
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
s.Require().NoError(err, "GetAccountCountsForProxies")
s.Require().Equal(int64(2), counts[p1.ID])
s.Require().Equal(int64(1), counts[p2.ID])
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
s.Require().NoError(err, "ListActiveWithAccountCount")
s.Require().Len(withCounts, 2, "expected 2 proxies")
for _, pc := range withCounts {
switch pc.ID {
case p1.ID:
s.Require().Equal(int64(2), pc.AccountCount, "p1 count mismatch")
case p2.ID:
s.Require().Equal(int64(1), pc.AccountCount, "p2 count mismatch")
default:
s.Require().Fail("unexpected proxy id", pc.ID)
}
}
}

View File

@@ -0,0 +1,105 @@
//go:build integration
package repository
import (
"errors"
"fmt"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type RedeemCacheSuite struct {
IntegrationRedisSuite
cache *redeemCache
}
func (s *RedeemCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewRedeemCache(s.rdb).(*redeemCache)
}
func (s *RedeemCacheSuite) TestGetRedeemAttemptCount_Missing() {
missingUserID := int64(99999)
_, err := s.cache.GetRedeemAttemptCount(s.ctx, missingUserID)
require.Error(s.T(), err, "expected redis.Nil for missing rate-limit key")
require.True(s.T(), errors.Is(err, redis.Nil))
}
func (s *RedeemCacheSuite) TestIncrementAndGetRedeemAttemptCount() {
userID := int64(1)
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID), "IncrementRedeemAttemptCount")
count, err := s.cache.GetRedeemAttemptCount(s.ctx, userID)
require.NoError(s.T(), err, "GetRedeemAttemptCount")
require.Equal(s.T(), 1, count, "count mismatch")
ttl, err := s.rdb.TTL(s.ctx, key).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, redeemRateLimitDuration)
}
func (s *RedeemCacheSuite) TestMultipleIncrements() {
userID := int64(2)
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID))
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID))
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID))
count, err := s.cache.GetRedeemAttemptCount(s.ctx, userID)
require.NoError(s.T(), err)
require.Equal(s.T(), 3, count, "count after 3 increments")
}
func (s *RedeemCacheSuite) TestAcquireAndReleaseRedeemLock() {
ok, err := s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second)
require.NoError(s.T(), err, "AcquireRedeemLock")
require.True(s.T(), ok)
// Second acquire should fail
ok, err = s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second)
require.NoError(s.T(), err, "AcquireRedeemLock 2")
require.False(s.T(), ok, "expected lock to be held")
// Release
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "CODE"), "ReleaseRedeemLock")
// Now acquire should succeed
ok, err = s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second)
require.NoError(s.T(), err, "AcquireRedeemLock after release")
require.True(s.T(), ok)
}
func (s *RedeemCacheSuite) TestAcquireRedeemLock_TTL() {
lockKey := redeemLockKeyPrefix + "CODE2"
lockTTL := 15 * time.Second
ok, err := s.cache.AcquireRedeemLock(s.ctx, "CODE2", lockTTL)
require.NoError(s.T(), err, "AcquireRedeemLock CODE2")
require.True(s.T(), ok)
ttl, err := s.rdb.TTL(s.ctx, lockKey).Result()
require.NoError(s.T(), err, "TTL lock key")
s.AssertTTLWithin(ttl, 1*time.Second, lockTTL)
}
func (s *RedeemCacheSuite) TestReleaseRedeemLock_Idempotent() {
// Release a lock that doesn't exist should not error
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "NONEXISTENT"))
// Acquire, release, release again
ok, err := s.cache.AcquireRedeemLock(s.ctx, "IDEMPOTENT", 10*time.Second)
require.NoError(s.T(), err)
require.True(s.T(), ok)
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "IDEMPOTENT"))
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "IDEMPOTENT"), "second release should be idempotent")
}
func TestRedeemCacheSuite(t *testing.T) {
suite.Run(t, new(RedeemCacheSuite))
}

View File

@@ -0,0 +1,315 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type RedeemCodeRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *RedeemCodeRepository
}
func (s *RedeemCodeRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewRedeemCodeRepository(s.db)
}
func TestRedeemCodeRepoSuite(t *testing.T) {
suite.Run(t, new(RedeemCodeRepoSuite))
}
// --- Create / CreateBatch / GetByID / GetByCode ---
func (s *RedeemCodeRepoSuite) TestCreate() {
code := &model.RedeemCode{
Code: "TEST-CREATE",
Type: model.RedeemTypeBalance,
Value: 100,
Status: model.StatusUnused,
}
err := s.repo.Create(s.ctx, code)
s.Require().NoError(err, "Create")
s.Require().NotZero(code.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, code.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("TEST-CREATE", got.Code)
}
func (s *RedeemCodeRepoSuite) TestCreateBatch() {
codes := []model.RedeemCode{
{Code: "BATCH-1", Type: model.RedeemTypeBalance, Value: 10, Status: model.StatusUnused},
{Code: "BATCH-2", Type: model.RedeemTypeBalance, Value: 20, Status: model.StatusUnused},
}
err := s.repo.CreateBatch(s.ctx, codes)
s.Require().NoError(err, "CreateBatch")
got1, err := s.repo.GetByCode(s.ctx, "BATCH-1")
s.Require().NoError(err)
s.Require().Equal(float64(10), got1.Value)
got2, err := s.repo.GetByCode(s.ctx, "BATCH-2")
s.Require().NoError(err)
s.Require().Equal(float64(20), got2.Value)
}
func (s *RedeemCodeRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *RedeemCodeRepoSuite) TestGetByCode() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "GET-BY-CODE", Type: model.RedeemTypeBalance})
got, err := s.repo.GetByCode(s.ctx, "GET-BY-CODE")
s.Require().NoError(err, "GetByCode")
s.Require().Equal("GET-BY-CODE", got.Code)
}
func (s *RedeemCodeRepoSuite) TestGetByCode_NotFound() {
_, err := s.repo.GetByCode(s.ctx, "NON-EXISTENT")
s.Require().Error(err, "expected error for non-existent code")
}
// --- Delete ---
func (s *RedeemCodeRepoSuite) TestDelete() {
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TO-DELETE", Type: model.RedeemTypeBalance})
err := s.repo.Delete(s.ctx, code.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, code.ID)
s.Require().Error(err, "expected error after delete")
}
// --- List / ListWithFilters ---
func (s *RedeemCodeRepoSuite) TestList() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-1", Type: model.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "LIST-2", Type: model.RedeemTypeBalance})
codes, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
s.Require().Len(codes, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_Type() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-BAL", Type: model.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "TYPE-SUB", Type: model.RedeemTypeSubscription})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, "", "")
s.Require().NoError(err)
s.Require().Len(codes, 1)
s.Require().Equal(model.RedeemTypeSubscription, codes[0].Type)
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-UNUSED", Type: model.RedeemTypeBalance, Status: model.StatusUnused})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "STAT-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.StatusUsed, "")
s.Require().NoError(err)
s.Require().Len(codes, 1)
s.Require().Equal(model.StatusUsed, codes[0].Status)
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALPHA-CODE", Type: model.RedeemTypeBalance})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "BETA-CODE", Type: model.RedeemTypeBalance})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alpha")
s.Require().NoError(err)
s.Require().Len(codes, 1)
s.Require().Contains(codes[0].Code, "ALPHA")
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() {
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"})
mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
Code: "WITH-GROUP",
Type: model.RedeemTypeSubscription,
GroupID: &group.ID,
})
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "")
s.Require().NoError(err)
s.Require().Len(codes, 1)
s.Require().NotNil(codes[0].Group, "expected Group preload")
s.Require().Equal(group.ID, codes[0].Group.ID)
}
// --- Update ---
func (s *RedeemCodeRepoSuite) TestUpdate() {
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "UPDATE-ME", Type: model.RedeemTypeBalance, Value: 10})
code.Value = 50
err := s.repo.Update(s.ctx, code)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, code.ID)
s.Require().NoError(err)
s.Require().Equal(float64(50), got.Value)
}
// --- Use ---
func (s *RedeemCodeRepoSuite) TestUse() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "use@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "USE-ME", Type: model.RedeemTypeBalance, Status: model.StatusUnused})
err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().NoError(err, "Use")
got, err := s.repo.GetByID(s.ctx, code.ID)
s.Require().NoError(err)
s.Require().Equal(model.StatusUsed, got.Status)
s.Require().NotNil(got.UsedBy)
s.Require().Equal(user.ID, *got.UsedBy)
s.Require().NotNil(got.UsedAt)
}
func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "idem@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "IDEM-CODE", Type: model.RedeemTypeBalance, Status: model.StatusUnused})
err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().NoError(err, "Use first time")
// Second use should fail
err = s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().Error(err, "Use expected error on second call")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
}
func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "already@test.com"})
code := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{Code: "ALREADY-USED", Type: model.RedeemTypeBalance, Status: model.StatusUsed})
err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().Error(err, "expected error for already used code")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
}
// --- ListByUser ---
func (s *RedeemCodeRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"})
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
// Create codes with explicit used_at for ordering
c1 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
Code: "USER-1",
Type: model.RedeemTypeBalance,
Status: model.StatusUsed,
UsedBy: &user.ID,
})
s.db.Model(c1).Update("used_at", base)
c2 := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
Code: "USER-2",
Type: model.RedeemTypeBalance,
Status: model.StatusUsed,
UsedBy: &user.ID,
})
s.db.Model(c2).Update("used_at", base.Add(1*time.Hour))
codes, err := s.repo.ListByUser(s.ctx, user.ID, 10)
s.Require().NoError(err, "ListByUser")
s.Require().Len(codes, 2)
// Ordered by used_at DESC, so USER-2 first
s.Require().Equal("USER-2", codes[0].Code)
s.Require().Equal("USER-1", codes[1].Code)
}
func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "grp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listby"})
c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
Code: "WITH-GRP",
Type: model.RedeemTypeSubscription,
Status: model.StatusUsed,
UsedBy: &user.ID,
GroupID: &group.ID,
})
s.db.Model(c).Update("used_at", time.Now())
codes, err := s.repo.ListByUser(s.ctx, user.ID, 10)
s.Require().NoError(err)
s.Require().Len(codes, 1)
s.Require().NotNil(codes[0].Group)
s.Require().Equal(group.ID, codes[0].Group.ID)
}
func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "deflimit@test.com"})
c := mustCreateRedeemCode(s.T(), s.db, &model.RedeemCode{
Code: "DEF-LIM",
Type: model.RedeemTypeBalance,
Status: model.StatusUsed,
UsedBy: &user.ID,
})
s.db.Model(c).Update("used_at", time.Now())
// limit <= 0 should default to 10
codes, err := s.repo.ListByUser(s.ctx, user.ID, 0)
s.Require().NoError(err)
s.Require().Len(codes, 1)
}
// --- Combined original test ---
func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "rc@example.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-rc"})
codes := []model.RedeemCode{
{Code: "CODEA", Type: model.RedeemTypeBalance, Value: 1, Status: model.StatusUnused, CreatedAt: time.Now()},
{Code: "CODEB", Type: model.RedeemTypeSubscription, Value: 0, Status: model.StatusUnused, GroupID: &group.ID, ValidityDays: 7, CreatedAt: time.Now()},
}
s.Require().NoError(s.repo.CreateBatch(s.ctx, codes), "CreateBatch")
list, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.RedeemTypeSubscription, model.StatusUnused, "code")
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(list, 1)
s.Require().NotNil(list[0].Group, "expected Group preload")
s.Require().Equal(group.ID, list[0].Group.ID)
codeB, err := s.repo.GetByCode(s.ctx, "CODEB")
s.Require().NoError(err, "GetByCode")
s.Require().NoError(s.repo.Use(s.ctx, codeB.ID, user.ID), "Use")
err = s.repo.Use(s.ctx, codeB.ID, user.ID)
s.Require().Error(err, "Use expected error on second call")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
codeA, err := s.repo.GetByCode(s.ctx, "CODEA")
s.Require().NoError(err, "GetByCode")
// Use fixed time instead of time.Sleep for deterministic ordering
s.db.Model(&model.RedeemCode{}).Where("id = ?", codeB.ID).Update("used_at", time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC))
s.Require().NoError(s.repo.Use(s.ctx, codeA.ID, user.ID), "Use codeA")
s.db.Model(&model.RedeemCode{}).Where("id = ?", codeA.ID).Update("used_at", time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC))
used, err := s.repo.ListByUser(s.ctx, user.ID, 10)
s.Require().NoError(err, "ListByUser")
s.Require().Len(used, 2, "expected 2 used codes")
s.Require().Equal("CODEA", used[0].Code, "expected newest used code first")
}

View File

@@ -0,0 +1,108 @@
//go:build integration
package repository
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type SettingRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *SettingRepository
}
func (s *SettingRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewSettingRepository(s.db)
}
func TestSettingRepoSuite(t *testing.T) {
suite.Run(t, new(SettingRepoSuite))
}
func (s *SettingRepoSuite) TestSetAndGetValue() {
s.Require().NoError(s.repo.Set(s.ctx, "k1", "v1"), "Set")
got, err := s.repo.GetValue(s.ctx, "k1")
s.Require().NoError(err, "GetValue")
s.Require().Equal("v1", got, "GetValue mismatch")
}
func (s *SettingRepoSuite) TestSet_Upsert() {
s.Require().NoError(s.repo.Set(s.ctx, "k1", "v1"), "Set")
s.Require().NoError(s.repo.Set(s.ctx, "k1", "v2"), "Set upsert")
got, err := s.repo.GetValue(s.ctx, "k1")
s.Require().NoError(err, "GetValue after upsert")
s.Require().Equal("v2", got, "upsert mismatch")
}
func (s *SettingRepoSuite) TestGetValue_Missing() {
_, err := s.repo.GetValue(s.ctx, "nonexistent")
s.Require().Error(err, "expected error for missing key")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
}
func (s *SettingRepoSuite) TestSetMultiple_AndGetMultiple() {
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"k2": "v2", "k3": "v3"}), "SetMultiple")
m, err := s.repo.GetMultiple(s.ctx, []string{"k2", "k3"})
s.Require().NoError(err, "GetMultiple")
s.Require().Equal("v2", m["k2"])
s.Require().Equal("v3", m["k3"])
}
func (s *SettingRepoSuite) TestGetMultiple_EmptyKeys() {
m, err := s.repo.GetMultiple(s.ctx, []string{})
s.Require().NoError(err, "GetMultiple with empty keys")
s.Require().Empty(m, "expected empty map")
}
func (s *SettingRepoSuite) TestGetMultiple_Subset() {
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"a": "1", "b": "2", "c": "3"}))
m, err := s.repo.GetMultiple(s.ctx, []string{"a", "c", "nonexistent"})
s.Require().NoError(err, "GetMultiple subset")
s.Require().Equal("1", m["a"])
s.Require().Equal("3", m["c"])
_, exists := m["nonexistent"]
s.Require().False(exists, "nonexistent key should not be in map")
}
func (s *SettingRepoSuite) TestGetAll() {
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"x": "1", "y": "2"}))
all, err := s.repo.GetAll(s.ctx)
s.Require().NoError(err, "GetAll")
s.Require().GreaterOrEqual(len(all), 2, "expected at least 2 settings")
s.Require().Equal("1", all["x"])
s.Require().Equal("2", all["y"])
}
func (s *SettingRepoSuite) TestDelete() {
s.Require().NoError(s.repo.Set(s.ctx, "todelete", "val"))
s.Require().NoError(s.repo.Delete(s.ctx, "todelete"), "Delete")
_, err := s.repo.GetValue(s.ctx, "todelete")
s.Require().Error(err, "expected missing key error after Delete")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
}
func (s *SettingRepoSuite) TestDelete_Idempotent() {
// Delete a key that doesn't exist should not error
s.Require().NoError(s.repo.Delete(s.ctx, "nonexistent_delete"), "Delete nonexistent should be idempotent")
}
func (s *SettingRepoSuite) TestSetMultiple_Upsert() {
s.Require().NoError(s.repo.Set(s.ctx, "upsert_key", "old_value"))
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"upsert_key": "new_value", "new_key": "new_val"}))
got, err := s.repo.GetValue(s.ctx, "upsert_key")
s.Require().NoError(err)
s.Require().Equal("new_value", got, "SetMultiple should upsert existing key")
got2, err := s.repo.GetValue(s.ctx, "new_key")
s.Require().NoError(err)
s.Require().Equal("new_val", got2)
}

View File

@@ -16,6 +16,7 @@ const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/sitev
type turnstileVerifier struct {
httpClient *http.Client
verifyURL string
}
func NewTurnstileVerifier() service.TurnstileVerifier {
@@ -23,6 +24,7 @@ func NewTurnstileVerifier() service.TurnstileVerifier {
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
verifyURL: turnstileVerifyURL,
}
}
@@ -34,7 +36,7 @@ func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, r
formData.Set("remoteip", remoteIP)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, turnstileVerifyURL, strings.NewReader(formData.Encode()))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, v.verifyURL, strings.NewReader(formData.Encode()))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}

View File

@@ -0,0 +1,143 @@
package repository
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type TurnstileServiceSuite struct {
suite.Suite
ctx context.Context
srv *httptest.Server
verifier *turnstileVerifier
received chan url.Values
}
func (s *TurnstileServiceSuite) SetupTest() {
s.ctx = context.Background()
s.received = make(chan url.Values, 1)
verifier, ok := NewTurnstileVerifier().(*turnstileVerifier)
require.True(s.T(), ok, "type assertion failed")
s.verifier = verifier
}
func (s *TurnstileServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
func (s *TurnstileServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = httptest.NewServer(handler)
s.verifier.verifyURL = s.srv.URL
}
func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Capture form data in main goroutine context later
body, _ := io.ReadAll(r.Body)
values, _ := url.ParseQuery(string(body))
s.received <- values
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
}))
resp, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.NoError(s.T(), err, "VerifyToken")
require.NotNil(s.T(), resp)
require.True(s.T(), resp.Success, "expected success response")
// Assert form fields in main goroutine
select {
case values := <-s.received:
require.Equal(s.T(), "sk", values.Get("secret"))
require.Equal(s.T(), "token", values.Get("response"))
require.Equal(s.T(), "1.1.1.1", values.Get("remoteip"))
default:
require.Fail(s.T(), "expected server to receive request")
}
}
func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() {
var contentType string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
contentType = r.Header.Get("Content-Type")
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
}))
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.NoError(s.T(), err)
require.True(s.T(), strings.HasPrefix(contentType, "application/x-www-form-urlencoded"), "unexpected content-type: %s", contentType)
}
func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
values, _ := url.ParseQuery(string(body))
s.received <- values
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
}))
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "")
require.NoError(s.T(), err)
select {
case values := <-s.received:
require.Equal(s.T(), "", values.Get("remoteip"), "remoteip should be empty or not sent")
default:
require.Fail(s.T(), "expected server to receive request")
}
}
func (s *TurnstileServiceSuite) TestVerifyToken_RequestError() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
s.srv.Close()
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.Error(s.T(), err, "expected error when server is closed")
}
func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-valid-json")
}))
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.Error(s.T(), err, "expected error for invalid JSON response")
}
func (s *TurnstileServiceSuite) TestVerifyToken_SuccessFalse() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{
Success: false,
ErrorCodes: []string{"invalid-input-response"},
})
}))
resp, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.NoError(s.T(), err, "VerifyToken should not error on success=false")
require.NotNil(s.T(), resp)
require.False(s.T(), resp.Success)
require.Contains(s.T(), resp.ErrorCodes, "invalid-input-response")
}
func TestTurnstileServiceSuite(t *testing.T) {
suite.Run(t, new(TurnstileServiceSuite))
}

View File

@@ -0,0 +1,73 @@
//go:build integration
package repository
import (
"errors"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type UpdateCacheSuite struct {
IntegrationRedisSuite
cache *updateCache
}
func (s *UpdateCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewUpdateCache(s.rdb).(*updateCache)
}
func (s *UpdateCacheSuite) TestGetUpdateInfo_Missing() {
_, err := s.cache.GetUpdateInfo(s.ctx)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing update info")
}
func (s *UpdateCacheSuite) TestSetAndGetUpdateInfo() {
updateTTL := 5 * time.Minute
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.2.3", updateTTL), "SetUpdateInfo")
info, err := s.cache.GetUpdateInfo(s.ctx)
require.NoError(s.T(), err, "GetUpdateInfo")
require.Equal(s.T(), "v1.2.3", info, "update info mismatch")
}
func (s *UpdateCacheSuite) TestSetUpdateInfo_TTL() {
updateTTL := 5 * time.Minute
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.2.3", updateTTL))
ttl, err := s.rdb.TTL(s.ctx, updateCacheKey).Result()
require.NoError(s.T(), err, "TTL updateCacheKey")
s.AssertTTLWithin(ttl, 1*time.Second, updateTTL)
}
func (s *UpdateCacheSuite) TestSetUpdateInfo_Overwrite() {
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.0.0", 5*time.Minute))
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v2.0.0", 5*time.Minute))
info, err := s.cache.GetUpdateInfo(s.ctx)
require.NoError(s.T(), err)
require.Equal(s.T(), "v2.0.0", info, "expected overwritten value")
}
func (s *UpdateCacheSuite) TestSetUpdateInfo_ZeroTTL() {
// TTL=0 means persist forever (no expiry) in Redis SET command
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v0.0.0", 0))
info, err := s.cache.GetUpdateInfo(s.ctx)
require.NoError(s.T(), err)
require.Equal(s.T(), "v0.0.0", info)
ttl, err := s.rdb.TTL(s.ctx, updateCacheKey).Result()
require.NoError(s.T(), err)
// TTL=-1 means no expiry, TTL=-2 means key doesn't exist
require.Equal(s.T(), time.Duration(-1), ttl, "expected TTL=-1 for key with no expiry")
}
func TestUpdateCacheSuite(t *testing.T) {
suite.Run(t, new(UpdateCacheSuite))
}

View File

@@ -0,0 +1,890 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type UsageLogRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *UsageLogRepository
}
func (s *UsageLogRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewUsageLogRepository(s.db)
}
func TestUsageLogRepoSuite(t *testing.T) {
suite.Run(t, new(UsageLogRepoSuite))
}
func (s *UsageLogRepoSuite) createUsageLog(user *model.User, apiKey *model.ApiKey, account *model.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *model.UsageLog {
log := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3",
InputTokens: inputTokens,
OutputTokens: outputTokens,
TotalCost: cost,
ActualCost: cost,
CreatedAt: createdAt,
}
s.Require().NoError(s.repo.Create(s.ctx, log))
return log
}
// --- Create / GetByID ---
func (s *UsageLogRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "create@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-create", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-create"})
log := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.4,
}
err := s.repo.Create(s.ctx, log)
s.Require().NoError(err, "Create")
s.Require().NotZero(log.ID)
}
func (s *UsageLogRepoSuite) TestGetByID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "getbyid@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-getbyid", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-getbyid"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
got, err := s.repo.GetByID(s.ctx, log.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal(log.ID, got.ID)
s.Require().Equal(10, got.InputTokens)
}
func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
// --- Delete ---
func (s *UsageLogRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-delete", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-delete"})
log := s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
err := s.repo.Delete(s.ctx, log.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, log.ID)
s.Require().Error(err, "expected error after delete")
}
// --- ListByUser ---
func (s *UsageLogRepoSuite) TestListByUser() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyuser@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyuser", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyuser"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
logs, page, err := s.repo.ListByUser(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUser")
s.Require().Len(logs, 2)
s.Require().Equal(int64(2), page.Total)
}
// --- ListByApiKey ---
func (s *UsageLogRepoSuite) TestListByApiKey() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyapikey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyapikey", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyapikey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, time.Now())
logs, page, err := s.repo.ListByApiKey(s.ctx, apiKey.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByApiKey")
s.Require().Len(logs, 2)
s.Require().Equal(int64(2), page.Total)
}
// --- ListByAccount ---
func (s *UsageLogRepoSuite) TestListByAccount() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listbyaccount@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-listbyaccount", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-listbyaccount"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
logs, page, err := s.repo.ListByAccount(s.ctx, account.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByAccount")
s.Require().Len(logs, 1)
s.Require().Equal(int64(1), page.Total)
}
// --- GetUserStats ---
func (s *UsageLogRepoSuite) TestGetUserStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "userstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
stats, err := s.repo.GetUserStats(s.ctx, user.ID, startTime, endTime)
s.Require().NoError(err, "GetUserStats")
s.Require().Equal(int64(2), stats.TotalRequests)
s.Require().Equal(int64(25), stats.InputTokens)
s.Require().Equal(int64(45), stats.OutputTokens)
}
// --- ListWithFilters ---
func (s *UsageLogRepoSuite) TestListWithFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filters"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
filters := usagestats.UsageLogFilters{UserID: user.ID}
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
s.Require().NoError(err, "ListWithFilters")
s.Require().Len(logs, 1)
s.Require().Equal(int64(1), page.Total)
}
// --- GetDashboardStats ---
func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
now := time.Now()
todayStart := timezone.Today()
userToday := mustCreateUser(s.T(), s.db, &model.User{
Email: "today@example.com",
CreatedAt: maxTime(todayStart.Add(10*time.Second), now.Add(-10*time.Second)),
UpdatedAt: now,
})
userOld := mustCreateUser(s.T(), s.db, &model.User{
Email: "old@example.com",
CreatedAt: todayStart.Add(-24 * time.Hour),
UpdatedAt: todayStart.Add(-24 * time.Hour),
})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-ul"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userToday.ID, Key: "sk-ul-1", Name: "ul1"})
mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: userOld.ID, Key: "sk-ul-2", Name: "ul2", Status: model.StatusDisabled})
resetAt := now.Add(10 * time.Minute)
accNormal := mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-normal", Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-error", Status: model.StatusError, Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-rl", RateLimitedAt: &now, RateLimitResetAt: &resetAt, Schedulable: true})
mustCreateAccount(s.T(), s.db, &model.Account{Name: "a-ov", OverloadUntil: &resetAt, Schedulable: true})
d1, d2, d3 := 100, 200, 300
logToday := &model.UsageLog{
UserID: userToday.ID,
ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID,
Model: "claude-3",
GroupID: &group.ID,
InputTokens: 10,
OutputTokens: 20,
CacheCreationTokens: 3,
CacheReadTokens: 4,
TotalCost: 1.5,
ActualCost: 1.2,
DurationMs: &d1,
CreatedAt: maxTime(todayStart.Add(2*time.Minute), now.Add(-2*time.Minute)),
}
s.Require().NoError(s.repo.Create(s.ctx, logToday), "Create logToday")
logOld := &model.UsageLog{
UserID: userOld.ID,
ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID,
Model: "claude-3",
InputTokens: 5,
OutputTokens: 6,
TotalCost: 0.7,
ActualCost: 0.7,
DurationMs: &d2,
CreatedAt: todayStart.Add(-1 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, logOld), "Create logOld")
logPerf := &model.UsageLog{
UserID: userToday.ID,
ApiKeyID: apiKey1.ID,
AccountID: accNormal.ID,
Model: "claude-3",
InputTokens: 1,
OutputTokens: 2,
TotalCost: 0.1,
ActualCost: 0.1,
DurationMs: &d3,
CreatedAt: now.Add(-30 * time.Second),
}
s.Require().NoError(s.repo.Create(s.ctx, logPerf), "Create logPerf")
stats, err := s.repo.GetDashboardStats(s.ctx)
s.Require().NoError(err, "GetDashboardStats")
s.Require().Equal(int64(2), stats.TotalUsers, "TotalUsers mismatch")
s.Require().Equal(int64(1), stats.TodayNewUsers, "TodayNewUsers mismatch")
s.Require().Equal(int64(1), stats.ActiveUsers, "ActiveUsers mismatch")
s.Require().Equal(int64(2), stats.TotalApiKeys, "TotalApiKeys mismatch")
s.Require().Equal(int64(1), stats.ActiveApiKeys, "ActiveApiKeys mismatch")
s.Require().Equal(int64(4), stats.TotalAccounts, "TotalAccounts mismatch")
s.Require().Equal(int64(1), stats.ErrorAccounts, "ErrorAccounts mismatch")
s.Require().Equal(int64(1), stats.RateLimitAccounts, "RateLimitAccounts mismatch")
s.Require().Equal(int64(1), stats.OverloadAccounts, "OverloadAccounts mismatch")
s.Require().Equal(int64(3), stats.TotalRequests, "TotalRequests mismatch")
s.Require().Equal(int64(16), stats.TotalInputTokens, "TotalInputTokens mismatch")
s.Require().Equal(int64(28), stats.TotalOutputTokens, "TotalOutputTokens mismatch")
s.Require().Equal(int64(3), stats.TotalCacheCreationTokens, "TotalCacheCreationTokens mismatch")
s.Require().Equal(int64(4), stats.TotalCacheReadTokens, "TotalCacheReadTokens mismatch")
s.Require().Equal(int64(51), stats.TotalTokens, "TotalTokens mismatch")
s.Require().Equal(2.3, stats.TotalCost, "TotalCost mismatch")
s.Require().Equal(2.0, stats.TotalActualCost, "TotalActualCost mismatch")
s.Require().GreaterOrEqual(stats.TodayRequests, int64(1), "expected TodayRequests >= 1")
s.Require().GreaterOrEqual(stats.TodayCost, 0.0, "expected TodayCost >= 0")
wantRpm, wantTpm := s.repo.getPerformanceStats(s.ctx, 0)
s.Require().Equal(wantRpm, stats.Rpm, "Rpm mismatch")
s.Require().Equal(wantTpm, stats.Tpm, "Tpm mismatch")
}
// --- GetUserDashboardStats ---
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "userdash@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-userdash", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-userdash"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
stats, err := s.repo.GetUserDashboardStats(s.ctx, user.ID)
s.Require().NoError(err, "GetUserDashboardStats")
s.Require().Equal(int64(1), stats.TotalApiKeys)
s.Require().Equal(int64(1), stats.TotalRequests)
}
// --- GetAccountTodayStats ---
func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctoday@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-today"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
stats, err := s.repo.GetAccountTodayStats(s.ctx, account.ID)
s.Require().NoError(err, "GetAccountTodayStats")
s.Require().Equal(int64(1), stats.Requests)
s.Require().Equal(int64(30), stats.Tokens)
}
// --- GetBatchUserUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "batch2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-batch1", Name: "k"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-batch2", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batch"})
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now())
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID})
s.Require().NoError(err, "GetBatchUserUsageStats")
s.Require().Len(stats, 2)
s.Require().NotNil(stats[user1.ID])
s.Require().NotNil(stats[user2.ID])
}
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{})
s.Require().NoError(err)
s.Require().Empty(stats)
}
// --- GetBatchApiKeyUsageStats ---
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "batchkey@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-batchkey2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-batchkey"})
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
s.Require().NoError(err, "GetBatchApiKeyUsageStats")
s.Require().Len(stats, 2)
}
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
stats, err := s.repo.GetBatchApiKeyUsageStats(s.ctx, []int64{})
s.Require().NoError(err)
s.Require().Empty(stats)
}
// --- GetGlobalStats ---
func (s *UsageLogRepoSuite) TestGetGlobalStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "global@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-global", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-global"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
stats, err := s.repo.GetGlobalStats(s.ctx, base.Add(-1*time.Hour), base.Add(2*time.Hour))
s.Require().NoError(err, "GetGlobalStats")
s.Require().Equal(int64(2), stats.TotalRequests)
s.Require().Equal(int64(25), stats.TotalInputTokens)
s.Require().Equal(int64(45), stats.TotalOutputTokens)
}
func maxTime(a, b time.Time) time.Time {
if a.After(b) {
return a
}
return b
}
// --- ListByUserAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByUserAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "timerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-timerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-timerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(-24*time.Hour)) // outside range
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
logs, _, err := s.repo.ListByUserAndTimeRange(s.ctx, user.ID, startTime, endTime)
s.Require().NoError(err, "ListByUserAndTimeRange")
s.Require().Len(logs, 2)
}
// --- ListByApiKeyAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByApiKeyAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(30*time.Minute))
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(-24*time.Hour)) // outside range
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
logs, _, err := s.repo.ListByApiKeyAndTimeRange(s.ctx, apiKey.ID, startTime, endTime)
s.Require().NoError(err, "ListByApiKeyAndTimeRange")
s.Require().Len(logs, 2)
}
// --- ListByAccountAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByAccountAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "acctimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-acctimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-acctimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(45*time.Minute))
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(-24*time.Hour)) // outside range
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
logs, _, err := s.repo.ListByAccountAndTimeRange(s.ctx, account.ID, startTime, endTime)
s.Require().NoError(err, "ListByAccountAndTimeRange")
s.Require().Len(logs, 2)
}
// --- ListByModelAndTimeRange ---
func (s *UsageLogRepoSuite) TestListByModelAndTimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modeltimerange@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modeltimerange", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modeltimerange"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
// Create logs with different models
log1 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: base,
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 15,
OutputTokens: 25,
TotalCost: 0.6,
ActualCost: 0.6,
CreatedAt: base.Add(30 * time.Minute),
}
s.Require().NoError(s.repo.Create(s.ctx, log2))
log3 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 20,
OutputTokens: 30,
TotalCost: 0.7,
ActualCost: 0.7,
CreatedAt: base.Add(1 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, log3))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
logs, _, err := s.repo.ListByModelAndTimeRange(s.ctx, "claude-3-opus", startTime, endTime)
s.Require().NoError(err, "ListByModelAndTimeRange")
s.Require().Len(logs, 2)
}
// --- GetAccountWindowStats ---
func (s *UsageLogRepoSuite) TestGetAccountWindowStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "windowstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-windowstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-windowstats"})
now := time.Now()
windowStart := now.Add(-10 * time.Minute)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, now.Add(-5*time.Minute))
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, now.Add(-3*time.Minute))
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, now.Add(-30*time.Minute)) // outside window
stats, err := s.repo.GetAccountWindowStats(s.ctx, account.ID, windowStart)
s.Require().NoError(err, "GetAccountWindowStats")
s.Require().Equal(int64(2), stats.Requests)
s.Require().Equal(int64(70), stats.Tokens) // (10+20) + (15+25)
}
// --- GetUserUsageTrendByUserID ---
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrend", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrend"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(24*time.Hour)) // next day
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(48 * time.Hour)
trend, err := s.repo.GetUserUsageTrendByUserID(s.ctx, user.ID, startTime, endTime, "day")
s.Require().NoError(err, "GetUserUsageTrendByUserID")
s.Require().Len(trend, 2) // 2 different days
}
func (s *UsageLogRepoSuite) TestGetUserUsageTrendByUserID_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrendhourly@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-usertrendhourly", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrendhourly"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(2*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetUserUsageTrendByUserID(s.ctx, user.ID, startTime, endTime, "hour")
s.Require().NoError(err, "GetUserUsageTrendByUserID hourly")
s.Require().Len(trend, 3) // 3 different hours
}
// --- GetUserModelStats ---
func (s *UsageLogRepoSuite) TestGetUserModelStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelstats"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
// Create logs with different models
log1 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 100,
OutputTokens: 200,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: base,
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 50,
OutputTokens: 100,
TotalCost: 0.2,
ActualCost: 0.2,
CreatedAt: base.Add(1 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, log2))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
stats, err := s.repo.GetUserModelStats(s.ctx, user.ID, startTime, endTime)
s.Require().NoError(err, "GetUserModelStats")
s.Require().Len(stats, 2)
// Should be ordered by total_tokens DESC
s.Require().Equal("claude-3-opus", stats[0].Model)
s.Require().Equal(int64(300), stats[0].TotalTokens)
}
// --- GetUsageTrendWithFilters ---
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(24*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(48 * time.Hour)
// Test with user filter
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0)
s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
s.Require().Len(trend, 2)
// Test with apiKey filter
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID)
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
s.Require().Len(trend, 2)
// Test with both filters
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID)
s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
s.Require().Len(trend, 2)
}
func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "trendfilters-h@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-trendfilters-h", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-trendfilters-h"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0)
s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
s.Require().Len(trend, 2)
}
// --- GetModelStatsWithFilters ---
func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "modelfilters@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-modelfilters", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-modelfilters"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
log1 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 100,
OutputTokens: 200,
TotalCost: 0.5,
ActualCost: 0.5,
CreatedAt: base,
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 50,
OutputTokens: 100,
TotalCost: 0.2,
ActualCost: 0.2,
CreatedAt: base.Add(1 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, log2))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
// Test with user filter
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0)
s.Require().NoError(err, "GetModelStatsWithFilters user filter")
s.Require().Len(stats, 2)
// Test with apiKey filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0)
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
s.Require().Len(stats, 2)
// Test with account filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID)
s.Require().NoError(err, "GetModelStatsWithFilters account filter")
s.Require().Len(stats, 2)
}
// --- GetAccountUsageStats ---
func (s *UsageLogRepoSuite) TestGetAccountUsageStats() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "accstats@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-accstats", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-accstats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
// Create logs on different days
log1 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-opus",
InputTokens: 100,
OutputTokens: 200,
TotalCost: 0.5,
ActualCost: 0.4,
CreatedAt: base.Add(12 * time.Hour),
}
s.Require().NoError(s.repo.Create(s.ctx, log1))
log2 := &model.UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
Model: "claude-3-sonnet",
InputTokens: 50,
OutputTokens: 100,
TotalCost: 0.2,
ActualCost: 0.15,
CreatedAt: base.Add(36 * time.Hour), // next day
}
s.Require().NoError(s.repo.Create(s.ctx, log2))
startTime := base
endTime := base.Add(72 * time.Hour)
resp, err := s.repo.GetAccountUsageStats(s.ctx, account.ID, startTime, endTime)
s.Require().NoError(err, "GetAccountUsageStats")
s.Require().Len(resp.History, 2, "expected 2 days of history")
s.Require().Equal(int64(2), resp.Summary.TotalRequests)
s.Require().Equal(int64(450), resp.Summary.TotalTokens)
s.Require().Len(resp.Models, 2)
}
func (s *UsageLogRepoSuite) TestGetAccountUsageStats_EmptyRange() {
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-emptystats"})
base := time.Date(2025, 1, 15, 0, 0, 0, 0, time.UTC)
startTime := base
endTime := base.Add(72 * time.Hour)
resp, err := s.repo.GetAccountUsageStats(s.ctx, account.ID, startTime, endTime)
s.Require().NoError(err, "GetAccountUsageStats empty")
s.Require().Len(resp.History, 0)
s.Require().Equal(int64(0), resp.Summary.TotalRequests)
}
// --- GetUserUsageTrend ---
func (s *UsageLogRepoSuite) TestGetUserUsageTrend() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "usertrend2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user1.ID, Key: "sk-usertrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user2.ID, Key: "sk-usertrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-usertrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user1, apiKey1, account, 100, 200, 1.0, base)
s.createUsageLog(user2, apiKey2, account, 50, 100, 0.5, base)
s.createUsageLog(user1, apiKey1, account, 100, 200, 1.0, base.Add(24*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(48 * time.Hour)
trend, err := s.repo.GetUserUsageTrend(s.ctx, startTime, endTime, "day", 10)
s.Require().NoError(err, "GetUserUsageTrend")
s.Require().GreaterOrEqual(len(trend), 2)
}
// --- GetApiKeyUsageTrend ---
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrend@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrends"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base)
s.createUsageLog(user, apiKey2, account, 50, 100, 0.5, base)
s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base.Add(24*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(48 * time.Hour)
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "day", 10)
s.Require().NoError(err, "GetApiKeyUsageTrend")
s.Require().GreaterOrEqual(len(trend), 2)
}
func (s *UsageLogRepoSuite) TestGetApiKeyUsageTrend_HourlyGranularity() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "keytrendh@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-keytrendh"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 100, 200, 1.0, base)
s.createUsageLog(user, apiKey, account, 50, 100, 0.5, base.Add(1*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetApiKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10)
s.Require().NoError(err, "GetApiKeyUsageTrend hourly")
s.Require().Len(trend, 2)
}
// --- ListWithFilters (additional filter tests) ---
func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterskey@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterskey", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterskey"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
filters := usagestats.UsageLogFilters{ApiKeyID: apiKey.ID}
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
s.Require().NoError(err, "ListWithFilters apiKey")
s.Require().Len(logs, 1)
s.Require().Equal(int64(1), page.Total)
}
func (s *UsageLogRepoSuite) TestListWithFilters_TimeRange() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterstime@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterstime", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterstime"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
s.createUsageLog(user, apiKey, account, 20, 30, 0.7, base.Add(-24*time.Hour)) // outside range
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
filters := usagestats.UsageLogFilters{StartTime: &startTime, EndTime: &endTime}
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
s.Require().NoError(err, "ListWithFilters time range")
s.Require().Len(logs, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *UsageLogRepoSuite) TestListWithFilters_CombinedFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "filterscombined@test.com"})
apiKey := mustCreateApiKey(s.T(), s.db, &model.ApiKey{UserID: user.ID, Key: "sk-filterscombined", Name: "k"})
account := mustCreateAccount(s.T(), s.db, &model.Account{Name: "acc-filterscombined"})
base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC)
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, base)
s.createUsageLog(user, apiKey, account, 15, 25, 0.6, base.Add(1*time.Hour))
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(2 * time.Hour)
filters := usagestats.UsageLogFilters{
UserID: user.ID,
ApiKeyID: apiKey.ID,
StartTime: &startTime,
EndTime: &endTime,
}
logs, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, filters)
s.Require().NoError(err, "ListWithFilters combined")
s.Require().Len(logs, 2)
s.Require().Equal(int64(2), page.Total)
}

View File

@@ -0,0 +1,448 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/lib/pq"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type UserRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *UserRepository
}
func (s *UserRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewUserRepository(s.db)
}
func TestUserRepoSuite(t *testing.T) {
suite.Run(t, new(UserRepoSuite))
}
// --- Create / GetByID / GetByEmail / Update / Delete ---
func (s *UserRepoSuite) TestCreate() {
user := &model.User{
Email: "create@test.com",
Username: "testuser",
Role: model.RoleUser,
Status: model.StatusActive,
}
err := s.repo.Create(s.ctx, user)
s.Require().NoError(err, "Create")
s.Require().NotZero(user.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("create@test.com", got.Email)
}
func (s *UserRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *UserRepoSuite) TestGetByEmail() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "byemail@test.com"})
got, err := s.repo.GetByEmail(s.ctx, user.Email)
s.Require().NoError(err, "GetByEmail")
s.Require().Equal(user.ID, got.ID)
}
func (s *UserRepoSuite) TestGetByEmail_NotFound() {
_, err := s.repo.GetByEmail(s.ctx, "nonexistent@test.com")
s.Require().Error(err, "expected error for non-existent email")
}
func (s *UserRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com", Username: "original"})
user.Username = "updated"
err := s.repo.Update(s.ctx, user)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated", got.Username)
}
func (s *UserRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
err := s.repo.Delete(s.ctx, user.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, user.ID)
s.Require().Error(err, "expected error after delete")
}
// --- List / ListWithFilters ---
func (s *UserRepoSuite) TestList() {
mustCreateUser(s.T(), s.db, &model.User{Email: "list1@test.com"})
mustCreateUser(s.T(), s.db, &model.User{Email: "list2@test.com"})
users, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
s.Require().Len(users, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *UserRepoSuite) TestListWithFilters_Status() {
mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com", Status: model.StatusActive})
mustCreateUser(s.T(), s.db, &model.User{Email: "disabled@test.com", Status: model.StatusDisabled})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, "", "")
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal(model.StatusActive, users[0].Status)
}
func (s *UserRepoSuite) TestListWithFilters_Role() {
mustCreateUser(s.T(), s.db, &model.User{Email: "user@test.com", Role: model.RoleUser})
mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", model.RoleAdmin, "")
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal(model.RoleAdmin, users[0].Role)
}
func (s *UserRepoSuite) TestListWithFilters_Search() {
mustCreateUser(s.T(), s.db, &model.User{Email: "alice@test.com", Username: "Alice"})
mustCreateUser(s.T(), s.db, &model.User{Email: "bob@test.com", Username: "Bob"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alice")
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Contains(users[0].Email, "alice")
}
func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com", Username: "JohnDoe"})
mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com", Username: "JaneSmith"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "john")
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal("JohnDoe", users[0].Username)
}
func (s *UserRepoSuite) TestListWithFilters_SearchByWechat() {
mustCreateUser(s.T(), s.db, &model.User{Email: "w1@test.com", Wechat: "wx_hello"})
mustCreateUser(s.T(), s.db, &model.User{Email: "w2@test.com", Wechat: "wx_world"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "wx_hello")
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal("wx_hello", users[0].Wechat)
}
func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub@test.com", Status: model.StatusActive})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-sub"})
_ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(1 * time.Hour),
})
_ = mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-1 * time.Hour),
})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "sub@")
s.Require().NoError(err, "ListWithFilters")
s.Require().Len(users, 1, "expected 1 user")
s.Require().Len(users[0].Subscriptions, 1, "expected 1 active subscription")
s.Require().NotNil(users[0].Subscriptions[0].Group, "expected subscription group preload")
s.Require().Equal(group.ID, users[0].Subscriptions[0].Group.ID, "group ID mismatch")
}
func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
mustCreateUser(s.T(), s.db, &model.User{
Email: "a@example.com",
Username: "Alice",
Wechat: "wx_a",
Role: model.RoleUser,
Status: model.StatusActive,
Balance: 10,
})
target := mustCreateUser(s.T(), s.db, &model.User{
Email: "b@example.com",
Username: "Bob",
Wechat: "wx_b",
Role: model.RoleAdmin,
Status: model.StatusActive,
Balance: 1,
})
mustCreateUser(s.T(), s.db, &model.User{
Email: "c@example.com",
Role: model.RoleAdmin,
Status: model.StatusDisabled,
})
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, model.StatusActive, model.RoleAdmin, "b@")
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch")
s.Require().Equal(target.ID, users[0].ID, "ListWithFilters result mismatch")
}
// --- Balance operations ---
func (s *UserRepoSuite) TestUpdateBalance() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "bal@test.com", Balance: 10})
err := s.repo.UpdateBalance(s.ctx, user.ID, 2.5)
s.Require().NoError(err, "UpdateBalance")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Equal(12.5, got.Balance)
}
func (s *UserRepoSuite) TestUpdateBalance_Negative() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "balneg@test.com", Balance: 10})
err := s.repo.UpdateBalance(s.ctx, user.ID, -3)
s.Require().NoError(err, "UpdateBalance with negative")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Equal(7.0, got.Balance)
}
func (s *UserRepoSuite) TestDeductBalance() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "deduct@test.com", Balance: 10})
err := s.repo.DeductBalance(s.ctx, user.ID, 5)
s.Require().NoError(err, "DeductBalance")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Equal(5.0, got.Balance)
}
func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "insuf@test.com", Balance: 5})
err := s.repo.DeductBalance(s.ctx, user.ID, 999)
s.Require().Error(err, "expected error for insufficient balance")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound)
}
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exact@test.com", Balance: 10})
err := s.repo.DeductBalance(s.ctx, user.ID, 10)
s.Require().NoError(err, "DeductBalance exact amount")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Zero(got.Balance)
}
// --- Concurrency ---
func (s *UserRepoSuite) TestUpdateConcurrency() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "conc@test.com", Concurrency: 5})
err := s.repo.UpdateConcurrency(s.ctx, user.ID, 3)
s.Require().NoError(err, "UpdateConcurrency")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Equal(8, got.Concurrency)
}
func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "concneg@test.com", Concurrency: 5})
err := s.repo.UpdateConcurrency(s.ctx, user.ID, -2)
s.Require().NoError(err, "UpdateConcurrency negative")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Equal(3, got.Concurrency)
}
// --- ExistsByEmail ---
func (s *UserRepoSuite) TestExistsByEmail() {
mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"})
exists, err := s.repo.ExistsByEmail(s.ctx, "exists@test.com")
s.Require().NoError(err, "ExistsByEmail")
s.Require().True(exists)
notExists, err := s.repo.ExistsByEmail(s.ctx, "notexists@test.com")
s.Require().NoError(err)
s.Require().False(notExists)
}
// --- RemoveGroupFromAllowedGroups ---
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() {
groupID := int64(42)
userA := mustCreateUser(s.T(), s.db, &model.User{
Email: "a1@example.com",
AllowedGroups: pq.Int64Array{groupID, 7},
})
mustCreateUser(s.T(), s.db, &model.User{
Email: "a2@example.com",
AllowedGroups: pq.Int64Array{7},
})
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, groupID)
s.Require().NoError(err, "RemoveGroupFromAllowedGroups")
s.Require().Equal(int64(1), affected, "expected 1 affected row")
got, err := s.repo.GetByID(s.ctx, userA.ID)
s.Require().NoError(err, "GetByID")
for _, id := range got.AllowedGroups {
s.Require().NotEqual(groupID, id, "expected groupID to be removed from allowed_groups")
}
}
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
mustCreateUser(s.T(), s.db, &model.User{
Email: "nomatch@test.com",
AllowedGroups: pq.Int64Array{1, 2, 3},
})
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, 999)
s.Require().NoError(err)
s.Require().Zero(affected, "expected no affected rows")
}
// --- GetFirstAdmin ---
func (s *UserRepoSuite) TestGetFirstAdmin() {
admin1 := mustCreateUser(s.T(), s.db, &model.User{
Email: "admin1@example.com",
Role: model.RoleAdmin,
Status: model.StatusActive,
})
mustCreateUser(s.T(), s.db, &model.User{
Email: "admin2@example.com",
Role: model.RoleAdmin,
Status: model.StatusActive,
})
got, err := s.repo.GetFirstAdmin(s.ctx)
s.Require().NoError(err, "GetFirstAdmin")
s.Require().Equal(admin1.ID, got.ID, "GetFirstAdmin mismatch")
}
func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
mustCreateUser(s.T(), s.db, &model.User{
Email: "user@example.com",
Role: model.RoleUser,
Status: model.StatusActive,
})
_, err := s.repo.GetFirstAdmin(s.ctx)
s.Require().Error(err, "expected error when no admin exists")
}
func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
mustCreateUser(s.T(), s.db, &model.User{
Email: "disabled@example.com",
Role: model.RoleAdmin,
Status: model.StatusDisabled,
})
activeAdmin := mustCreateUser(s.T(), s.db, &model.User{
Email: "active@example.com",
Role: model.RoleAdmin,
Status: model.StatusActive,
})
got, err := s.repo.GetFirstAdmin(s.ctx)
s.Require().NoError(err, "GetFirstAdmin")
s.Require().Equal(activeAdmin.ID, got.ID, "should return only active admin")
}
// --- Combined original test ---
func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
user1 := mustCreateUser(s.T(), s.db, &model.User{
Email: "a@example.com",
Username: "Alice",
Wechat: "wx_a",
Role: model.RoleUser,
Status: model.StatusActive,
Balance: 10,
})
user2 := mustCreateUser(s.T(), s.db, &model.User{
Email: "b@example.com",
Username: "Bob",
Wechat: "wx_b",
Role: model.RoleAdmin,
Status: model.StatusActive,
Balance: 1,
})
_ = mustCreateUser(s.T(), s.db, &model.User{
Email: "c@example.com",
Role: model.RoleAdmin,
Status: model.StatusDisabled,
})
got, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal(user1.Email, got.Email, "GetByID email mismatch")
gotByEmail, err := s.repo.GetByEmail(s.ctx, user2.Email)
s.Require().NoError(err, "GetByEmail")
s.Require().Equal(user2.ID, gotByEmail.ID, "GetByEmail ID mismatch")
got.Username = "Alice2"
s.Require().NoError(s.repo.Update(s.ctx, got), "Update")
got2, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("Alice2", got2.Username, "Update did not persist")
s.Require().NoError(s.repo.UpdateBalance(s.ctx, user1.ID, 2.5), "UpdateBalance")
got3, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after UpdateBalance")
s.Require().Equal(12.5, got3.Balance, "UpdateBalance mismatch")
s.Require().NoError(s.repo.DeductBalance(s.ctx, user1.ID, 5), "DeductBalance")
got4, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after DeductBalance")
s.Require().Equal(7.5, got4.Balance, "DeductBalance mismatch")
err = s.repo.DeductBalance(s.ctx, user1.ID, 999)
s.Require().Error(err, "DeductBalance expected error for insufficient balance")
s.Require().ErrorIs(err, gorm.ErrRecordNotFound, "DeductBalance unexpected error")
s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency")
got5, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after UpdateConcurrency")
s.Require().Equal(user1.Concurrency+3, got5.Concurrency, "UpdateConcurrency mismatch")
params := pagination.PaginationParams{Page: 1, PageSize: 10}
users, page, err := s.repo.ListWithFilters(s.ctx, params, model.StatusActive, model.RoleAdmin, "b@")
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch")
s.Require().Equal(user2.ID, users[0].ID, "ListWithFilters result mismatch")
}

View File

@@ -0,0 +1,733 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/suite"
"gorm.io/gorm"
)
type UserSubscriptionRepoSuite struct {
suite.Suite
ctx context.Context
db *gorm.DB
repo *UserSubscriptionRepository
}
func (s *UserSubscriptionRepoSuite) SetupTest() {
s.ctx = context.Background()
s.db = testTx(s.T())
s.repo = NewUserSubscriptionRepository(s.db)
}
func TestUserSubscriptionRepoSuite(t *testing.T) {
suite.Run(t, new(UserSubscriptionRepoSuite))
}
// --- Create / GetByID / Update / Delete ---
func (s *UserSubscriptionRepoSuite) TestCreate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "sub-create@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-create"})
sub := &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
err := s.repo.Create(s.ctx, sub)
s.Require().NoError(err, "Create")
s.Require().NotZero(sub.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal(sub.UserID, got.UserID)
s.Require().Equal(sub.GroupID, got.GroupID)
}
func (s *UserSubscriptionRepoSuite) TestGetByID_WithPreloads() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "preload@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-preload"})
admin := mustCreateUser(s.T(), s.db, &model.User{Email: "admin@test.com", Role: model.RoleAdmin})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
AssignedBy: &admin.ID,
})
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err, "GetByID")
s.Require().NotNil(got.User, "expected User preload")
s.Require().NotNil(got.Group, "expected Group preload")
s.Require().NotNil(got.AssignedByUser, "expected AssignedByUser preload")
s.Require().Equal(user.ID, got.User.ID)
s.Require().Equal(group.ID, got.Group.ID)
s.Require().Equal(admin.ID, got.AssignedByUser.ID)
}
func (s *UserSubscriptionRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *UserSubscriptionRepoSuite) TestUpdate() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "update@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-update"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
sub.Notes = "updated notes"
err := s.repo.Update(s.ctx, sub)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated notes", got.Notes)
}
func (s *UserSubscriptionRepoSuite) TestDelete() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delete@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delete"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
err := s.repo.Delete(s.ctx, sub.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, sub.ID)
s.Require().Error(err, "expected error after delete")
}
// --- GetByUserIDAndGroupID / GetActiveByUserIDAndGroupID ---
func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "byuser@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-byuser"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
got, err := s.repo.GetByUserIDAndGroupID(s.ctx, user.ID, group.ID)
s.Require().NoError(err, "GetByUserIDAndGroupID")
s.Require().Equal(sub.ID, got.ID)
s.Require().NotNil(got.Group, "expected Group preload")
}
func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID_NotFound() {
_, err := s.repo.GetByUserIDAndGroupID(s.ctx, 999999, 999999)
s.Require().Error(err, "expected error for non-existent pair")
}
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "active@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-active"})
// Create active subscription (future expiry)
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(2 * time.Hour),
})
got, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
s.Require().NoError(err, "GetActiveByUserIDAndGroupID")
s.Require().Equal(active.ID, got.ID)
}
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnored() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "expired@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-expired"})
// Create expired subscription (past expiry but active status)
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-2 * time.Hour),
})
_, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
s.Require().Error(err, "expected error for expired subscription")
}
// --- ListByUserID / ListActiveByUserID ---
func (s *UserSubscriptionRepoSuite) TestListByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listby@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list2"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: g1.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: g2.ID,
Status: model.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
subs, err := s.repo.ListByUserID(s.ctx, user.ID)
s.Require().NoError(err, "ListByUserID")
s.Require().Len(subs, 2)
for _, sub := range subs {
s.Require().NotNil(sub.Group, "expected Group preload")
}
}
func (s *UserSubscriptionRepoSuite) TestListActiveByUserID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listactive@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-act2"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: g1.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: g2.ID,
Status: model.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
subs, err := s.repo.ListActiveByUserID(s.ctx, user.ID)
s.Require().NoError(err, "ListActiveByUserID")
s.Require().Len(subs, 1)
s.Require().Equal(model.SubscriptionStatusActive, subs[0].Status)
}
// --- ListByGroupID ---
func (s *UserSubscriptionRepoSuite) TestListByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "u1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "u2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listgrp"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user1.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user2.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
subs, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByGroupID")
s.Require().Len(subs, 2)
s.Require().Equal(int64(2), page.Total)
for _, sub := range subs {
s.Require().NotNil(sub.User, "expected User preload")
s.Require().NotNil(sub.Group, "expected Group preload")
}
}
// --- List with filters ---
func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "list@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-list"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "")
s.Require().NoError(err, "List")
s.Require().Len(subs, 1)
s.Require().Equal(int64(1), page.Total)
}
func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "filter2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-filter"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user1.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user2.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(user1.ID, subs[0].UserID)
}
func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "grpfilter@test.com"})
g1 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f1"})
g2 := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-f2"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: g1.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: g2.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(g1.ID, subs[0].GroupID)
}
func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "statfilter@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-stat"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, model.SubscriptionStatusExpired)
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(model.SubscriptionStatusExpired, subs[0].Status)
}
// --- Usage tracking ---
func (s *UserSubscriptionRepoSuite) TestIncrementUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "usage@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-usage"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
err := s.repo.IncrementUsage(s.ctx, sub.ID, 1.25)
s.Require().NoError(err, "IncrementUsage")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Equal(1.25, got.DailyUsageUSD)
s.Require().Equal(1.25, got.WeeklyUsageUSD)
s.Require().Equal(1.25, got.MonthlyUsageUSD)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "accum@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-accum"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
s.Require().NoError(s.repo.IncrementUsage(s.ctx, sub.ID, 1.0))
s.Require().NoError(s.repo.IncrementUsage(s.ctx, sub.ID, 2.5))
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Equal(3.5, got.DailyUsageUSD)
}
func (s *UserSubscriptionRepoSuite) TestActivateWindows() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "activate@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-activate"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
activateAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
err := s.repo.ActivateWindows(s.ctx, sub.ID, activateAt)
s.Require().NoError(err, "ActivateWindows")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().NotNil(got.DailyWindowStart)
s.Require().NotNil(got.WeeklyWindowStart)
s.Require().NotNil(got.MonthlyWindowStart)
s.Require().True(got.DailyWindowStart.Equal(activateAt))
}
func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetd@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetd"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
DailyUsageUSD: 10.0,
WeeklyUsageUSD: 20.0,
})
resetAt := time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)
err := s.repo.ResetDailyUsage(s.ctx, sub.ID, resetAt)
s.Require().NoError(err, "ResetDailyUsage")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Zero(got.DailyUsageUSD)
s.Require().Equal(20.0, got.WeeklyUsageUSD, "weekly should remain unchanged")
s.Require().True(got.DailyWindowStart.Equal(resetAt))
}
func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetw@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetw"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
WeeklyUsageUSD: 15.0,
MonthlyUsageUSD: 30.0,
})
resetAt := time.Date(2025, 1, 6, 0, 0, 0, 0, time.UTC)
err := s.repo.ResetWeeklyUsage(s.ctx, sub.ID, resetAt)
s.Require().NoError(err, "ResetWeeklyUsage")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Zero(got.WeeklyUsageUSD)
s.Require().Equal(30.0, got.MonthlyUsageUSD, "monthly should remain unchanged")
s.Require().True(got.WeeklyWindowStart.Equal(resetAt))
}
func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "resetm@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-resetm"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
MonthlyUsageUSD: 100.0,
})
resetAt := time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC)
err := s.repo.ResetMonthlyUsage(s.ctx, sub.ID, resetAt)
s.Require().NoError(err, "ResetMonthlyUsage")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Zero(got.MonthlyUsageUSD)
s.Require().True(got.MonthlyWindowStart.Equal(resetAt))
}
// --- UpdateStatus / ExtendExpiry / UpdateNotes ---
func (s *UserSubscriptionRepoSuite) TestUpdateStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "status@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-status"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
err := s.repo.UpdateStatus(s.ctx, sub.ID, model.SubscriptionStatusExpired)
s.Require().NoError(err, "UpdateStatus")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Equal(model.SubscriptionStatusExpired, got.Status)
}
func (s *UserSubscriptionRepoSuite) TestExtendExpiry() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "extend@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-extend"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
newExpiry := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
err := s.repo.ExtendExpiry(s.ctx, sub.ID, newExpiry)
s.Require().NoError(err, "ExtendExpiry")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().True(got.ExpiresAt.Equal(newExpiry))
}
func (s *UserSubscriptionRepoSuite) TestUpdateNotes() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "notes@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-notes"})
sub := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
err := s.repo.UpdateNotes(s.ctx, sub.ID, "VIP user")
s.Require().NoError(err, "UpdateNotes")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Equal("VIP user", got.Notes)
}
// --- ListExpired / BatchUpdateExpiredStatus ---
func (s *UserSubscriptionRepoSuite) TestListExpired() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "listexp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-listexp"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
expired, err := s.repo.ListExpired(s.ctx)
s.Require().NoError(err, "ListExpired")
s.Require().Len(expired, 1)
}
func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "batch@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-batch"})
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
affected, err := s.repo.BatchUpdateExpiredStatus(s.ctx)
s.Require().NoError(err, "BatchUpdateExpiredStatus")
s.Require().Equal(int64(1), affected)
gotActive, _ := s.repo.GetByID(s.ctx, active.ID)
s.Require().Equal(model.SubscriptionStatusActive, gotActive.Status)
gotExpired, _ := s.repo.GetByID(s.ctx, expiredActive.ID)
s.Require().Equal(model.SubscriptionStatusExpired, gotExpired.Status)
}
// --- ExistsByUserIDAndGroupID ---
func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "exists@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-exists"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
exists, err := s.repo.ExistsByUserIDAndGroupID(s.ctx, user.ID, group.ID)
s.Require().NoError(err, "ExistsByUserIDAndGroupID")
s.Require().True(exists)
notExists, err := s.repo.ExistsByUserIDAndGroupID(s.ctx, user.ID, 999999)
s.Require().NoError(err)
s.Require().False(notExists)
}
// --- CountByGroupID / CountActiveByGroupID ---
func (s *UserSubscriptionRepoSuite) TestCountByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cnt2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-count"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user1.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user2.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID")
s.Require().Equal(int64(2), count)
}
func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() {
user1 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact1@test.com"})
user2 := mustCreateUser(s.T(), s.db, &model.User{Email: "cntact2@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-cntact"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user1.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user2.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-24 * time.Hour), // expired by time
})
count, err := s.repo.CountActiveByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountActiveByGroupID")
s.Require().Equal(int64(1), count, "only future expiry counts as active")
}
// --- DeleteByGroupID ---
func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "delgrp@test.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-delgrp"})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
})
mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusExpired,
ExpiresAt: time.Now().Add(-24 * time.Hour),
})
affected, err := s.repo.DeleteByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "DeleteByGroupID")
s.Require().Equal(int64(2), affected)
count, _ := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().Zero(count)
}
// --- Combined original test ---
func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_BatchUpdateExpiredStatus() {
user := mustCreateUser(s.T(), s.db, &model.User{Email: "subr@example.com"})
group := mustCreateGroup(s.T(), s.db, &model.Group{Name: "g-subr"})
active := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(2 * time.Hour),
})
expiredActive := mustCreateSubscription(s.T(), s.db, &model.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: model.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(-2 * time.Hour),
})
got, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
s.Require().NoError(err, "GetActiveByUserIDAndGroupID")
s.Require().Equal(active.ID, got.ID, "expected active subscription")
activateAt := time.Now().Add(-25 * time.Hour)
s.Require().NoError(s.repo.ActivateWindows(s.ctx, active.ID, activateAt), "ActivateWindows")
s.Require().NoError(s.repo.IncrementUsage(s.ctx, active.ID, 1.25), "IncrementUsage")
after, err := s.repo.GetByID(s.ctx, active.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal(1.25, after.DailyUsageUSD, "DailyUsageUSD mismatch")
s.Require().Equal(1.25, after.WeeklyUsageUSD, "WeeklyUsageUSD mismatch")
s.Require().Equal(1.25, after.MonthlyUsageUSD, "MonthlyUsageUSD mismatch")
s.Require().NotNil(after.DailyWindowStart, "expected DailyWindowStart activated")
s.Require().NotNil(after.WeeklyWindowStart, "expected WeeklyWindowStart activated")
s.Require().NotNil(after.MonthlyWindowStart, "expected MonthlyWindowStart activated")
resetAt := time.Now().Truncate(time.Microsecond) // truncate to microsecond for DB precision
s.Require().NoError(s.repo.ResetDailyUsage(s.ctx, active.ID, resetAt), "ResetDailyUsage")
afterReset, err := s.repo.GetByID(s.ctx, active.ID)
s.Require().NoError(err, "GetByID after reset")
s.Require().Equal(0.0, afterReset.DailyUsageUSD, "expected daily usage reset to 0")
s.Require().NotNil(afterReset.DailyWindowStart, "expected DailyWindowStart not nil")
s.Require().True(afterReset.DailyWindowStart.Equal(resetAt), "expected daily window start updated")
affected, err := s.repo.BatchUpdateExpiredStatus(s.ctx)
s.Require().NoError(err, "BatchUpdateExpiredStatus")
s.Require().Equal(int64(1), affected, "expected 1 affected row")
updated, err := s.repo.GetByID(s.ctx, expiredActive.ID)
s.Require().NoError(err, "GetByID expired")
s.Require().Equal(model.SubscriptionStatusExpired, updated.Status, "expected status expired")
}