first commit
Some checks failed
CI / test (push) Has been cancelled
CI / golangci-lint (push) Has been cancelled
Security Scan / backend-security (push) Has been cancelled
Security Scan / frontend-security (push) Has been cancelled

This commit is contained in:
huangzhenpc
2026-01-15 20:29:55 +08:00
commit f52498603c
820 changed files with 320002 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,587 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/accountgroup"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
)
type AccountRepoSuite struct {
suite.Suite
ctx context.Context
client *dbent.Client
repo *accountRepository
}
func (s *AccountRepoSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
s.repo = newAccountRepositoryWithSQL(s.client, tx)
}
func TestAccountRepoSuite(t *testing.T) {
suite.Run(t, new(AccountRepoSuite))
}
// --- Create / GetByID / Update / Delete ---
func (s *AccountRepoSuite) TestCreate() {
account := &service.Account{
Name: "test-create",
Platform: service.PlatformAnthropic,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Credentials: map[string]any{},
Extra: map[string]any{},
Concurrency: 3,
Priority: 50,
Schedulable: true,
}
err := s.repo.Create(s.ctx, account)
s.Require().NoError(err, "Create")
s.Require().NotZero(account.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("test-create", got.Name)
}
func (s *AccountRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *AccountRepoSuite) TestUpdate() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "original"})
account.Name = "updated"
err := s.repo.Update(s.ctx, account)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated", got.Name)
}
func (s *AccountRepoSuite) TestDelete() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, account.ID)
s.Require().Error(err, "expected error after delete")
}
func (s *AccountRepoSuite) TestDelete_WithGroupBindings() {
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-del"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-del"})
mustBindAccountToGroup(s.T(), s.client, account.ID, group.ID, 1)
err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete should cascade remove bindings")
count, err := s.client.AccountGroup.Query().Where(accountgroup.AccountIDEQ(account.ID)).Count(s.ctx)
s.Require().NoError(err)
s.Require().Zero(count, "expected bindings to be removed")
}
// --- List / ListWithFilters ---
func (s *AccountRepoSuite) TestList() {
mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc1"})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc2"})
accounts, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
s.Require().Len(accounts, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *AccountRepoSuite) TestListWithFilters() {
tests := []struct {
name string
setup func(client *dbent.Client)
platform string
accType string
status string
search string
wantCount int
validate func(accounts []service.Account)
}{
{
name: "filter_by_platform",
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "a1", Platform: service.PlatformAnthropic})
mustCreateAccount(s.T(), client, &service.Account{Name: "a2", Platform: service.PlatformOpenAI})
},
platform: service.PlatformOpenAI,
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal(service.PlatformOpenAI, accounts[0].Platform)
},
},
{
name: "filter_by_type",
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "t1", Type: service.AccountTypeOAuth})
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeAPIKey})
},
accType: service.AccountTypeAPIKey,
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal(service.AccountTypeAPIKey, accounts[0].Type)
},
},
{
name: "filter_by_status",
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "s1", Status: service.StatusActive})
mustCreateAccount(s.T(), client, &service.Account{Name: "s2", Status: service.StatusDisabled})
},
status: service.StatusDisabled,
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Equal(service.StatusDisabled, accounts[0].Status)
},
},
{
name: "filter_by_search",
setup: func(client *dbent.Client) {
mustCreateAccount(s.T(), client, &service.Account{Name: "alpha-account"})
mustCreateAccount(s.T(), client, &service.Account{Name: "beta-account"})
},
search: "alpha",
wantCount: 1,
validate: func(accounts []service.Account) {
s.Require().Contains(accounts[0].Name, "alpha")
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
// 每个 case 重新获取隔离资源
tx := testEntTx(s.T())
client := tx.Client()
repo := newAccountRepositoryWithSQL(client, tx)
ctx := context.Background()
tt.setup(client)
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search)
s.Require().NoError(err)
s.Require().Len(accounts, tt.wantCount)
if tt.validate != nil {
tt.validate(accounts)
}
})
}
}
// --- ListByGroup / ListActive / ListByPlatform ---
func (s *AccountRepoSuite) TestListByGroup() {
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-list"})
acc1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a1", Status: service.StatusActive})
acc2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a2", Status: service.StatusActive})
mustBindAccountToGroup(s.T(), s.client, acc1.ID, group.ID, 2)
mustBindAccountToGroup(s.T(), s.client, acc2.ID, group.ID, 1)
accounts, err := s.repo.ListByGroup(s.ctx, group.ID)
s.Require().NoError(err, "ListByGroup")
s.Require().Len(accounts, 2)
// Should be ordered by priority
s.Require().Equal(acc2.ID, accounts[0].ID, "expected acc2 first (priority=1)")
}
func (s *AccountRepoSuite) TestListActive() {
mustCreateAccount(s.T(), s.client, &service.Account{Name: "active1", Status: service.StatusActive})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "inactive1", Status: service.StatusDisabled})
accounts, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
s.Require().Len(accounts, 1)
s.Require().Equal("active1", accounts[0].Name)
}
func (s *AccountRepoSuite) TestListByPlatform() {
mustCreateAccount(s.T(), s.client, &service.Account{Name: "p1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "p2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
accounts, err := s.repo.ListByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err, "ListByPlatform")
s.Require().Len(accounts, 1)
s.Require().Equal(service.PlatformAnthropic, accounts[0].Platform)
}
// --- Preload and VirtualFields ---
func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
proxy := mustCreateProxy(s.T(), s.client, &service.Proxy{Name: "p1"})
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g1"})
account := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "acc1",
ProxyID: &proxy.ID,
})
mustBindAccountToGroup(s.T(), s.client, account.ID, group.ID, 1)
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err, "GetByID")
s.Require().NotNil(got.Proxy, "expected Proxy preload")
s.Require().Equal(proxy.ID, got.Proxy.ID)
s.Require().Len(got.GroupIDs, 1, "expected GroupIDs to be populated")
s.Require().Equal(group.ID, got.GroupIDs[0])
s.Require().Len(got.Groups, 1, "expected Groups to be populated")
s.Require().Equal(group.ID, got.Groups[0].ID)
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc")
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(accounts, 1)
s.Require().NotNil(accounts[0].Proxy, "expected Proxy preload in list")
s.Require().Equal(proxy.ID, accounts[0].Proxy.ID)
s.Require().Len(accounts[0].GroupIDs, 1, "expected GroupIDs in list")
s.Require().Equal(group.ID, accounts[0].GroupIDs[0])
}
// --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups ---
func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
g1 := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g1"})
g2 := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc"})
s.Require().NoError(s.repo.AddToGroup(s.ctx, account.ID, g1.ID, 10), "AddToGroup")
groups, err := s.repo.GetGroups(s.ctx, account.ID)
s.Require().NoError(err, "GetGroups")
s.Require().Len(groups, 1, "expected 1 group")
s.Require().Equal(g1.ID, groups[0].ID)
s.Require().NoError(s.repo.RemoveFromGroup(s.ctx, account.ID, g1.ID), "RemoveFromGroup")
groups, err = s.repo.GetGroups(s.ctx, account.ID)
s.Require().NoError(err, "GetGroups after remove")
s.Require().Empty(groups, "expected 0 groups after remove")
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{g1.ID, g2.ID}), "BindGroups")
groups, err = s.repo.GetGroups(s.ctx, account.ID)
s.Require().NoError(err, "GetGroups after bind")
s.Require().Len(groups, 2, "expected 2 groups after bind")
}
func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-empty"})
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-empty"})
mustBindAccountToGroup(s.T(), s.client, account.ID, group.ID, 1)
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{}), "BindGroups empty")
groups, err := s.repo.GetGroups(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Empty(groups, "expected 0 groups after binding empty list")
}
// --- Schedulable ---
func (s *AccountRepoSuite) TestListSchedulable() {
now := time.Now()
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-sched"})
okAcc := mustCreateAccount(s.T(), s.client, &service.Account{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.client, okAcc.ID, group.ID, 1)
future := now.Add(10 * time.Minute)
overloaded := mustCreateAccount(s.T(), s.client, &service.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.client, overloaded.ID, group.ID, 1)
sched, err := s.repo.ListSchedulable(s.ctx)
s.Require().NoError(err, "ListSchedulable")
ids := idsOfAccounts(sched)
s.Require().Contains(ids, okAcc.ID)
s.Require().NotContains(ids, overloaded.ID)
}
func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates() {
now := time.Now()
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-sched"})
okAcc := mustCreateAccount(s.T(), s.client, &service.Account{Name: "ok", Schedulable: true})
mustBindAccountToGroup(s.T(), s.client, okAcc.ID, group.ID, 1)
future := now.Add(10 * time.Minute)
overloaded := mustCreateAccount(s.T(), s.client, &service.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
mustBindAccountToGroup(s.T(), s.client, overloaded.ID, group.ID, 1)
rateLimited := mustCreateAccount(s.T(), s.client, &service.Account{Name: "rl", Schedulable: true})
mustBindAccountToGroup(s.T(), s.client, rateLimited.ID, group.ID, 1)
s.Require().NoError(s.repo.SetRateLimited(s.ctx, rateLimited.ID, now.Add(10*time.Minute)), "SetRateLimited")
s.Require().NoError(s.repo.SetError(s.ctx, overloaded.ID, "boom"), "SetError")
sched, err := s.repo.ListSchedulableByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ListSchedulableByGroupID")
s.Require().Len(sched, 1, "expected only ok account schedulable")
s.Require().Equal(okAcc.ID, sched[0].ID)
s.Require().NoError(s.repo.ClearRateLimit(s.ctx, rateLimited.ID), "ClearRateLimit")
sched2, err := s.repo.ListSchedulableByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ListSchedulableByGroupID after ClearRateLimit")
s.Require().Len(sched2, 2, "expected 2 schedulable accounts after ClearRateLimit")
}
func (s *AccountRepoSuite) TestListSchedulableByPlatform() {
mustCreateAccount(s.T(), s.client, &service.Account{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
mustCreateAccount(s.T(), s.client, &service.Account{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err)
s.Require().Len(accounts, 1)
s.Require().Equal(service.PlatformAnthropic, accounts[0].Platform)
}
func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-sp"})
a1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
a2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
mustBindAccountToGroup(s.T(), s.client, a1.ID, group.ID, 1)
mustBindAccountToGroup(s.T(), s.client, a2.ID, group.ID, 2)
accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, service.PlatformAnthropic)
s.Require().NoError(err)
s.Require().Len(accounts, 1)
s.Require().Equal(a1.ID, accounts[0].ID)
}
func (s *AccountRepoSuite) TestSetSchedulable() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-sched", Schedulable: true})
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().False(got.Schedulable)
}
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
func (s *AccountRepoSuite) TestSetOverloaded() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-over"})
until := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().NotNil(got.OverloadUntil)
s.Require().WithinDuration(until, *got.OverloadUntil, time.Second)
}
func (s *AccountRepoSuite) TestSetRateLimited() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-rl"})
resetAt := time.Date(2025, 6, 15, 14, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, resetAt))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().NotNil(got.RateLimitedAt)
s.Require().NotNil(got.RateLimitResetAt)
s.Require().WithinDuration(resetAt, *got.RateLimitResetAt, time.Second)
}
func (s *AccountRepoSuite) TestClearRateLimit() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-clear"})
until := time.Now().Add(1 * time.Hour)
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, until))
s.Require().NoError(s.repo.ClearRateLimit(s.ctx, account.ID))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Nil(got.RateLimitedAt)
s.Require().Nil(got.RateLimitResetAt)
s.Require().Nil(got.OverloadUntil)
}
// --- UpdateLastUsed ---
func (s *AccountRepoSuite) TestUpdateLastUsed() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-used"})
s.Require().Nil(account.LastUsedAt)
s.Require().NoError(s.repo.UpdateLastUsed(s.ctx, account.ID))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().NotNil(got.LastUsedAt)
}
// --- SetError ---
func (s *AccountRepoSuite) TestSetError() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-err", Status: service.StatusActive})
s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong"))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Equal(service.StatusError, got.Status)
s.Require().Equal("something went wrong", got.ErrorMessage)
}
// --- UpdateSessionWindow ---
func (s *AccountRepoSuite) TestUpdateSessionWindow() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-win"})
start := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC)
end := time.Date(2025, 6, 15, 15, 0, 0, 0, time.UTC)
s.Require().NoError(s.repo.UpdateSessionWindow(s.ctx, account.ID, &start, &end, "active"))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().NotNil(got.SessionWindowStart)
s.Require().NotNil(got.SessionWindowEnd)
s.Require().Equal("active", got.SessionWindowStatus)
}
// --- UpdateExtra ---
func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
account := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "acc-extra",
Extra: map[string]any{"a": "1"},
})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"b": "2"}), "UpdateExtra")
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("1", got.Extra["a"])
s.Require().Equal("2", got.Extra["b"])
}
func (s *AccountRepoSuite) TestUpdateExtra_EmptyUpdates() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-extra-empty"})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{}))
}
func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-nil-extra", Extra: nil})
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"key": "val"}))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Equal("val", got.Extra["key"])
}
// --- GetByCRSAccountID ---
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
crsID := "crs-12345"
mustCreateAccount(s.T(), s.client, &service.Account{
Name: "acc-crs",
Extra: map[string]any{"crs_account_id": crsID},
})
got, err := s.repo.GetByCRSAccountID(s.ctx, crsID)
s.Require().NoError(err)
s.Require().NotNil(got)
s.Require().Equal("acc-crs", got.Name)
}
func (s *AccountRepoSuite) TestGetByCRSAccountID_NotFound() {
got, err := s.repo.GetByCRSAccountID(s.ctx, "non-existent")
s.Require().NoError(err)
s.Require().Nil(got)
}
func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() {
got, err := s.repo.GetByCRSAccountID(s.ctx, "")
s.Require().NoError(err)
s.Require().Nil(got)
}
// --- BulkUpdate ---
func (s *AccountRepoSuite) TestBulkUpdate() {
a1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk1", Priority: 1})
a2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk2", Priority: 1})
newPriority := 99
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, service.AccountBulkUpdate{
Priority: &newPriority,
})
s.Require().NoError(err)
s.Require().GreaterOrEqual(affected, int64(1), "expected at least one affected row")
got1, _ := s.repo.GetByID(s.ctx, a1.ID)
got2, _ := s.repo.GetByID(s.ctx, a2.ID)
s.Require().Equal(99, got1.Priority)
s.Require().Equal(99, got2.Priority)
}
func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
a1 := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "bulk-cred",
Credentials: map[string]any{"existing": "value"},
})
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
Credentials: map[string]any{"new_key": "new_value"},
})
s.Require().NoError(err)
got, _ := s.repo.GetByID(s.ctx, a1.ID)
s.Require().Equal("value", got.Credentials["existing"])
s.Require().Equal("new_value", got.Credentials["new_key"])
}
func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() {
a1 := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "bulk-extra",
Extra: map[string]any{"existing": "val"},
})
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
Extra: map[string]any{"new_key": "new_val"},
})
s.Require().NoError(err)
got, _ := s.repo.GetByID(s.ctx, a1.ID)
s.Require().Equal("val", got.Extra["existing"])
s.Require().Equal("new_val", got.Extra["new_key"])
}
func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
affected, err := s.repo.BulkUpdate(s.ctx, []int64{}, service.AccountBulkUpdate{})
s.Require().NoError(err)
s.Require().Zero(affected)
}
func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() {
a1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-empty"})
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{})
s.Require().NoError(err)
s.Require().Zero(affected)
}
func idsOfAccounts(accounts []service.Account) []int64 {
out := make([]int64, 0, len(accounts))
for i := range accounts {
out = append(out, accounts[i].ID)
}
return out
}

View File

@@ -0,0 +1,145 @@
//go:build integration
package repository
import (
"context"
"fmt"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func uniqueTestValue(t *testing.T, prefix string) string {
t.Helper()
safeName := strings.NewReplacer("/", "_", " ", "_").Replace(t.Name())
return fmt.Sprintf("%s-%s", prefix, safeName)
}
func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
entClient := tx.Client()
targetGroup, err := entClient.Group.Create().
SetName(uniqueTestValue(t, "target-group")).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
otherGroup, err := entClient.Group.Create().
SetName(uniqueTestValue(t, "other-group")).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
repo := newUserRepositoryWithSQL(entClient, tx)
u1 := &service.User{
Email: uniqueTestValue(t, "u1") + "@example.com",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
AllowedGroups: []int64{targetGroup.ID, otherGroup.ID},
}
require.NoError(t, repo.Create(ctx, u1))
u2 := &service.User{
Email: uniqueTestValue(t, "u2") + "@example.com",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
AllowedGroups: []int64{targetGroup.ID},
}
require.NoError(t, repo.Create(ctx, u2))
u3 := &service.User{
Email: uniqueTestValue(t, "u3") + "@example.com",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
AllowedGroups: []int64{otherGroup.ID},
}
require.NoError(t, repo.Create(ctx, u3))
affected, err := repo.RemoveGroupFromAllowedGroups(ctx, targetGroup.ID)
require.NoError(t, err)
require.Equal(t, int64(2), affected)
u1After, err := repo.GetByID(ctx, u1.ID)
require.NoError(t, err)
require.NotContains(t, u1After.AllowedGroups, targetGroup.ID)
require.Contains(t, u1After.AllowedGroups, otherGroup.ID)
u2After, err := repo.GetByID(ctx, u2.ID)
require.NoError(t, err)
require.NotContains(t, u2After.AllowedGroups, targetGroup.ID)
}
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
entClient := tx.Client()
targetGroup, err := entClient.Group.Create().
SetName(uniqueTestValue(t, "delete-cascade-target")).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
otherGroup, err := entClient.Group.Create().
SetName(uniqueTestValue(t, "delete-cascade-other")).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
userRepo := newUserRepositoryWithSQL(entClient, tx)
groupRepo := newGroupRepositoryWithSQL(entClient, tx)
apiKeyRepo := NewAPIKeyRepository(entClient)
u := &service.User{
Email: uniqueTestValue(t, "cascade-user") + "@example.com",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
AllowedGroups: []int64{targetGroup.ID, otherGroup.ID},
}
require.NoError(t, userRepo.Create(ctx, u))
key := &service.APIKey{
UserID: u.ID,
Key: uniqueTestValue(t, "sk-test-delete-cascade"),
Name: "test key",
GroupID: &targetGroup.ID,
Status: service.StatusActive,
}
require.NoError(t, apiKeyRepo.Create(ctx, key))
_, err = groupRepo.DeleteCascade(ctx, targetGroup.ID)
require.NoError(t, err)
// Deleted group should be hidden by default queries (soft-delete semantics).
_, err = groupRepo.GetByID(ctx, targetGroup.ID)
require.ErrorIs(t, err, service.ErrGroupNotFound)
activeGroups, err := groupRepo.ListActive(ctx)
require.NoError(t, err)
for _, g := range activeGroups {
require.NotEqual(t, targetGroup.ID, g.ID)
}
// User.allowed_groups should no longer include the deleted group.
uAfter, err := userRepo.GetByID(ctx, u.ID)
require.NoError(t, err)
require.NotContains(t, uAfter.AllowedGroups, targetGroup.ID)
require.Contains(t, uAfter.AllowedGroups, otherGroup.ID)
// API keys bound to the deleted group should have group_id cleared.
keyAfter, err := apiKeyRepo.GetByID(ctx, key.ID)
require.NoError(t, err)
require.Nil(t, keyAfter.GroupID)
}

View File

@@ -0,0 +1,93 @@
package repository
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
apiKeyRateLimitDuration = 24 * time.Hour
apiKeyAuthCachePrefix = "apikey:auth:"
)
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
func apiKeyRateLimitKey(userID int64) string {
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
}
func apiKeyAuthCacheKey(key string) string {
return fmt.Sprintf("%s%s", apiKeyAuthCachePrefix, key)
}
type apiKeyCache struct {
rdb *redis.Client
}
func NewAPIKeyCache(rdb *redis.Client) service.APIKeyCache {
return &apiKeyCache{rdb: rdb}
}
func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
key := apiKeyRateLimitKey(userID)
count, err := c.rdb.Get(ctx, key).Int()
if errors.Is(err, redis.Nil) {
return 0, nil
}
return count, err
}
func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
key := apiKeyRateLimitKey(userID)
pipe := c.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
_, err := pipe.Exec(ctx)
return err
}
func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
key := apiKeyRateLimitKey(userID)
return c.rdb.Del(ctx, key).Err()
}
func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
return c.rdb.Incr(ctx, apiKey).Err()
}
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
return c.rdb.Expire(ctx, apiKey, ttl).Err()
}
func (c *apiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) {
val, err := c.rdb.Get(ctx, apiKeyAuthCacheKey(key)).Bytes()
if err != nil {
return nil, err
}
var entry service.APIKeyAuthCacheEntry
if err := json.Unmarshal(val, &entry); err != nil {
return nil, err
}
return &entry, nil
}
func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error {
if entry == nil {
return nil
}
payload, err := json.Marshal(entry)
if err != nil {
return err
}
return c.rdb.Set(ctx, apiKeyAuthCacheKey(key), payload, ttl).Err()
}
func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err()
}

View File

@@ -0,0 +1,127 @@
//go:build integration
package repository
import (
"context"
"fmt"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ApiKeyCacheSuite struct {
IntegrationRedisSuite
}
func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
}{
{
name: "missing_key_returns_zero_nil",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
userID := int64(1)
count, err := cache.GetCreateAttemptCount(ctx, userID)
require.NoError(s.T(), err, "expected nil error for missing key")
require.Equal(s.T(), 0, count, "expected zero count for missing key")
},
},
{
name: "increment_increases_count_and_sets_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
userID := int64(1)
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount")
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount 2")
count, err := cache.GetCreateAttemptCount(ctx, userID)
require.NoError(s.T(), err, "GetCreateAttemptCount")
require.Equal(s.T(), 2, count, "count mismatch")
ttl, err := rdb.TTL(ctx, key).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, apiKeyRateLimitDuration)
},
},
{
name: "delete_removes_key",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
userID := int64(1)
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID))
require.NoError(s.T(), cache.DeleteCreateAttemptCount(ctx, userID), "DeleteCreateAttemptCount")
count, err := cache.GetCreateAttemptCount(ctx, userID)
require.NoError(s.T(), err, "expected nil error after delete")
require.Equal(s.T(), 0, count, "expected zero count after delete")
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
// 每个 case 重新获取隔离资源
rdb := testRedis(s.T())
cache := &apiKeyCache{rdb: rdb}
ctx := context.Background()
tt.fn(ctx, rdb, cache)
})
}
}
func (s *ApiKeyCacheSuite) TestDailyUsage() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
}{
{
name: "increment_increases_count",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
dailyKey := "daily:sk-test"
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage")
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage 2")
n, err := rdb.Get(ctx, dailyKey).Int()
require.NoError(s.T(), err, "Get dailyKey")
require.Equal(s.T(), 2, n, "expected daily usage=2")
},
},
{
name: "set_expiry_sets_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
dailyKey := "daily:sk-test-expiry"
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey))
require.NoError(s.T(), cache.SetDailyUsageExpiry(ctx, dailyKey, 1*time.Hour), "SetDailyUsageExpiry")
ttl, err := rdb.TTL(ctx, dailyKey).Result()
require.NoError(s.T(), err, "TTL dailyKey")
require.Greater(s.T(), ttl, time.Duration(0), "expected ttl > 0")
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
rdb := testRedis(s.T())
cache := &apiKeyCache{rdb: rdb}
ctx := context.Background()
tt.fn(ctx, rdb, cache)
})
}
}
func TestApiKeyCacheSuite(t *testing.T) {
suite.Run(t, new(ApiKeyCacheSuite))
}

View File

@@ -0,0 +1,46 @@
//go:build unit
package repository
import (
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestApiKeyRateLimitKey(t *testing.T) {
tests := []struct {
name string
userID int64
expected string
}{
{
name: "normal_user_id",
userID: 123,
expected: "apikey:ratelimit:123",
},
{
name: "zero_user_id",
userID: 0,
expected: "apikey:ratelimit:0",
},
{
name: "negative_user_id",
userID: -1,
expected: "apikey:ratelimit:-1",
},
{
name: "max_int64",
userID: math.MaxInt64,
expected: "apikey:ratelimit:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := apiKeyRateLimitKey(tc.userID)
require.Equal(t, tc.expected, got)
})
}
}

View File

@@ -0,0 +1,435 @@
package repository
import (
"context"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
type apiKeyRepository struct {
client *dbent.Client
}
func NewAPIKeyRepository(client *dbent.Client) service.APIKeyRepository {
return &apiKeyRepository{client: client}
}
func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
// 默认过滤已软删除记录,避免删除后仍被查询到。
return r.client.APIKey.Query().Where(apikey.DeletedAtIsNil())
}
func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error {
builder := r.client.APIKey.Create().
SetUserID(key.UserID).
SetKey(key.Key).
SetName(key.Name).
SetStatus(key.Status).
SetNillableGroupID(key.GroupID)
if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist)
}
if len(key.IPBlacklist) > 0 {
builder.SetIPBlacklist(key.IPBlacklist)
}
created, err := builder.Save(ctx)
if err == nil {
key.ID = created.ID
key.CreatedAt = created.CreatedAt
key.UpdatedAt = created.UpdatedAt
}
return translatePersistenceError(err, nil, service.ErrAPIKeyExists)
}
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
m, err := r.activeQuery().
Where(apikey.IDEQ(id)).
WithUser().
WithGroup().
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrAPIKeyNotFound
}
return nil, err
}
return apiKeyEntityToService(m), nil
}
// GetKeyAndOwnerID 根据 API Key ID 获取其 key 与所有者用户ID。
// 相比 GetByID此方法性能更优因为
// - 使用 Select() 只查询必要字段,减少数据传输量
// - 不加载完整的 API Key 实体及其关联数据User、Group 等)
// - 适用于删除等只需 key 与用户 ID 的场景
func (r *apiKeyRepository) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
m, err := r.activeQuery().
Where(apikey.IDEQ(id)).
Select(apikey.FieldKey, apikey.FieldUserID).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return "", 0, service.ErrAPIKeyNotFound
}
return "", 0, err
}
return m.Key, m.UserID, nil
}
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
m, err := r.activeQuery().
Where(apikey.KeyEQ(key)).
WithUser().
WithGroup().
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrAPIKeyNotFound
}
return nil, err
}
return apiKeyEntityToService(m), nil
}
func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
m, err := r.activeQuery().
Where(apikey.KeyEQ(key)).
Select(
apikey.FieldID,
apikey.FieldUserID,
apikey.FieldGroupID,
apikey.FieldStatus,
apikey.FieldIPWhitelist,
apikey.FieldIPBlacklist,
).
WithUser(func(q *dbent.UserQuery) {
q.Select(
user.FieldID,
user.FieldStatus,
user.FieldRole,
user.FieldBalance,
user.FieldConcurrency,
)
}).
WithGroup(func(q *dbent.GroupQuery) {
q.Select(
group.FieldID,
group.FieldName,
group.FieldPlatform,
group.FieldStatus,
group.FieldSubscriptionType,
group.FieldRateMultiplier,
group.FieldDailyLimitUsd,
group.FieldWeeklyLimitUsd,
group.FieldMonthlyLimitUsd,
group.FieldImagePrice1k,
group.FieldImagePrice2k,
group.FieldImagePrice4k,
group.FieldClaudeCodeOnly,
group.FieldFallbackGroupID,
)
}).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrAPIKeyNotFound
}
return nil, err
}
return apiKeyEntityToService(m), nil
}
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
// 之前的实现先检查 Exist 再 UpdateOneID若在两步之间发生软删除
// 则会更新已删除的记录。
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。
// 同时显式设置 updated_at避免二次查询带来的并发可见性问题。
now := time.Now()
builder := r.client.APIKey.Update().
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
SetName(key.Name).
SetStatus(key.Status).
SetUpdatedAt(now)
if key.GroupID != nil {
builder.SetGroupID(*key.GroupID)
} else {
builder.ClearGroupID()
}
// IP 限制字段
if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist)
} else {
builder.ClearIPWhitelist()
}
if len(key.IPBlacklist) > 0 {
builder.SetIPBlacklist(key.IPBlacklist)
} else {
builder.ClearIPBlacklist()
}
affected, err := builder.Save(ctx)
if err != nil {
return err
}
if affected == 0 {
// 更新影响行数为 0说明记录不存在或已被软删除。
return service.ErrAPIKeyNotFound
}
// 使用同一时间戳回填,避免并发删除导致二次查询失败。
key.UpdatedAt = now
return nil
}
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
// 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
affected, err := r.client.APIKey.Update().
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
SetDeletedAt(time.Now()).
Save(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return service.ErrAPIKeyNotFound
}
return err
}
if affected == 0 {
exists, err := r.client.APIKey.Query().
Where(apikey.IDEQ(id)).
Exist(mixins.SkipSoftDelete(ctx))
if err != nil {
return err
}
if exists {
return nil
}
return service.ErrAPIKeyNotFound
}
return nil
}
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
keys, err := q.
WithGroup().
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(apikey.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outKeys := make([]service.APIKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
return outKeys, paginationResultFromTotal(int64(total), params), nil
}
func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
if len(apiKeyIDs) == 0 {
return []int64{}, nil
}
ids, err := r.client.APIKey.Query().
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...), apikey.DeletedAtIsNil()).
IDs(ctx)
if err != nil {
return nil, err
}
return ids, nil
}
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
count, err := r.activeQuery().Where(apikey.UserIDEQ(userID)).Count(ctx)
return int64(count), err
}
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
count, err := r.activeQuery().Where(apikey.KeyEQ(key)).Count(ctx)
return count > 0, err
}
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
q := r.activeQuery().Where(apikey.GroupIDEQ(groupID))
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
keys, err := q.
WithUser().
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(apikey.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outKeys := make([]service.APIKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
return outKeys, paginationResultFromTotal(int64(total), params), nil
}
// SearchAPIKeys searches API keys by user ID and/or keyword (name)
func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
q := r.activeQuery()
if userID > 0 {
q = q.Where(apikey.UserIDEQ(userID))
}
if keyword != "" {
q = q.Where(apikey.NameContainsFold(keyword))
}
keys, err := q.Limit(limit).Order(dbent.Desc(apikey.FieldID)).All(ctx)
if err != nil {
return nil, err
}
outKeys := make([]service.APIKey, 0, len(keys))
for i := range keys {
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
}
return outKeys, nil
}
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
n, err := r.client.APIKey.Update().
Where(apikey.GroupIDEQ(groupID), apikey.DeletedAtIsNil()).
ClearGroupID().
Save(ctx)
return int64(n), err
}
// CountByGroupID 获取分组的 API Key 数量
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
count, err := r.activeQuery().Where(apikey.GroupIDEQ(groupID)).Count(ctx)
return int64(count), err
}
func (r *apiKeyRepository) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
keys, err := r.activeQuery().
Where(apikey.UserIDEQ(userID)).
Select(apikey.FieldKey).
Strings(ctx)
if err != nil {
return nil, err
}
return keys, nil
}
func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
keys, err := r.activeQuery().
Where(apikey.GroupIDEQ(groupID)).
Select(apikey.FieldKey).
Strings(ctx)
if err != nil {
return nil, err
}
return keys, nil
}
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
if m == nil {
return nil
}
out := &service.APIKey{
ID: m.ID,
UserID: m.UserID,
Key: m.Key,
Name: m.Name,
Status: m.Status,
IPWhitelist: m.IPWhitelist,
IPBlacklist: m.IPBlacklist,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
GroupID: m.GroupID,
}
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)
}
if m.Edges.Group != nil {
out.Group = groupEntityToService(m.Edges.Group)
}
return out
}
func userEntityToService(u *dbent.User) *service.User {
if u == nil {
return nil
}
return &service.User{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}
func groupEntityToService(g *dbent.Group) *service.Group {
if g == nil {
return nil
}
return &service.Group{
ID: g.ID,
Name: g.Name,
Description: derefString(g.Description),
Platform: g.Platform,
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
Hydrated: true,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd,
MonthlyLimitUSD: g.MonthlyLimitUsd,
ImagePrice1K: g.ImagePrice1k,
ImagePrice2K: g.ImagePrice2k,
ImagePrice4K: g.ImagePrice4k,
DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}
}
func derefString(s *string) string {
if s == nil {
return ""
}
return *s
}

View File

@@ -0,0 +1,385 @@
//go:build integration
package repository
import (
"context"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
)
type APIKeyRepoSuite struct {
suite.Suite
ctx context.Context
client *dbent.Client
repo *apiKeyRepository
}
func (s *APIKeyRepoSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
s.repo = NewAPIKeyRepository(s.client).(*apiKeyRepository)
}
func TestAPIKeyRepoSuite(t *testing.T) {
suite.Run(t, new(APIKeyRepoSuite))
}
// --- Create / GetByID / GetByKey ---
func (s *APIKeyRepoSuite) TestCreate() {
user := s.mustCreateUser("create@test.com")
key := &service.APIKey{
UserID: user.ID,
Key: "sk-create-test",
Name: "Test Key",
Status: service.StatusActive,
}
err := s.repo.Create(s.ctx, key)
s.Require().NoError(err, "Create")
s.Require().NotZero(key.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, key.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("sk-create-test", got.Key)
}
func (s *APIKeyRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *APIKeyRepoSuite) TestGetByKey() {
user := s.mustCreateUser("getbykey@test.com")
group := s.mustCreateGroup("g-key")
key := &service.APIKey{
UserID: user.ID,
Key: "sk-getbykey",
Name: "My Key",
GroupID: &group.ID,
Status: service.StatusActive,
}
s.Require().NoError(s.repo.Create(s.ctx, key))
got, err := s.repo.GetByKey(s.ctx, key.Key)
s.Require().NoError(err, "GetByKey")
s.Require().Equal(key.ID, got.ID)
s.Require().NotNil(got.User, "expected User preload")
s.Require().Equal(user.ID, got.User.ID)
s.Require().NotNil(got.Group, "expected Group preload")
s.Require().Equal(group.ID, got.Group.ID)
}
func (s *APIKeyRepoSuite) TestGetByKey_NotFound() {
_, err := s.repo.GetByKey(s.ctx, "non-existent-key")
s.Require().Error(err, "expected error for non-existent key")
}
// --- Update ---
func (s *APIKeyRepoSuite) TestUpdate() {
user := s.mustCreateUser("update@test.com")
key := &service.APIKey{
UserID: user.ID,
Key: "sk-update",
Name: "Original",
Status: service.StatusActive,
}
s.Require().NoError(s.repo.Create(s.ctx, key))
key.Name = "Renamed"
key.Status = service.StatusDisabled
err := s.repo.Update(s.ctx, key)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, key.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("sk-update", got.Key, "Update should not change key")
s.Require().Equal(user.ID, got.UserID, "Update should not change user_id")
s.Require().Equal("Renamed", got.Name)
s.Require().Equal(service.StatusDisabled, got.Status)
}
func (s *APIKeyRepoSuite) TestUpdate_ClearGroupID() {
user := s.mustCreateUser("cleargroup@test.com")
group := s.mustCreateGroup("g-clear")
key := &service.APIKey{
UserID: user.ID,
Key: "sk-clear-group",
Name: "Group Key",
GroupID: &group.ID,
Status: service.StatusActive,
}
s.Require().NoError(s.repo.Create(s.ctx, key))
key.GroupID = nil
err := s.repo.Update(s.ctx, key)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, key.ID)
s.Require().NoError(err)
s.Require().Nil(got.GroupID, "expected GroupID to be cleared")
}
// --- Delete ---
func (s *APIKeyRepoSuite) TestDelete() {
user := s.mustCreateUser("delete@test.com")
key := &service.APIKey{
UserID: user.ID,
Key: "sk-delete",
Name: "Delete Me",
Status: service.StatusActive,
}
s.Require().NoError(s.repo.Create(s.ctx, key))
err := s.repo.Delete(s.ctx, key.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, key.ID)
s.Require().Error(err, "expected error after delete")
}
// --- ListByUserID / CountByUserID ---
func (s *APIKeyRepoSuite) TestListByUserID() {
user := s.mustCreateUser("listbyuser@test.com")
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUserID")
s.Require().Len(keys, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *APIKeyRepoSuite) TestListByUserID_Pagination() {
user := s.mustCreateUser("paging@test.com")
for i := 0; i < 5; i++ {
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
}
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
s.Require().NoError(err)
s.Require().Len(keys, 2)
s.Require().Equal(int64(5), page.Total)
s.Require().Equal(3, page.Pages)
}
func (s *APIKeyRepoSuite) TestCountByUserID() {
user := s.mustCreateUser("count@test.com")
s.mustCreateApiKey(user.ID, "sk-count-1", "K1", nil)
s.mustCreateApiKey(user.ID, "sk-count-2", "K2", nil)
count, err := s.repo.CountByUserID(s.ctx, user.ID)
s.Require().NoError(err, "CountByUserID")
s.Require().Equal(int64(2), count)
}
// --- ListByGroupID / CountByGroupID ---
func (s *APIKeyRepoSuite) TestListByGroupID() {
user := s.mustCreateUser("listbygroup@test.com")
group := s.mustCreateGroup("g-list")
s.mustCreateApiKey(user.ID, "sk-grp-1", "K1", &group.ID)
s.mustCreateApiKey(user.ID, "sk-grp-2", "K2", &group.ID)
s.mustCreateApiKey(user.ID, "sk-grp-3", "K3", nil) // no group
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByGroupID")
s.Require().Len(keys, 2)
s.Require().Equal(int64(2), page.Total)
// User preloaded
s.Require().NotNil(keys[0].User)
}
func (s *APIKeyRepoSuite) TestCountByGroupID() {
user := s.mustCreateUser("countgroup@test.com")
group := s.mustCreateGroup("g-count")
s.mustCreateApiKey(user.ID, "sk-gc-1", "K1", &group.ID)
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID")
s.Require().Equal(int64(1), count)
}
// --- ExistsByKey ---
func (s *APIKeyRepoSuite) TestExistsByKey() {
user := s.mustCreateUser("exists@test.com")
s.mustCreateApiKey(user.ID, "sk-exists", "K", nil)
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
s.Require().NoError(err, "ExistsByKey")
s.Require().True(exists)
notExists, err := s.repo.ExistsByKey(s.ctx, "sk-not-exists")
s.Require().NoError(err)
s.Require().False(notExists)
}
// --- SearchAPIKeys ---
func (s *APIKeyRepoSuite) TestSearchAPIKeys() {
user := s.mustCreateUser("search@test.com")
s.mustCreateApiKey(user.ID, "sk-search-1", "Production Key", nil)
s.mustCreateApiKey(user.ID, "sk-search-2", "Development Key", nil)
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "prod", 10)
s.Require().NoError(err, "SearchAPIKeys")
s.Require().Len(found, 1)
s.Require().Contains(found[0].Name, "Production")
}
func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoKeyword() {
user := s.mustCreateUser("searchnokw@test.com")
s.mustCreateApiKey(user.ID, "sk-nk-1", "K1", nil)
s.mustCreateApiKey(user.ID, "sk-nk-2", "K2", nil)
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "", 10)
s.Require().NoError(err)
s.Require().Len(found, 2)
}
func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoUserID() {
user := s.mustCreateUser("searchnouid@test.com")
s.mustCreateApiKey(user.ID, "sk-nu-1", "TestKey", nil)
found, err := s.repo.SearchAPIKeys(s.ctx, 0, "testkey", 10)
s.Require().NoError(err)
s.Require().Len(found, 1)
}
// --- ClearGroupIDByGroupID ---
func (s *APIKeyRepoSuite) TestClearGroupIDByGroupID() {
user := s.mustCreateUser("cleargrp@test.com")
group := s.mustCreateGroup("g-clear-bulk")
k1 := s.mustCreateApiKey(user.ID, "sk-clr-1", "K1", &group.ID)
k2 := s.mustCreateApiKey(user.ID, "sk-clr-2", "K2", &group.ID)
s.mustCreateApiKey(user.ID, "sk-clr-3", "K3", nil) // no group
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ClearGroupIDByGroupID")
s.Require().Equal(int64(2), affected)
got1, _ := s.repo.GetByID(s.ctx, k1.ID)
got2, _ := s.repo.GetByID(s.ctx, k2.ID)
s.Require().Nil(got1.GroupID)
s.Require().Nil(got2.GroupID)
count, _ := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().Zero(count)
}
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
user := s.mustCreateUser("k@example.com")
group := s.mustCreateGroup("g-k")
key := s.mustCreateApiKey(user.ID, "sk-test-1", "My Key", &group.ID)
key.GroupID = &group.ID
got, err := s.repo.GetByKey(s.ctx, key.Key)
s.Require().NoError(err, "GetByKey")
s.Require().Equal(key.ID, got.ID)
s.Require().NotNil(got.User)
s.Require().Equal(user.ID, got.User.ID)
s.Require().NotNil(got.Group)
s.Require().Equal(group.ID, got.Group.ID)
key.Name = "Renamed"
key.Status = service.StatusDisabled
key.GroupID = nil
s.Require().NoError(s.repo.Update(s.ctx, key), "Update")
got2, err := s.repo.GetByID(s.ctx, key.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("sk-test-1", got2.Key, "Update should not change key")
s.Require().Equal(user.ID, got2.UserID, "Update should not change user_id")
s.Require().Equal("Renamed", got2.Name)
s.Require().Equal(service.StatusDisabled, got2.Status)
s.Require().Nil(got2.GroupID)
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByUserID")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(keys, 1)
exists, err := s.repo.ExistsByKey(s.ctx, "sk-test-1")
s.Require().NoError(err, "ExistsByKey")
s.Require().True(exists, "expected key to exist")
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "renam", 10)
s.Require().NoError(err, "SearchAPIKeys")
s.Require().Len(found, 1)
s.Require().Equal(key.ID, found[0].ID)
// ClearGroupIDByGroupID
k2 := s.mustCreateApiKey(user.ID, "sk-test-2", "Group Key", &group.ID)
k2.GroupID = &group.ID
countBefore, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID")
s.Require().Equal(int64(1), countBefore, "expected 1 key in group before clear")
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "ClearGroupIDByGroupID")
s.Require().Equal(int64(1), affected, "expected 1 affected row")
got3, err := s.repo.GetByID(s.ctx, k2.ID)
s.Require().NoError(err, "GetByID")
s.Require().Nil(got3.GroupID, "expected GroupID cleared")
countAfter, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID after clear")
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
}
func (s *APIKeyRepoSuite) mustCreateUser(email string) *service.User {
s.T().Helper()
u, err := s.client.User.Create().
SetEmail(email).
SetPasswordHash("test-password-hash").
SetStatus(service.StatusActive).
SetRole(service.RoleUser).
Save(s.ctx)
s.Require().NoError(err, "create user")
return userEntityToService(u)
}
func (s *APIKeyRepoSuite) mustCreateGroup(name string) *service.Group {
s.T().Helper()
g, err := s.client.Group.Create().
SetName(name).
SetStatus(service.StatusActive).
Save(s.ctx)
s.Require().NoError(err, "create group")
return groupEntityToService(g)
}
func (s *APIKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, groupID *int64) *service.APIKey {
s.T().Helper()
k := &service.APIKey{
UserID: userID,
Key: key,
Name: name,
GroupID: groupID,
Status: service.StatusActive,
}
s.Require().NoError(s.repo.Create(s.ctx, k), "create api key")
return k
}

View File

@@ -0,0 +1,183 @@
package repository
import (
"context"
"errors"
"fmt"
"log"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
billingBalanceKeyPrefix = "billing:balance:"
billingSubKeyPrefix = "billing:sub:"
billingCacheTTL = 5 * time.Minute
)
// billingBalanceKey generates the Redis key for user balance cache.
func billingBalanceKey(userID int64) string {
return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
}
// billingSubKey generates the Redis key for subscription cache.
func billingSubKey(userID, groupID int64) string {
return fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
}
const (
subFieldStatus = "status"
subFieldExpiresAt = "expires_at"
subFieldDailyUsage = "daily_usage"
subFieldWeeklyUsage = "weekly_usage"
subFieldMonthlyUsage = "monthly_usage"
subFieldVersion = "version"
)
var (
deductBalanceScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
return 0
end
local newVal = tonumber(current) - tonumber(ARGV[1])
redis.call('SET', KEYS[1], newVal)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
updateSubUsageScript = redis.NewScript(`
local exists = redis.call('EXISTS', KEYS[1])
if exists == 0 then
return 0
end
local cost = tonumber(ARGV[1])
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
)
type billingCache struct {
rdb *redis.Client
}
func NewBillingCache(rdb *redis.Client) service.BillingCache {
return &billingCache{rdb: rdb}
}
func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
key := billingBalanceKey(userID)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return 0, err
}
return strconv.ParseFloat(val, 64)
}
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
key := billingBalanceKey(userID)
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
}
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
key := billingBalanceKey(userID)
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
}
return nil
}
func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error {
key := billingBalanceKey(userID)
return c.rdb.Del(ctx, key).Err()
}
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*service.SubscriptionCacheData, error) {
key := billingSubKey(userID, groupID)
result, err := c.rdb.HGetAll(ctx, key).Result()
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, redis.Nil
}
return c.parseSubscriptionCache(result)
}
func (c *billingCache) parseSubscriptionCache(data map[string]string) (*service.SubscriptionCacheData, error) {
result := &service.SubscriptionCacheData{}
result.Status = data[subFieldStatus]
if result.Status == "" {
return nil, errors.New("invalid cache: missing status")
}
if expiresStr, ok := data[subFieldExpiresAt]; ok {
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
if err == nil {
result.ExpiresAt = time.Unix(expiresAt, 0)
}
}
if dailyStr, ok := data[subFieldDailyUsage]; ok {
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
}
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
}
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
}
if versionStr, ok := data[subFieldVersion]; ok {
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
}
return result, nil
}
func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *service.SubscriptionCacheData) error {
if data == nil {
return nil
}
key := billingSubKey(userID, groupID)
fields := map[string]any{
subFieldStatus: data.Status,
subFieldExpiresAt: data.ExpiresAt.Unix(),
subFieldDailyUsage: data.DailyUsage,
subFieldWeeklyUsage: data.WeeklyUsage,
subFieldMonthlyUsage: data.MonthlyUsage,
subFieldVersion: data.Version,
}
pipe := c.rdb.Pipeline()
pipe.HSet(ctx, key, fields)
pipe.Expire(ctx, key, billingCacheTTL)
_, err := pipe.Exec(ctx)
return err
}
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
key := billingSubKey(userID, groupID)
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
}
return nil
}
func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
key := billingSubKey(userID, groupID)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,283 @@
//go:build integration
package repository
import (
"context"
"fmt"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type BillingCacheSuite struct {
IntegrationRedisSuite
}
func (s *BillingCacheSuite) TestUserBalance() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
}{
{
name: "missing_key_returns_redis_nil",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
_, err := cache.GetUserBalance(ctx, 1)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing balance key")
},
},
{
name: "deduct_on_nonexistent_is_noop",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(1)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 1), "DeductUserBalance should not error")
_, err := rdb.Get(ctx, balanceKey).Result()
require.ErrorIs(s.T(), err, redis.Nil, "expected missing key after deduct on non-existent")
},
},
{
name: "set_and_get_with_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(2)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
got, err := cache.GetUserBalance(ctx, userID)
require.NoError(s.T(), err, "GetUserBalance")
require.Equal(s.T(), 10.5, got, "balance mismatch")
ttl, err := rdb.TTL(ctx, balanceKey).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
},
},
{
name: "deduct_reduces_balance",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(3)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 2.25), "DeductUserBalance")
got, err := cache.GetUserBalance(ctx, userID)
require.NoError(s.T(), err, "GetUserBalance after deduct")
require.Equal(s.T(), 8.25, got, "deduct mismatch")
},
},
{
name: "invalidate_removes_key",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(100)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 50.0), "SetUserBalance")
exists, err := rdb.Exists(ctx, balanceKey).Result()
require.NoError(s.T(), err, "Exists")
require.Equal(s.T(), int64(1), exists, "expected balance key to exist")
require.NoError(s.T(), cache.InvalidateUserBalance(ctx, userID), "InvalidateUserBalance")
exists, err = rdb.Exists(ctx, balanceKey).Result()
require.NoError(s.T(), err, "Exists after invalidate")
require.Equal(s.T(), int64(0), exists, "expected balance key to be removed after invalidate")
_, err = cache.GetUserBalance(ctx, userID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
},
},
{
name: "deduct_refreshes_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(103)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 100.0), "SetUserBalance")
ttl1, err := rdb.TTL(ctx, balanceKey).Result()
require.NoError(s.T(), err, "TTL before deduct")
s.AssertTTLWithin(ttl1, 1*time.Second, billingCacheTTL)
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 25.0), "DeductUserBalance")
balance, err := cache.GetUserBalance(ctx, userID)
require.NoError(s.T(), err, "GetUserBalance")
require.Equal(s.T(), 75.0, balance, "expected balance 75.0")
ttl2, err := rdb.TTL(ctx, balanceKey).Result()
require.NoError(s.T(), err, "TTL after deduct")
s.AssertTTLWithin(ttl2, 1*time.Second, billingCacheTTL)
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
rdb := testRedis(s.T())
cache := NewBillingCache(rdb)
ctx := context.Background()
tt.fn(ctx, rdb, cache)
})
}
}
func (s *BillingCacheSuite) TestSubscriptionCache() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
}{
{
name: "missing_key_returns_redis_nil",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(10)
groupID := int64(20)
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing subscription key")
},
},
{
name: "update_usage_on_nonexistent_is_noop",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(11)
groupID := int64(21)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 1.0), "UpdateSubscriptionUsage should not error")
exists, err := rdb.Exists(ctx, subKey).Result()
require.NoError(s.T(), err, "Exists")
require.Equal(s.T(), int64(0), exists, "expected missing subscription key after UpdateSubscriptionUsage on non-existent")
},
},
{
name: "set_and_get_with_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(12)
groupID := int64(22)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
data := &service.SubscriptionCacheData{
Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour),
DailyUsage: 1.0,
WeeklyUsage: 2.0,
MonthlyUsage: 3.0,
Version: 7,
}
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
require.NoError(s.T(), err, "GetSubscriptionCache")
require.Equal(s.T(), "active", gotSub.Status)
require.Equal(s.T(), int64(7), gotSub.Version)
require.Equal(s.T(), 1.0, gotSub.DailyUsage)
ttl, err := rdb.TTL(ctx, subKey).Result()
require.NoError(s.T(), err, "TTL subKey")
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
},
},
{
name: "update_usage_increments_all_fields",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(13)
groupID := int64(23)
data := &service.SubscriptionCacheData{
Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour),
DailyUsage: 1.0,
WeeklyUsage: 2.0,
MonthlyUsage: 3.0,
Version: 1,
}
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 0.5), "UpdateSubscriptionUsage")
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
require.NoError(s.T(), err, "GetSubscriptionCache after update")
require.Equal(s.T(), 1.5, gotSub.DailyUsage)
require.Equal(s.T(), 2.5, gotSub.WeeklyUsage)
require.Equal(s.T(), 3.5, gotSub.MonthlyUsage)
},
},
{
name: "invalidate_removes_key",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(101)
groupID := int64(10)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
data := &service.SubscriptionCacheData{
Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour),
DailyUsage: 1.0,
WeeklyUsage: 2.0,
MonthlyUsage: 3.0,
Version: 1,
}
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
exists, err := rdb.Exists(ctx, subKey).Result()
require.NoError(s.T(), err, "Exists")
require.Equal(s.T(), int64(1), exists, "expected subscription key to exist")
require.NoError(s.T(), cache.InvalidateSubscriptionCache(ctx, userID, groupID), "InvalidateSubscriptionCache")
exists, err = rdb.Exists(ctx, subKey).Result()
require.NoError(s.T(), err, "Exists after invalidate")
require.Equal(s.T(), int64(0), exists, "expected subscription key to be removed after invalidate")
_, err = cache.GetSubscriptionCache(ctx, userID, groupID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
},
},
{
name: "missing_status_returns_parsing_error",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(102)
groupID := int64(11)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
fields := map[string]any{
"expires_at": time.Now().Add(1 * time.Hour).Unix(),
"daily_usage": 1.0,
"weekly_usage": 2.0,
"monthly_usage": 3.0,
"version": 1,
}
require.NoError(s.T(), rdb.HSet(ctx, subKey, fields).Err(), "HSet")
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
require.Error(s.T(), err, "expected error for missing status field")
require.NotErrorIs(s.T(), err, redis.Nil, "expected parsing error, not redis.Nil")
require.Equal(s.T(), "invalid cache: missing status", err.Error())
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
rdb := testRedis(s.T())
cache := NewBillingCache(rdb)
ctx := context.Background()
tt.fn(ctx, rdb, cache)
})
}
}
func TestBillingCacheSuite(t *testing.T) {
suite.Run(t, new(BillingCacheSuite))
}

View File

@@ -0,0 +1,87 @@
//go:build unit
package repository
import (
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestBillingBalanceKey(t *testing.T) {
tests := []struct {
name string
userID int64
expected string
}{
{
name: "normal_user_id",
userID: 123,
expected: "billing:balance:123",
},
{
name: "zero_user_id",
userID: 0,
expected: "billing:balance:0",
},
{
name: "negative_user_id",
userID: -1,
expected: "billing:balance:-1",
},
{
name: "max_int64",
userID: math.MaxInt64,
expected: "billing:balance:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := billingBalanceKey(tc.userID)
require.Equal(t, tc.expected, got)
})
}
}
func TestBillingSubKey(t *testing.T) {
tests := []struct {
name string
userID int64
groupID int64
expected string
}{
{
name: "normal_ids",
userID: 123,
groupID: 456,
expected: "billing:sub:123:456",
},
{
name: "zero_ids",
userID: 0,
groupID: 0,
expected: "billing:sub:0:0",
},
{
name: "negative_ids",
userID: -1,
groupID: -2,
expected: "billing:sub:-1:-2",
},
{
name: "max_int64_ids",
userID: math.MaxInt64,
groupID: math.MaxInt64,
expected: "billing:sub:9223372036854775807:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := billingSubKey(tc.userID, tc.groupID)
require.Equal(t, tc.expected, got)
})
}
}

View File

@@ -0,0 +1,248 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
"github.com/imroc/req/v3"
)
func NewClaudeOAuthClient() service.ClaudeOAuthClient {
return &claudeOAuthService{
baseURL: "https://claude.ai",
tokenURL: oauth.TokenURL,
clientFactory: createReqClient,
}
}
type claudeOAuthService struct {
baseURL string
tokenURL string
clientFactory func(proxyURL string) *req.Client
}
func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
client := s.clientFactory(proxyURL)
var orgs []struct {
UUID string `json:"uuid"`
}
targetURL := s.baseURL + "/api/organizations"
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
resp, err := client.R().
SetContext(ctx).
SetCookies(&http.Cookie{
Name: "sessionKey",
Value: sessionKey,
}).
SetSuccessResult(&orgs).
Get(targetURL)
if err != nil {
log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 1 Response - Status: %d", resp.StatusCode)
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
}
if len(orgs) == 0 {
return "", fmt.Errorf("no organizations found")
}
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
return orgs[0].UUID, nil
}
func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
client := s.clientFactory(proxyURL)
authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID)
reqBody := map[string]any{
"response_type": "code",
"client_id": oauth.ClientID,
"organization_uuid": orgUUID,
"redirect_uri": oauth.RedirectURI,
"scope": scope,
"state": state,
"code_challenge": codeChallenge,
"code_challenge_method": "S256",
}
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
var result struct {
RedirectURI string `json:"redirect_uri"`
}
resp, err := client.R().
SetContext(ctx).
SetCookies(&http.Cookie{
Name: "sessionKey",
Value: sessionKey,
}).
SetHeader("Accept", "application/json").
SetHeader("Accept-Language", "en-US,en;q=0.9").
SetHeader("Cache-Control", "no-cache").
SetHeader("Origin", "https://claude.ai").
SetHeader("Referer", "https://claude.ai/new").
SetHeader("Content-Type", "application/json").
SetBody(reqBody).
SetSuccessResult(&result).
Post(authURL)
if err != nil {
log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err)
return "", fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
if !resp.IsSuccessState() {
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
}
if result.RedirectURI == "" {
return "", fmt.Errorf("no redirect_uri in response")
}
parsedURL, err := url.Parse(result.RedirectURI)
if err != nil {
return "", fmt.Errorf("failed to parse redirect_uri: %w", err)
}
queryParams := parsedURL.Query()
authCode := queryParams.Get("code")
responseState := queryParams.Get("state")
if authCode == "" {
return "", fmt.Errorf("no authorization code in redirect_uri")
}
fullCode := authCode
if responseState != "" {
fullCode = authCode + "#" + responseState
}
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code")
return fullCode, nil
}
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
client := s.clientFactory(proxyURL)
// Parse code which may contain state in format "authCode#state"
authCode := code
codeState := ""
if idx := strings.Index(code, "#"); idx != -1 {
authCode = code[:idx]
codeState = code[idx+1:]
}
reqBody := map[string]any{
"code": authCode,
"grant_type": "authorization_code",
"client_id": oauth.ClientID,
"redirect_uri": oauth.RedirectURI,
"code_verifier": codeVerifier,
}
if codeState != "" {
reqBody["state"] = codeState
}
// Setup token requires longer expiration (1 year)
if isSetupToken {
reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds
}
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
var tokenResp oauth.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetHeader("Content-Type", "application/json").
SetBody(reqBody).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)
if err != nil {
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
return nil, fmt.Errorf("request failed: %w", err)
}
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
}
log.Printf("[OAuth] Step 3 SUCCESS - Got access token")
return &tokenResp, nil
}
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
client := s.clientFactory(proxyURL)
// 使用 JSON 格式(与 ExchangeCodeForToken 保持一致)
// Anthropic OAuth API 期望 JSON 格式的请求体
reqBody := map[string]any{
"grant_type": "refresh_token",
"refresh_token": refreshToken,
"client_id": oauth.ClientID,
}
var tokenResp oauth.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetHeader("Content-Type", "application/json").
SetBody(reqBody).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
}
func createReqClient(proxyURL string) *req.Client {
// 禁用 CookieJar确保每次授权都是干净的会话
client := req.C().
SetTimeout(60 * time.Second).
ImpersonateChrome().
SetCookieJar(nil) // 禁用 CookieJar
if strings.TrimSpace(proxyURL) != "" {
client.SetProxyURL(strings.TrimSpace(proxyURL))
}
return client
}

View File

@@ -0,0 +1,396 @@
package repository
import (
"context"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
"github.com/imroc/req/v3"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ClaudeOAuthServiceSuite struct {
suite.Suite
client *claudeOAuthService
}
// requestCapture holds captured request data for assertions in the main goroutine.
type requestCapture struct {
path string
method string
cookies []*http.Cookie
body []byte
bodyJSON map[string]any
contentType string
}
func newTestReqClient(rt http.RoundTripper) *req.Client {
c := req.C()
c.GetClient().Transport = rt
return c
}
func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
tests := []struct {
name string
handler http.HandlerFunc
wantErr bool
errContain string
wantUUID string
validate func(captured requestCapture)
}{
{
name: "success",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`[{"uuid":"org-1"}]`))
},
wantUUID: "org-1",
validate: func(captured requestCapture) {
require.Equal(s.T(), "/api/organizations", captured.path, "unexpected path")
require.Len(s.T(), captured.cookies, 1, "expected 1 cookie")
require.Equal(s.T(), "sessionKey", captured.cookies[0].Name)
require.Equal(s.T(), "sess", captured.cookies[0].Value)
},
},
{
name: "non_200_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte("unauthorized"))
},
wantErr: true,
errContain: "401",
},
{
name: "invalid_json_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte("not-json"))
},
wantErr: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var captured requestCapture
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.path = r.URL.Path
captured.cookies = r.Cookies()
tt.handler(w, r)
}), nil)
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.baseURL = "http://in-process"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "")
if tt.wantErr {
require.Error(s.T(), err)
if tt.errContain != "" {
require.ErrorContains(s.T(), err, tt.errContain)
}
return
}
require.NoError(s.T(), err)
require.Equal(s.T(), tt.wantUUID, got)
if tt.validate != nil {
tt.validate(captured)
}
})
}
}
func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
tests := []struct {
name string
handler http.HandlerFunc
wantErr bool
wantCode string
validate func(captured requestCapture)
}{
{
name: "parses_redirect_uri",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{
"redirect_uri": oauth.RedirectURI + "?code=AUTH&state=STATE",
})
},
wantCode: "AUTH#STATE",
validate: func(captured requestCapture) {
require.True(s.T(), strings.HasPrefix(captured.path, "/v1/oauth/") && strings.HasSuffix(captured.path, "/authorize"), "unexpected path: %s", captured.path)
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
require.Len(s.T(), captured.cookies, 1, "expected 1 cookie")
require.Equal(s.T(), "sess", captured.cookies[0].Value)
require.Equal(s.T(), "org-1", captured.bodyJSON["organization_uuid"])
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
require.Equal(s.T(), "st", captured.bodyJSON["state"])
},
},
{
name: "missing_code_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{
"redirect_uri": oauth.RedirectURI + "?state=STATE", // no code
})
},
wantErr: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var captured requestCapture
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.path = r.URL.Path
captured.method = r.Method
captured.cookies = r.Cookies()
captured.body, _ = io.ReadAll(r.Body)
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r)
}), nil)
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.baseURL = "http://in-process"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "")
if tt.wantErr {
require.Error(s.T(), err)
return
}
require.NoError(s.T(), err)
require.Equal(s.T(), tt.wantCode, code)
if tt.validate != nil {
tt.validate(captured)
}
})
}
}
func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
tests := []struct {
name string
handler http.HandlerFunc
code string
isSetupToken bool
wantErr bool
wantResp *oauth.TokenResponse
validate func(captured requestCapture)
}{
{
name: "sends_state_when_embedded",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
AccessToken: "at",
TokenType: "bearer",
ExpiresIn: 3600,
RefreshToken: "rt",
Scope: "s",
})
},
code: "AUTH#STATE2",
isSetupToken: false,
wantResp: &oauth.TokenResponse{
AccessToken: "at",
RefreshToken: "rt",
},
validate: func(captured requestCapture) {
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
require.True(s.T(), strings.HasPrefix(captured.contentType, "application/json"), "unexpected content-type")
require.Equal(s.T(), "AUTH", captured.bodyJSON["code"])
require.Equal(s.T(), "STATE2", captured.bodyJSON["state"])
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
require.Equal(s.T(), "ver", captured.bodyJSON["code_verifier"])
// Regular OAuth should not include expires_in
require.Nil(s.T(), captured.bodyJSON["expires_in"], "regular OAuth should not include expires_in")
},
},
{
name: "setup_token_includes_expires_in",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
AccessToken: "at",
TokenType: "bearer",
ExpiresIn: 31536000,
})
},
code: "AUTH",
isSetupToken: true,
wantResp: &oauth.TokenResponse{
AccessToken: "at",
},
validate: func(captured requestCapture) {
// Setup token should include expires_in with 1 year value
require.Equal(s.T(), float64(31536000), captured.bodyJSON["expires_in"],
"setup token should include expires_in: 31536000")
},
},
{
name: "non_200_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte("bad request"))
},
code: "AUTH",
isSetupToken: false,
wantErr: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var captured requestCapture
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.method = r.Method
captured.contentType = r.Header.Get("Content-Type")
captured.body, _ = io.ReadAll(r.Body)
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r)
}), nil)
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.tokenURL = "http://in-process/token"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken)
if tt.wantErr {
require.Error(s.T(), err)
return
}
require.NoError(s.T(), err)
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
require.Equal(s.T(), tt.wantResp.RefreshToken, resp.RefreshToken)
if tt.validate != nil {
tt.validate(captured)
}
})
}
}
func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
tests := []struct {
name string
handler http.HandlerFunc
wantErr bool
wantResp *oauth.TokenResponse
validate func(captured requestCapture)
}{
{
name: "sends_json_format",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
AccessToken: "new_access_token",
TokenType: "bearer",
ExpiresIn: 28800,
RefreshToken: "new_refresh_token",
Scope: "user:profile user:inference",
})
},
wantResp: &oauth.TokenResponse{
AccessToken: "new_access_token",
RefreshToken: "new_refresh_token",
},
validate: func(captured requestCapture) {
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
// 验证使用 JSON 格式(不是 form 格式)
require.True(s.T(), strings.HasPrefix(captured.contentType, "application/json"),
"expected JSON content-type, got: %s", captured.contentType)
// 验证 JSON body 内容
require.Equal(s.T(), "refresh_token", captured.bodyJSON["grant_type"])
require.Equal(s.T(), "rt", captured.bodyJSON["refresh_token"])
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
},
},
{
name: "returns_new_refresh_token",
handler: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
AccessToken: "at",
TokenType: "bearer",
ExpiresIn: 28800,
RefreshToken: "rotated_rt", // Anthropic rotates refresh tokens
})
},
wantResp: &oauth.TokenResponse{
AccessToken: "at",
RefreshToken: "rotated_rt",
},
},
{
name: "non_200_returns_error",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"error":"invalid_grant"}`))
},
wantErr: true,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var captured requestCapture
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.method = r.Method
captured.contentType = r.Header.Get("Content-Type")
captured.body, _ = io.ReadAll(r.Body)
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
tt.handler(w, r)
}), nil)
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.tokenURL = "http://in-process/token"
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
resp, err := s.client.RefreshToken(context.Background(), "rt", "")
if tt.wantErr {
require.Error(s.T(), err)
return
}
require.NoError(s.T(), err)
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
require.Equal(s.T(), tt.wantResp.RefreshToken, resp.RefreshToken)
if tt.validate != nil {
tt.validate(captured)
}
})
}
}
func TestClaudeOAuthServiceSuite(t *testing.T) {
suite.Run(t, new(ClaudeOAuthServiceSuite))
}

View File

@@ -0,0 +1,62 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
type claudeUsageService struct {
usageURL string
allowPrivateHosts bool
}
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
return &claudeUsageService{usageURL: defaultClaudeUsageURL}
}
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 30 * time.Second,
ValidateResolvedIP: true,
AllowPrivateHosts: s.allowPrivateHosts,
})
if err != nil {
client = &http.Client{Timeout: 30 * time.Second}
}
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
}
var usageResp service.ClaudeUsageResponse
if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil {
return nil, fmt.Errorf("decode response failed: %w", err)
}
return &usageResp, nil
}

View File

@@ -0,0 +1,117 @@
package repository
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ClaudeUsageServiceSuite struct {
suite.Suite
srv *httptest.Server
fetcher *claudeUsageService
}
func (s *ClaudeUsageServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
// usageRequestCapture holds captured request data for assertions in the main goroutine.
type usageRequestCapture struct {
authorization string
anthropicBeta string
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
var captured usageRequestCapture
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
captured.authorization = r.Header.Get("Authorization")
captured.anthropicBeta = r.Header.Get("anthropic-beta")
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{
"five_hour": {"utilization": 12.5, "resets_at": "2025-01-01T00:00:00Z"},
"seven_day": {"utilization": 34.0, "resets_at": "2025-01-08T00:00:00Z"},
"seven_day_sonnet": {"utilization": 56.0, "resets_at": "2025-01-08T00:00:00Z"}
}`)
}))
s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
require.NoError(s.T(), err, "FetchUsage")
require.Equal(s.T(), 12.5, resp.FiveHour.Utilization, "FiveHour utilization mismatch")
require.Equal(s.T(), 34.0, resp.SevenDay.Utilization, "SevenDay utilization mismatch")
require.Equal(s.T(), 56.0, resp.SevenDaySonnet.Utilization, "SevenDaySonnet utilization mismatch")
// Assertions on captured request data
require.Equal(s.T(), "Bearer at", captured.authorization, "Authorization header mismatch")
require.Equal(s.T(), "oauth-2025-04-20", captured.anthropicBeta, "anthropic-beta header mismatch")
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = io.WriteString(w, "nope")
}))
s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "status 401")
require.ErrorContains(s.T(), err, "nope")
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-json")
}))
s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "decode response failed")
}
func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Never respond - simulate slow server
<-r.Context().Done()
}))
s.fetcher = &claudeUsageService{
usageURL: s.srv.URL,
allowPrivateHosts: true,
}
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
_, err := s.fetcher.FetchUsage(ctx, "at", "")
require.Error(s.T(), err, "expected error for cancelled context")
}
func TestClaudeUsageServiceSuite(t *testing.T) {
suite.Run(t, new(ClaudeUsageServiceSuite))
}

View File

@@ -0,0 +1,391 @@
package repository
import (
"context"
"errors"
"fmt"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// 并发控制缓存常量定义
//
// 性能优化说明:
// 原实现使用 SCAN 命令遍历独立的槽位键concurrency:account:{id}:{requestID}
// 在高并发场景下 SCAN 需要多次往返,且遍历大量键时性能下降明显。
//
// 新实现改用 Redis 有序集合Sorted Set
// 1. 每个账号/用户只有一个键,成员为 requestID分数为时间戳
// 2. 使用 ZCARD 原子获取并发数,时间复杂度 O(1)
// 3. 使用 ZREMRANGEBYSCORE 清理过期槽位,避免手动管理 TTL
// 4. 单次 Redis 调用完成计数,减少网络往返
const (
// 并发槽位键前缀(有序集合)
// 格式: concurrency:account:{accountID}
accountSlotKeyPrefix = "concurrency:account:"
// 格式: concurrency:user:{userID}
userSlotKeyPrefix = "concurrency:user:"
// 等待队列计数器格式: concurrency:wait:{userID}
waitQueueKeyPrefix = "concurrency:wait:"
// 账号级等待队列计数器格式: wait:account:{accountID}
accountWaitKeyPrefix = "wait:account:"
// 默认槽位过期时间(分钟),可通过配置覆盖
defaultSlotTTLMinutes = 15
)
var (
// acquireScript 使用有序集合计数并在未达上限时添加槽位
// 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步问题
// KEYS[1] = 有序集合键 (concurrency:account:{id} / concurrency:user:{id})
// ARGV[1] = maxConcurrency
// ARGV[2] = TTL
// ARGV[3] = requestID
acquireScript = redis.NewScript(`
local key = KEYS[1]
local maxConcurrency = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local requestID = ARGV[3]
-- 使用 Redis 服务器时间,确保多实例时钟一致
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
-- 清理过期槽位
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
-- 检查是否已存在(支持重试场景刷新时间戳)
local exists = redis.call('ZSCORE', key, requestID)
if exists ~= false then
redis.call('ZADD', key, now, requestID)
redis.call('EXPIRE', key, ttl)
return 1
end
-- 检查是否达到并发上限
local count = redis.call('ZCARD', key)
if count < maxConcurrency then
redis.call('ZADD', key, now, requestID)
redis.call('EXPIRE', key, ttl)
return 1
end
return 0
`)
// getCountScript 统计有序集合中的槽位数量并清理过期条目
// 使用 Redis TIME 命令获取服务器时间
// KEYS[1] = 有序集合键
// ARGV[1] = TTL
getCountScript = redis.NewScript(`
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
-- 使用 Redis 服务器时间
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
return redis.call('ZCARD', key)
`)
// incrementWaitScript - refreshes TTL on each increment to keep queue depth accurate
// KEYS[1] = wait queue key
// ARGV[1] = maxWait
// ARGV[2] = TTL in seconds
incrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current >= tonumber(ARGV[1]) then
return 0
end
local newVal = redis.call('INCR', KEYS[1])
-- Refresh TTL so long-running traffic doesn't expire active queue counters.
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
// incrementAccountWaitScript - account-level wait queue count (refresh TTL on each increment)
incrementAccountWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
current = 0
else
current = tonumber(current)
end
if current >= tonumber(ARGV[1]) then
return 0
end
local newVal = redis.call('INCR', KEYS[1])
-- Refresh TTL so long-running traffic doesn't expire active queue counters.
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
// decrementWaitScript - same as before
decrementWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current ~= false and tonumber(current) > 0 then
redis.call('DECR', KEYS[1])
end
return 1
`)
// getAccountsLoadBatchScript - batch load query with expired slot cleanup
// ARGV[1] = slot TTL (seconds)
// ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ...
getAccountsLoadBatchScript = redis.NewScript(`
local result = {}
local slotTTL = tonumber(ARGV[1])
-- Get current server time
local timeResult = redis.call('TIME')
local nowSeconds = tonumber(timeResult[1])
local cutoffTime = nowSeconds - slotTTL
local i = 2
while i <= #ARGV do
local accountID = ARGV[i]
local maxConcurrency = tonumber(ARGV[i + 1])
local slotKey = 'concurrency:account:' .. accountID
-- Clean up expired slots before counting
redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime)
local currentConcurrency = redis.call('ZCARD', slotKey)
local waitKey = 'wait:account:' .. accountID
local waitingCount = redis.call('GET', waitKey)
if waitingCount == false then
waitingCount = 0
else
waitingCount = tonumber(waitingCount)
end
local loadRate = 0
if maxConcurrency > 0 then
loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency)
end
table.insert(result, accountID)
table.insert(result, currentConcurrency)
table.insert(result, waitingCount)
table.insert(result, loadRate)
i = i + 2
end
return result
`)
// cleanupExpiredSlotsScript - remove expired slots
// KEYS[1] = concurrency:account:{accountID}
// ARGV[1] = TTL (seconds)
cleanupExpiredSlotsScript = redis.NewScript(`
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - ttl
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
`)
)
type concurrencyCache struct {
rdb *redis.Client
slotTTLSeconds int // 槽位过期时间(秒)
waitQueueTTLSeconds int // 等待队列过期时间(秒)
}
// NewConcurrencyCache 创建并发控制缓存
// slotTTLMinutes: 槽位过期时间分钟0 或负数使用默认值 15 分钟
// waitQueueTTLSeconds: 等待队列过期时间0 或负数使用 slot TTL
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int, waitQueueTTLSeconds int) service.ConcurrencyCache {
if slotTTLMinutes <= 0 {
slotTTLMinutes = defaultSlotTTLMinutes
}
if waitQueueTTLSeconds <= 0 {
waitQueueTTLSeconds = slotTTLMinutes * 60
}
return &concurrencyCache{
rdb: rdb,
slotTTLSeconds: slotTTLMinutes * 60,
waitQueueTTLSeconds: waitQueueTTLSeconds,
}
}
// Helper functions for key generation
func accountSlotKey(accountID int64) string {
return fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
}
func userSlotKey(userID int64) string {
return fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
}
func waitQueueKey(userID int64) string {
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
}
func accountWaitKey(accountID int64) string {
return fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
}
// Account slot operations
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
key := accountSlotKey(accountID)
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
key := accountSlotKey(accountID)
return c.rdb.ZRem(ctx, key, requestID).Err()
}
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
key := accountSlotKey(accountID)
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
if err != nil {
return 0, err
}
return result, nil
}
// User slot operations
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
key := userSlotKey(userID)
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
key := userSlotKey(userID)
return c.rdb.ZRem(ctx, key, requestID).Err()
}
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
key := userSlotKey(userID)
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
if err != nil {
return 0, err
}
return result, nil
}
// Wait queue operations
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
key := waitQueueKey(userID)
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
key := waitQueueKey(userID)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}
// Account wait queue operations
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
key := accountWaitKey(accountID)
result, err := incrementAccountWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
key := accountWaitKey(accountID)
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
return err
}
func (c *concurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
key := accountWaitKey(accountID)
val, err := c.rdb.Get(ctx, key).Int()
if err != nil && !errors.Is(err, redis.Nil) {
return 0, err
}
if errors.Is(err, redis.Nil) {
return 0, nil
}
return val, nil
}
func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
if len(accounts) == 0 {
return map[int64]*service.AccountLoadInfo{}, nil
}
args := []any{c.slotTTLSeconds}
for _, acc := range accounts {
args = append(args, acc.ID, acc.MaxConcurrency)
}
result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice()
if err != nil {
return nil, err
}
loadMap := make(map[int64]*service.AccountLoadInfo)
for i := 0; i < len(result); i += 4 {
if i+3 >= len(result) {
break
}
accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64)
currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1]))
waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2]))
loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3]))
loadMap[accountID] = &service.AccountLoadInfo{
AccountID: accountID,
CurrentConcurrency: currentConcurrency,
WaitingCount: waitingCount,
LoadRate: loadRate,
}
}
return loadMap, nil
}
func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
key := accountSlotKey(accountID)
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
return err
}

View File

@@ -0,0 +1,135 @@
package repository
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/redis/go-redis/v9"
)
// 基准测试用 TTL 配置
const benchSlotTTLMinutes = 15
var benchSlotTTL = time.Duration(benchSlotTTLMinutes) * time.Minute
// BenchmarkAccountConcurrency 用于对比 SCAN 与有序集合的计数性能。
func BenchmarkAccountConcurrency(b *testing.B) {
rdb := newBenchmarkRedisClient(b)
defer func() {
_ = rdb.Close()
}()
cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes, int(benchSlotTTL.Seconds())).(*concurrencyCache)
ctx := context.Background()
for _, size := range []int{10, 100, 1000} {
size := size
b.Run(fmt.Sprintf("zset/slots=%d", size), func(b *testing.B) {
accountID := time.Now().UnixNano()
key := accountSlotKey(accountID)
b.StopTimer()
members := make([]redis.Z, 0, size)
now := float64(time.Now().Unix())
for i := 0; i < size; i++ {
members = append(members, redis.Z{
Score: now,
Member: fmt.Sprintf("req_%d", i),
})
}
if err := rdb.ZAdd(ctx, key, members...).Err(); err != nil {
b.Fatalf("初始化有序集合失败: %v", err)
}
if err := rdb.Expire(ctx, key, benchSlotTTL).Err(); err != nil {
b.Fatalf("设置有序集合 TTL 失败: %v", err)
}
b.StartTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if _, err := cache.GetAccountConcurrency(ctx, accountID); err != nil {
b.Fatalf("获取并发数量失败: %v", err)
}
}
b.StopTimer()
if err := rdb.Del(ctx, key).Err(); err != nil {
b.Fatalf("清理有序集合失败: %v", err)
}
})
b.Run(fmt.Sprintf("scan/slots=%d", size), func(b *testing.B) {
accountID := time.Now().UnixNano()
pattern := fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
keys := make([]string, 0, size)
b.StopTimer()
pipe := rdb.Pipeline()
for i := 0; i < size; i++ {
key := fmt.Sprintf("%s%d:req_%d", accountSlotKeyPrefix, accountID, i)
keys = append(keys, key)
pipe.Set(ctx, key, "1", benchSlotTTL)
}
if _, err := pipe.Exec(ctx); err != nil {
b.Fatalf("初始化扫描键失败: %v", err)
}
b.StartTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if _, err := scanSlotCount(ctx, rdb, pattern); err != nil {
b.Fatalf("SCAN 计数失败: %v", err)
}
}
b.StopTimer()
if err := rdb.Del(ctx, keys...).Err(); err != nil {
b.Fatalf("清理扫描键失败: %v", err)
}
})
}
}
func scanSlotCount(ctx context.Context, rdb *redis.Client, pattern string) (int, error) {
var cursor uint64
count := 0
for {
keys, nextCursor, err := rdb.Scan(ctx, cursor, pattern, 100).Result()
if err != nil {
return 0, err
}
count += len(keys)
if nextCursor == 0 {
break
}
cursor = nextCursor
}
return count, nil
}
func newBenchmarkRedisClient(b *testing.B) *redis.Client {
b.Helper()
redisURL := os.Getenv("TEST_REDIS_URL")
if redisURL == "" {
b.Skip("未设置 TEST_REDIS_URL跳过 Redis 基准测试")
}
opt, err := redis.ParseURL(redisURL)
if err != nil {
b.Fatalf("解析 TEST_REDIS_URL 失败: %v", err)
}
client := redis.NewClient(opt)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
b.Fatalf("Redis 连接失败: %v", err)
}
return client
}

View File

@@ -0,0 +1,412 @@
//go:build integration
package repository
import (
"errors"
"fmt"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
// 测试用 TTL 配置15 分钟,与默认值一致)
const testSlotTTLMinutes = 15
// 测试用 TTL Duration用于 TTL 断言
var testSlotTTL = time.Duration(testSlotTTLMinutes) * time.Minute
type ConcurrencyCacheSuite struct {
IntegrationRedisSuite
cache service.ConcurrencyCache
}
func (s *ConcurrencyCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
accountID := int64(10)
reqID1, reqID2, reqID3 := "req1", "req2", "req3"
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID1)
require.NoError(s.T(), err, "AcquireAccountSlot 1")
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID2)
require.NoError(s.T(), err, "AcquireAccountSlot 2")
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID3)
require.NoError(s.T(), err, "AcquireAccountSlot 3")
require.False(s.T(), ok, "expected third acquire to fail")
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err, "GetAccountConcurrency")
require.Equal(s.T(), 2, cur, "concurrency mismatch")
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID1), "ReleaseAccountSlot")
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err, "GetAccountConcurrency after release")
require.Equal(s.T(), 1, cur, "expected 1 after release")
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
accountID := int64(11)
reqID := "req_ttl_test"
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID)
require.NoError(s.T(), err, "AcquireAccountSlot")
require.True(s.T(), ok)
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() {
accountID := int64(12)
reqID := "dup-req"
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Acquiring with same reqID should be idempotent
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
require.NoError(s.T(), err)
require.True(s.T(), ok)
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 1, cur, "expected concurrency=1 (idempotent)")
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_ReleaseIdempotent() {
accountID := int64(13)
reqID := "release-test"
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 1, reqID)
require.NoError(s.T(), err)
require.True(s.T(), ok)
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot")
// Releasing again should not error
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot again")
// Releasing non-existent should not error
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, "non-existent"), "ReleaseAccountSlot non-existent")
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 0, cur)
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_MaxZero() {
accountID := int64(14)
reqID := "max-zero-test"
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 0, reqID)
require.NoError(s.T(), err)
require.False(s.T(), ok, "expected acquire to fail with max=0")
}
func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() {
userID := int64(42)
reqID1, reqID2 := "req1", "req2"
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID1)
require.NoError(s.T(), err, "AcquireUserSlot")
require.True(s.T(), ok)
ok, err = s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID2)
require.NoError(s.T(), err, "AcquireUserSlot 2")
require.False(s.T(), ok, "expected second acquire to fail at max=1")
cur, err := s.cache.GetUserConcurrency(s.ctx, userID)
require.NoError(s.T(), err, "GetUserConcurrency")
require.Equal(s.T(), 1, cur, "expected concurrency=1")
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, reqID1), "ReleaseUserSlot")
// Releasing a non-existent slot should not error
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, "non-existent"), "ReleaseUserSlot non-existent")
cur, err = s.cache.GetUserConcurrency(s.ctx, userID)
require.NoError(s.T(), err, "GetUserConcurrency after release")
require.Equal(s.T(), 0, cur, "expected concurrency=0 after release")
}
func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
userID := int64(200)
reqID := "req_ttl_test"
slotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID)
require.NoError(s.T(), err, "AcquireUserSlot")
require.True(s.T(), ok)
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
}
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
userID := int64(20)
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2)
require.NoError(s.T(), err, "IncrementWaitCount 1")
require.True(s.T(), ok)
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
require.NoError(s.T(), err, "IncrementWaitCount 2")
require.True(s.T(), ok)
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
require.NoError(s.T(), err, "IncrementWaitCount 3")
require.False(s.T(), ok, "expected wait increment over max to fail")
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.Equal(s.T(), 1, val, "expected wait count 1")
}
func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
userID := int64(300)
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
// Test decrement on non-existent key - should not error and should not create negative value
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key")
// Verify no key was created or it's not negative
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty")
// Set count to 1, then decrement twice
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 5)
require.NoError(s.T(), err, "IncrementWaitCount")
require.True(s.T(), ok)
// Decrement once (1 -> 0)
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
// Decrement again on 0 - should not go negative
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero")
// Verify count is 0, not negative
val, err = s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey after double decrement")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
}
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
accountID := int64(30)
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
ok, err := s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 1")
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 2")
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 3")
require.False(s.T(), ok, "expected account wait increment over max to fail")
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL account waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.Equal(s.T(), 1, val, "expected account wait count 1")
}
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() {
accountID := int64(301)
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty")
}
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
// When no slots exist, GetAccountConcurrency should return 0
cur, err := s.cache.GetAccountConcurrency(s.ctx, 999)
require.NoError(s.T(), err)
require.Equal(s.T(), 0, cur)
}
func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
// When no slots exist, GetUserConcurrency should return 0
cur, err := s.cache.GetUserConcurrency(s.ctx, 999)
require.NoError(s.T(), err)
require.Equal(s.T(), 0, cur)
}
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch() {
s.T().Skip("TODO: Fix this test - CurrentConcurrency returns 0 instead of expected value in CI")
// Setup: Create accounts with different load states
account1 := int64(100)
account2 := int64(101)
account3 := int64(102)
// Account 1: 2/3 slots used, 1 waiting
ok, err := s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, account1, 5)
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Account 2: 1/2 slots used, 0 waiting
ok, err = s.cache.AcquireAccountSlot(s.ctx, account2, 2, "req3")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Account 3: 0/1 slots used, 0 waiting (idle)
// Query batch load
accounts := []service.AccountWithConcurrency{
{ID: account1, MaxConcurrency: 3},
{ID: account2, MaxConcurrency: 2},
{ID: account3, MaxConcurrency: 1},
}
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, accounts)
require.NoError(s.T(), err)
require.Len(s.T(), loadMap, 3)
// Verify account1: (2 + 1) / 3 = 100%
load1 := loadMap[account1]
require.NotNil(s.T(), load1)
require.Equal(s.T(), account1, load1.AccountID)
require.Equal(s.T(), 2, load1.CurrentConcurrency)
require.Equal(s.T(), 1, load1.WaitingCount)
require.Equal(s.T(), 100, load1.LoadRate)
// Verify account2: (1 + 0) / 2 = 50%
load2 := loadMap[account2]
require.NotNil(s.T(), load2)
require.Equal(s.T(), account2, load2.AccountID)
require.Equal(s.T(), 1, load2.CurrentConcurrency)
require.Equal(s.T(), 0, load2.WaitingCount)
require.Equal(s.T(), 50, load2.LoadRate)
// Verify account3: (0 + 0) / 1 = 0%
load3 := loadMap[account3]
require.NotNil(s.T(), load3)
require.Equal(s.T(), account3, load3.AccountID)
require.Equal(s.T(), 0, load3.CurrentConcurrency)
require.Equal(s.T(), 0, load3.WaitingCount)
require.Equal(s.T(), 0, load3.LoadRate)
}
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch_Empty() {
// Test with empty account list
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, []service.AccountWithConcurrency{})
require.NoError(s.T(), err)
require.Empty(s.T(), loadMap)
}
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots() {
accountID := int64(200)
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
// Acquire 3 slots
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req3")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Verify 3 slots exist
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 3, cur)
// Manually set old timestamps for req1 and req2 (simulate expired slots)
now := time.Now().Unix()
expiredTime := now - int64(testSlotTTL.Seconds()) - 10 // 10 seconds past TTL
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req1"}).Err()
require.NoError(s.T(), err)
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req2"}).Err()
require.NoError(s.T(), err)
// Run cleanup
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
require.NoError(s.T(), err)
// Verify only 1 slot remains (req3)
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 1, cur)
// Verify req3 still exists
members, err := s.rdb.ZRange(s.ctx, slotKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Len(s.T(), members, 1)
require.Equal(s.T(), "req3", members[0])
}
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
accountID := int64(201)
// Acquire 2 fresh slots
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Run cleanup (should not remove anything)
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
require.NoError(s.T(), err)
// Verify both slots still exist
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 2, cur)
}
func TestConcurrencyCacheSuite(t *testing.T) {
suite.Run(t, new(ConcurrencyCacheSuite))
}

View File

@@ -0,0 +1,392 @@
package repository
import (
"context"
"database/sql"
"fmt"
"log"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
type dashboardAggregationRepository struct {
sql sqlExecutor
}
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
if sqlDB == nil {
return nil
}
if !isPostgresDriver(sqlDB) {
log.Printf("[DashboardAggregation] 检测到非 PostgreSQL 驱动,已自动禁用预聚合")
return nil
}
return newDashboardAggregationRepositoryWithSQL(sqlDB)
}
func newDashboardAggregationRepositoryWithSQL(sqlq sqlExecutor) *dashboardAggregationRepository {
return &dashboardAggregationRepository{sql: sqlq}
}
func isPostgresDriver(db *sql.DB) bool {
if db == nil {
return false
}
_, ok := db.Driver().(*pq.Driver)
return ok
}
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
loc := timezone.Location()
startLocal := start.In(loc)
endLocal := end.In(loc)
if !endLocal.After(startLocal) {
return nil
}
hourStart := startLocal.Truncate(time.Hour)
hourEnd := endLocal.Truncate(time.Hour)
if endLocal.After(hourEnd) {
hourEnd = hourEnd.Add(time.Hour)
}
dayStart := truncateToDay(startLocal)
dayEnd := truncateToDay(endLocal)
if endLocal.After(dayEnd) {
dayEnd = dayEnd.Add(24 * time.Hour)
}
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
return err
}
if err := r.insertDailyActiveUsers(ctx, hourStart, hourEnd); err != nil {
return err
}
if err := r.upsertHourlyAggregates(ctx, hourStart, hourEnd); err != nil {
return err
}
if err := r.upsertDailyAggregates(ctx, dayStart, dayEnd); err != nil {
return err
}
return nil
}
func (r *dashboardAggregationRepository) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
var ts time.Time
query := "SELECT last_aggregated_at FROM usage_dashboard_aggregation_watermark WHERE id = 1"
if err := scanSingleRow(ctx, r.sql, query, nil, &ts); err != nil {
if err == sql.ErrNoRows {
return time.Unix(0, 0).UTC(), nil
}
return time.Time{}, err
}
return ts.UTC(), nil
}
func (r *dashboardAggregationRepository) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
query := `
INSERT INTO usage_dashboard_aggregation_watermark (id, last_aggregated_at, updated_at)
VALUES (1, $1, NOW())
ON CONFLICT (id)
DO UPDATE SET last_aggregated_at = EXCLUDED.last_aggregated_at, updated_at = EXCLUDED.updated_at
`
_, err := r.sql.ExecContext(ctx, query, aggregatedAt.UTC())
return err
}
func (r *dashboardAggregationRepository) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
hourlyCutoffUTC := hourlyCutoff.UTC()
dailyCutoffUTC := dailyCutoff.UTC()
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly WHERE bucket_start < $1", hourlyCutoffUTC); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly_users WHERE bucket_start < $1", hourlyCutoffUTC); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily WHERE bucket_date < $1::date", dailyCutoffUTC); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily_users WHERE bucket_date < $1::date", dailyCutoffUTC); err != nil {
return err
}
return nil
}
func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
isPartitioned, err := r.isUsageLogsPartitioned(ctx)
if err != nil {
return err
}
if isPartitioned {
return r.dropUsageLogsPartitions(ctx, cutoff)
}
_, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC())
return err
}
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
isPartitioned, err := r.isUsageLogsPartitioned(ctx)
if err != nil || !isPartitioned {
return err
}
monthStart := truncateToMonthUTC(now)
prevMonth := monthStart.AddDate(0, -1, 0)
nextMonth := monthStart.AddDate(0, 1, 0)
for _, m := range []time.Time{prevMonth, monthStart, nextMonth} {
if err := r.createUsageLogsPartition(ctx, m); err != nil {
return err
}
}
return nil
}
func (r *dashboardAggregationRepository) insertHourlyActiveUsers(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := `
INSERT INTO usage_dashboard_hourly_users (bucket_start, user_id)
SELECT DISTINCT
date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
user_id
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
ON CONFLICT DO NOTHING
`
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) insertDailyActiveUsers(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := `
INSERT INTO usage_dashboard_daily_users (bucket_date, user_id)
SELECT DISTINCT
(bucket_start AT TIME ZONE $3)::date AS bucket_date,
user_id
FROM usage_dashboard_hourly_users
WHERE bucket_start >= $1 AND bucket_start < $2
ON CONFLICT DO NOTHING
`
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := `
WITH hourly AS (
SELECT
date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
COUNT(*) AS total_requests,
COALESCE(SUM(input_tokens), 0) AS input_tokens,
COALESCE(SUM(output_tokens), 0) AS output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) AS cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
COALESCE(SUM(total_cost), 0) AS total_cost,
COALESCE(SUM(actual_cost), 0) AS actual_cost,
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) AS total_duration_ms
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
GROUP BY 1
),
user_counts AS (
SELECT bucket_start, COUNT(*) AS active_users
FROM usage_dashboard_hourly_users
WHERE bucket_start >= $1 AND bucket_start < $2
GROUP BY bucket_start
)
INSERT INTO usage_dashboard_hourly (
bucket_start,
total_requests,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
total_cost,
actual_cost,
total_duration_ms,
active_users,
computed_at
)
SELECT
hourly.bucket_start,
hourly.total_requests,
hourly.input_tokens,
hourly.output_tokens,
hourly.cache_creation_tokens,
hourly.cache_read_tokens,
hourly.total_cost,
hourly.actual_cost,
hourly.total_duration_ms,
COALESCE(user_counts.active_users, 0) AS active_users,
NOW()
FROM hourly
LEFT JOIN user_counts ON user_counts.bucket_start = hourly.bucket_start
ON CONFLICT (bucket_start)
DO UPDATE SET
total_requests = EXCLUDED.total_requests,
input_tokens = EXCLUDED.input_tokens,
output_tokens = EXCLUDED.output_tokens,
cache_creation_tokens = EXCLUDED.cache_creation_tokens,
cache_read_tokens = EXCLUDED.cache_read_tokens,
total_cost = EXCLUDED.total_cost,
actual_cost = EXCLUDED.actual_cost,
total_duration_ms = EXCLUDED.total_duration_ms,
active_users = EXCLUDED.active_users,
computed_at = EXCLUDED.computed_at
`
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := `
WITH daily AS (
SELECT
(bucket_start AT TIME ZONE $5)::date AS bucket_date,
COALESCE(SUM(total_requests), 0) AS total_requests,
COALESCE(SUM(input_tokens), 0) AS input_tokens,
COALESCE(SUM(output_tokens), 0) AS output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) AS cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
COALESCE(SUM(total_cost), 0) AS total_cost,
COALESCE(SUM(actual_cost), 0) AS actual_cost,
COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms
FROM usage_dashboard_hourly
WHERE bucket_start >= $1 AND bucket_start < $2
GROUP BY (bucket_start AT TIME ZONE $5)::date
),
user_counts AS (
SELECT bucket_date, COUNT(*) AS active_users
FROM usage_dashboard_daily_users
WHERE bucket_date >= $3::date AND bucket_date < $4::date
GROUP BY bucket_date
)
INSERT INTO usage_dashboard_daily (
bucket_date,
total_requests,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
total_cost,
actual_cost,
total_duration_ms,
active_users,
computed_at
)
SELECT
daily.bucket_date,
daily.total_requests,
daily.input_tokens,
daily.output_tokens,
daily.cache_creation_tokens,
daily.cache_read_tokens,
daily.total_cost,
daily.actual_cost,
daily.total_duration_ms,
COALESCE(user_counts.active_users, 0) AS active_users,
NOW()
FROM daily
LEFT JOIN user_counts ON user_counts.bucket_date = daily.bucket_date
ON CONFLICT (bucket_date)
DO UPDATE SET
total_requests = EXCLUDED.total_requests,
input_tokens = EXCLUDED.input_tokens,
output_tokens = EXCLUDED.output_tokens,
cache_creation_tokens = EXCLUDED.cache_creation_tokens,
cache_read_tokens = EXCLUDED.cache_read_tokens,
total_cost = EXCLUDED.total_cost,
actual_cost = EXCLUDED.actual_cost,
total_duration_ms = EXCLUDED.total_duration_ms,
active_users = EXCLUDED.active_users,
computed_at = EXCLUDED.computed_at
`
_, err := r.sql.ExecContext(ctx, query, start, end, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) isUsageLogsPartitioned(ctx context.Context) (bool, error) {
query := `
SELECT EXISTS(
SELECT 1
FROM pg_partitioned_table pt
JOIN pg_class c ON c.oid = pt.partrelid
WHERE c.relname = 'usage_logs'
)
`
var partitioned bool
if err := scanSingleRow(ctx, r.sql, query, nil, &partitioned); err != nil {
return false, err
}
return partitioned, nil
}
func (r *dashboardAggregationRepository) dropUsageLogsPartitions(ctx context.Context, cutoff time.Time) error {
rows, err := r.sql.QueryContext(ctx, `
SELECT c.relname
FROM pg_inherits
JOIN pg_class c ON c.oid = pg_inherits.inhrelid
JOIN pg_class p ON p.oid = pg_inherits.inhparent
WHERE p.relname = 'usage_logs'
`)
if err != nil {
return err
}
defer func() {
_ = rows.Close()
}()
cutoffMonth := truncateToMonthUTC(cutoff)
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return err
}
if !strings.HasPrefix(name, "usage_logs_") {
continue
}
suffix := strings.TrimPrefix(name, "usage_logs_")
month, err := time.Parse("200601", suffix)
if err != nil {
continue
}
month = month.UTC()
if month.Before(cutoffMonth) {
if _, err := r.sql.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", pq.QuoteIdentifier(name))); err != nil {
return err
}
}
}
return rows.Err()
}
func (r *dashboardAggregationRepository) createUsageLogsPartition(ctx context.Context, month time.Time) error {
monthStart := truncateToMonthUTC(month)
nextMonth := monthStart.AddDate(0, 1, 0)
name := fmt.Sprintf("usage_logs_%s", monthStart.Format("200601"))
query := fmt.Sprintf(
"CREATE TABLE IF NOT EXISTS %s PARTITION OF usage_logs FOR VALUES FROM (%s) TO (%s)",
pq.QuoteIdentifier(name),
pq.QuoteLiteral(monthStart.Format("2006-01-02")),
pq.QuoteLiteral(nextMonth.Format("2006-01-02")),
)
_, err := r.sql.ExecContext(ctx, query)
return err
}
func truncateToDay(t time.Time) time.Time {
return timezone.StartOfDay(t)
}
func truncateToMonthUTC(t time.Time) time.Time {
t = t.UTC()
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, time.UTC)
}

View File

@@ -0,0 +1,58 @@
package repository
import (
"context"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const dashboardStatsCacheKey = "dashboard:stats:v1"
type dashboardCache struct {
rdb *redis.Client
keyPrefix string
}
func NewDashboardCache(rdb *redis.Client, cfg *config.Config) service.DashboardStatsCache {
prefix := "sub2api:"
if cfg != nil {
prefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
}
if prefix != "" && !strings.HasSuffix(prefix, ":") {
prefix += ":"
}
return &dashboardCache{
rdb: rdb,
keyPrefix: prefix,
}
}
func (c *dashboardCache) GetDashboardStats(ctx context.Context) (string, error) {
val, err := c.rdb.Get(ctx, c.buildKey()).Result()
if err != nil {
if err == redis.Nil {
return "", service.ErrDashboardStatsCacheMiss
}
return "", err
}
return val, nil
}
func (c *dashboardCache) SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error {
return c.rdb.Set(ctx, c.buildKey(), data, ttl).Err()
}
func (c *dashboardCache) buildKey() string {
if c.keyPrefix == "" {
return dashboardStatsCacheKey
}
return c.keyPrefix + dashboardStatsCacheKey
}
func (c *dashboardCache) DeleteDashboardStats(ctx context.Context) error {
return c.rdb.Del(ctx, c.buildKey()).Err()
}

View File

@@ -0,0 +1,28 @@
package repository
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestNewDashboardCacheKeyPrefix(t *testing.T) {
cache := NewDashboardCache(nil, &config.Config{
Dashboard: config.DashboardCacheConfig{
KeyPrefix: "prod",
},
})
impl, ok := cache.(*dashboardCache)
require.True(t, ok)
require.Equal(t, "prod:", impl.keyPrefix)
cache = NewDashboardCache(nil, &config.Config{
Dashboard: config.DashboardCacheConfig{
KeyPrefix: "staging:",
},
})
impl, ok = cache.(*dashboardCache)
require.True(t, ok)
require.Equal(t, "staging:", impl.keyPrefix)
}

View File

@@ -0,0 +1,32 @@
package repository
import (
"database/sql"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
type dbPoolSettings struct {
MaxOpenConns int
MaxIdleConns int
ConnMaxLifetime time.Duration
ConnMaxIdleTime time.Duration
}
func buildDBPoolSettings(cfg *config.Config) dbPoolSettings {
return dbPoolSettings{
MaxOpenConns: cfg.Database.MaxOpenConns,
MaxIdleConns: cfg.Database.MaxIdleConns,
ConnMaxLifetime: time.Duration(cfg.Database.ConnMaxLifetimeMinutes) * time.Minute,
ConnMaxIdleTime: time.Duration(cfg.Database.ConnMaxIdleTimeMinutes) * time.Minute,
}
}
func applyDBPoolSettings(db *sql.DB, cfg *config.Config) {
settings := buildDBPoolSettings(cfg)
db.SetMaxOpenConns(settings.MaxOpenConns)
db.SetMaxIdleConns(settings.MaxIdleConns)
db.SetConnMaxLifetime(settings.ConnMaxLifetime)
db.SetConnMaxIdleTime(settings.ConnMaxIdleTime)
}

View File

@@ -0,0 +1,50 @@
package repository
import (
"database/sql"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
_ "github.com/lib/pq"
)
func TestBuildDBPoolSettings(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
MaxOpenConns: 50,
MaxIdleConns: 10,
ConnMaxLifetimeMinutes: 30,
ConnMaxIdleTimeMinutes: 5,
},
}
settings := buildDBPoolSettings(cfg)
require.Equal(t, 50, settings.MaxOpenConns)
require.Equal(t, 10, settings.MaxIdleConns)
require.Equal(t, 30*time.Minute, settings.ConnMaxLifetime)
require.Equal(t, 5*time.Minute, settings.ConnMaxIdleTime)
}
func TestApplyDBPoolSettings(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
MaxOpenConns: 40,
MaxIdleConns: 8,
ConnMaxLifetimeMinutes: 15,
ConnMaxIdleTimeMinutes: 3,
},
}
db, err := sql.Open("postgres", "host=127.0.0.1 port=5432 user=postgres sslmode=disable")
require.NoError(t, err)
t.Cleanup(func() {
_ = db.Close()
})
applyDBPoolSettings(db, cfg)
stats := db.Stats()
require.Equal(t, 40, stats.MaxOpenConnections)
}

View File

@@ -0,0 +1,52 @@
package repository
import (
"context"
"encoding/json"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const verifyCodeKeyPrefix = "verify_code:"
// verifyCodeKey generates the Redis key for email verification code.
func verifyCodeKey(email string) string {
return verifyCodeKeyPrefix + email
}
type emailCache struct {
rdb *redis.Client
}
func NewEmailCache(rdb *redis.Client) service.EmailCache {
return &emailCache{rdb: rdb}
}
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*service.VerificationCodeData, error) {
key := verifyCodeKey(email)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
}
var data service.VerificationCodeData
if err := json.Unmarshal([]byte(val), &data); err != nil {
return nil, err
}
return &data, nil
}
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *service.VerificationCodeData, ttl time.Duration) error {
key := verifyCodeKey(email)
val, err := json.Marshal(data)
if err != nil {
return err
}
return c.rdb.Set(ctx, key, val, ttl).Err()
}
func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) error {
key := verifyCodeKey(email)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,92 @@
//go:build integration
package repository
import (
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type EmailCacheSuite struct {
IntegrationRedisSuite
cache service.EmailCache
}
func (s *EmailCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewEmailCache(s.rdb)
}
func (s *EmailCacheSuite) TestGetVerificationCode_Missing() {
_, err := s.cache.GetVerificationCode(s.ctx, "nonexistent@example.com")
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing verification code")
}
func (s *EmailCacheSuite) TestSetAndGetVerificationCode() {
email := "a@example.com"
emailTTL := 2 * time.Minute
data := &service.VerificationCodeData{Code: "123456", Attempts: 1, CreatedAt: time.Now()}
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
got, err := s.cache.GetVerificationCode(s.ctx, email)
require.NoError(s.T(), err, "GetVerificationCode")
require.Equal(s.T(), "123456", got.Code)
require.Equal(s.T(), 1, got.Attempts)
}
func (s *EmailCacheSuite) TestVerificationCode_TTL() {
email := "ttl@example.com"
emailTTL := 2 * time.Minute
data := &service.VerificationCodeData{Code: "654321", Attempts: 0, CreatedAt: time.Now()}
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
emailKey := verifyCodeKeyPrefix + email
ttl, err := s.rdb.TTL(s.ctx, emailKey).Result()
require.NoError(s.T(), err, "TTL emailKey")
s.AssertTTLWithin(ttl, 1*time.Second, emailTTL)
}
func (s *EmailCacheSuite) TestDeleteVerificationCode() {
email := "delete@example.com"
data := &service.VerificationCodeData{Code: "999999", Attempts: 0, CreatedAt: time.Now()}
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, 2*time.Minute), "SetVerificationCode")
// Verify it exists
_, err := s.cache.GetVerificationCode(s.ctx, email)
require.NoError(s.T(), err, "GetVerificationCode before delete")
// Delete
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, email), "DeleteVerificationCode")
// Verify it's gone
_, err = s.cache.GetVerificationCode(s.ctx, email)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
}
func (s *EmailCacheSuite) TestDeleteVerificationCode_NonExistent() {
// Deleting a non-existent key should not error
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, "nonexistent@example.com"), "DeleteVerificationCode non-existent")
}
func (s *EmailCacheSuite) TestGetVerificationCode_JSONCorruption() {
emailKey := verifyCodeKeyPrefix + "corrupted@example.com"
require.NoError(s.T(), s.rdb.Set(s.ctx, emailKey, "not-json", 1*time.Minute).Err(), "Set invalid JSON")
_, err := s.cache.GetVerificationCode(s.ctx, "corrupted@example.com")
require.Error(s.T(), err, "expected error for corrupted JSON")
require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil")
}
func TestEmailCacheSuite(t *testing.T) {
suite.Run(t, new(EmailCacheSuite))
}

View File

@@ -0,0 +1,45 @@
//go:build unit
package repository
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestVerifyCodeKey(t *testing.T) {
tests := []struct {
name string
email string
expected string
}{
{
name: "normal_email",
email: "user@example.com",
expected: "verify_code:user@example.com",
},
{
name: "empty_email",
email: "",
expected: "verify_code:",
},
{
name: "email_with_plus",
email: "user+tag@example.com",
expected: "verify_code:user+tag@example.com",
},
{
name: "email_with_special_chars",
email: "user.name+tag@sub.domain.com",
expected: "verify_code:user.name+tag@sub.domain.com",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := verifyCodeKey(tc.email)
require.Equal(t, tc.expected, got)
})
}
}

View File

@@ -0,0 +1,69 @@
// Package repository 提供应用程序的基础设施层组件。
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
package repository
import (
"context"
"database/sql"
"time"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/migrations"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "github.com/lib/pq" // PostgreSQL 驱动,通过副作用导入注册驱动
)
// InitEnt 初始化 Ent ORM 客户端并返回客户端实例和底层的 *sql.DB。
//
// 该函数执行以下操作:
// 1. 初始化全局时区设置,确保时间处理一致性
// 2. 建立 PostgreSQL 数据库连接
// 3. 自动执行数据库迁移,确保 schema 与代码同步
// 4. 创建并返回 Ent 客户端实例
//
// 重要提示:调用者必须负责关闭返回的 ent.Client关闭时会自动关闭底层的 driver/db
//
// 参数:
// - cfg: 应用程序配置,包含数据库连接信息和时区设置
//
// 返回:
// - *ent.Client: Ent ORM 客户端,用于执行数据库操作
// - *sql.DB: 底层的 SQL 数据库连接,可用于直接执行原生 SQL
// - error: 初始化过程中的错误
func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
// 优先初始化时区设置,确保所有时间操作使用统一的时区。
// 这对于跨时区部署和日志时间戳的一致性至关重要。
if err := timezone.Init(cfg.Timezone); err != nil {
return nil, nil, err
}
// 构建包含时区信息的数据库连接字符串 (DSN)。
// 时区信息会传递给 PostgreSQL确保数据库层面的时间处理正确。
dsn := cfg.Database.DSNWithTimezone(cfg.Timezone)
// 使用 Ent 的 SQL 驱动打开 PostgreSQL 连接。
// dialect.Postgres 指定使用 PostgreSQL 方言进行 SQL 生成。
drv, err := entsql.Open(dialect.Postgres, dsn)
if err != nil {
return nil, nil, err
}
applyDBPoolSettings(drv.DB(), cfg)
// 确保数据库 schema 已准备就绪。
// SQL 迁移文件是 schema 的权威来源source of truth
// 这种方式比 Ent 的自动迁移更可控,支持复杂的迁移场景。
migrationCtx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
if err := applyMigrationsFS(migrationCtx, drv.DB(), migrations.FS); err != nil {
_ = drv.Close() // 迁移失败时关闭驱动,避免资源泄露
return nil, nil, err
}
// 创建 Ent 客户端,绑定到已配置的数据库驱动。
client := ent.NewClient(ent.Driver(drv))
return client, drv.DB(), nil
}

View File

@@ -0,0 +1,97 @@
package repository
import (
"context"
"database/sql"
"errors"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/lib/pq"
)
// clientFromContext 从 context 中获取事务 client如果不存在则返回默认 client。
//
// 这个辅助函数支持 repository 方法在事务上下文中工作:
// - 如果 context 中存在事务(通过 ent.NewTxContext 设置),返回事务的 client
// - 否则返回传入的默认 client
//
// 使用示例:
//
// func (r *someRepo) SomeMethod(ctx context.Context) error {
// client := clientFromContext(ctx, r.client)
// return client.SomeEntity.Create().Save(ctx)
// }
func clientFromContext(ctx context.Context, defaultClient *dbent.Client) *dbent.Client {
if tx := dbent.TxFromContext(ctx); tx != nil {
return tx.Client()
}
return defaultClient
}
// translatePersistenceError 将数据库层错误翻译为业务层错误。
//
// 这是 Repository 层的核心错误处理函数,确保数据库细节不会泄露到业务层。
// 通过统一的错误翻译,业务层可以使用语义明确的错误类型(如 ErrUserNotFound
// 而不是依赖于特定数据库的错误(如 sql.ErrNoRows
//
// 参数:
// - err: 原始数据库错误
// - notFound: 当记录不存在时返回的业务错误(可为 nil 表示不处理)
// - conflict: 当违反唯一约束时返回的业务错误(可为 nil 表示不处理)
//
// 返回:
// - 翻译后的业务错误,或原始错误(如果不匹配任何规则)
//
// 示例:
//
// err := translatePersistenceError(dbErr, service.ErrUserNotFound, service.ErrEmailExists)
func translatePersistenceError(err error, notFound, conflict *infraerrors.ApplicationError) error {
if err == nil {
return nil
}
// 兼容 Ent ORM 和标准 database/sql 的 NotFound 行为。
// Ent 使用自定义的 NotFoundError而标准库使用 sql.ErrNoRows。
// 这里同时处理两种情况,保持业务错误映射一致。
if notFound != nil && (errors.Is(err, sql.ErrNoRows) || dbent.IsNotFound(err)) {
return notFound.WithCause(err)
}
// 处理唯一约束冲突(如邮箱已存在、名称重复等)
if conflict != nil && isUniqueConstraintViolation(err) {
return conflict.WithCause(err)
}
// 未匹配任何规则,返回原始错误
return err
}
// isUniqueConstraintViolation 判断错误是否为唯一约束冲突。
//
// 支持多种检测方式:
// 1. PostgreSQL 特定错误码 23505唯一约束冲突
// 2. 错误消息中包含的通用关键词
//
// 这种多层次的检测确保了对不同数据库驱动和 ORM 的兼容性。
func isUniqueConstraintViolation(err error) bool {
if err == nil {
return false
}
// 优先检测 PostgreSQL 特定错误码(最精确)。
// 错误码 23505 对应 unique_violation。
// 参考https://www.postgresql.org/docs/current/errcodes-appendix.html
var pgErr *pq.Error
if errors.As(err, &pgErr) {
return pgErr.Code == "23505"
}
// 回退到错误消息检测(兼容其他场景)。
// 这些关键词覆盖了 PostgreSQL、MySQL 等主流数据库的错误消息。
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "duplicate key") ||
strings.Contains(msg, "unique constraint") ||
strings.Contains(msg, "duplicate entry")
}

View File

@@ -0,0 +1,391 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func mustCreateUser(t *testing.T, client *dbent.Client, u *service.User) *service.User {
t.Helper()
ctx := context.Background()
if u.Email == "" {
u.Email = "user-" + time.Now().Format(time.RFC3339Nano) + "@example.com"
}
if u.PasswordHash == "" {
u.PasswordHash = "test-password-hash"
}
if u.Role == "" {
u.Role = service.RoleUser
}
if u.Status == "" {
u.Status = service.StatusActive
}
if u.Concurrency == 0 {
u.Concurrency = 5
}
create := client.User.Create().
SetEmail(u.Email).
SetPasswordHash(u.PasswordHash).
SetRole(u.Role).
SetStatus(u.Status).
SetBalance(u.Balance).
SetConcurrency(u.Concurrency).
SetUsername(u.Username).
SetNotes(u.Notes)
if !u.CreatedAt.IsZero() {
create.SetCreatedAt(u.CreatedAt)
}
if !u.UpdatedAt.IsZero() {
create.SetUpdatedAt(u.UpdatedAt)
}
created, err := create.Save(ctx)
require.NoError(t, err, "create user")
u.ID = created.ID
u.CreatedAt = created.CreatedAt
u.UpdatedAt = created.UpdatedAt
if len(u.AllowedGroups) > 0 {
for _, groupID := range u.AllowedGroups {
_, err := client.UserAllowedGroup.Create().
SetUserID(u.ID).
SetGroupID(groupID).
Save(ctx)
require.NoError(t, err, "create user_allowed_groups row")
}
}
return u
}
func mustCreateGroup(t *testing.T, client *dbent.Client, g *service.Group) *service.Group {
t.Helper()
ctx := context.Background()
if g.Platform == "" {
g.Platform = service.PlatformAnthropic
}
if g.Status == "" {
g.Status = service.StatusActive
}
if g.SubscriptionType == "" {
g.SubscriptionType = service.SubscriptionTypeStandard
}
create := client.Group.Create().
SetName(g.Name).
SetPlatform(g.Platform).
SetStatus(g.Status).
SetSubscriptionType(g.SubscriptionType).
SetRateMultiplier(g.RateMultiplier).
SetIsExclusive(g.IsExclusive)
if g.Description != "" {
create.SetDescription(g.Description)
}
if g.DailyLimitUSD != nil {
create.SetDailyLimitUsd(*g.DailyLimitUSD)
}
if g.WeeklyLimitUSD != nil {
create.SetWeeklyLimitUsd(*g.WeeklyLimitUSD)
}
if g.MonthlyLimitUSD != nil {
create.SetMonthlyLimitUsd(*g.MonthlyLimitUSD)
}
if !g.CreatedAt.IsZero() {
create.SetCreatedAt(g.CreatedAt)
}
if !g.UpdatedAt.IsZero() {
create.SetUpdatedAt(g.UpdatedAt)
}
created, err := create.Save(ctx)
require.NoError(t, err, "create group")
g.ID = created.ID
g.CreatedAt = created.CreatedAt
g.UpdatedAt = created.UpdatedAt
return g
}
func mustCreateProxy(t *testing.T, client *dbent.Client, p *service.Proxy) *service.Proxy {
t.Helper()
ctx := context.Background()
if p.Protocol == "" {
p.Protocol = "http"
}
if p.Host == "" {
p.Host = "127.0.0.1"
}
if p.Port == 0 {
p.Port = 8080
}
if p.Status == "" {
p.Status = service.StatusActive
}
create := client.Proxy.Create().
SetName(p.Name).
SetProtocol(p.Protocol).
SetHost(p.Host).
SetPort(p.Port).
SetStatus(p.Status)
if p.Username != "" {
create.SetUsername(p.Username)
}
if p.Password != "" {
create.SetPassword(p.Password)
}
if !p.CreatedAt.IsZero() {
create.SetCreatedAt(p.CreatedAt)
}
if !p.UpdatedAt.IsZero() {
create.SetUpdatedAt(p.UpdatedAt)
}
created, err := create.Save(ctx)
require.NoError(t, err, "create proxy")
p.ID = created.ID
p.CreatedAt = created.CreatedAt
p.UpdatedAt = created.UpdatedAt
return p
}
func mustCreateAccount(t *testing.T, client *dbent.Client, a *service.Account) *service.Account {
t.Helper()
ctx := context.Background()
if a.Platform == "" {
a.Platform = service.PlatformAnthropic
}
if a.Type == "" {
a.Type = service.AccountTypeOAuth
}
if a.Status == "" {
a.Status = service.StatusActive
}
if a.Concurrency == 0 {
a.Concurrency = 3
}
if a.Priority == 0 {
a.Priority = 50
}
if !a.Schedulable {
a.Schedulable = true
}
if a.Credentials == nil {
a.Credentials = map[string]any{}
}
if a.Extra == nil {
a.Extra = map[string]any{}
}
create := client.Account.Create().
SetName(a.Name).
SetPlatform(a.Platform).
SetType(a.Type).
SetCredentials(a.Credentials).
SetExtra(a.Extra).
SetConcurrency(a.Concurrency).
SetPriority(a.Priority).
SetStatus(a.Status).
SetSchedulable(a.Schedulable).
SetErrorMessage(a.ErrorMessage)
if a.ProxyID != nil {
create.SetProxyID(*a.ProxyID)
}
if a.LastUsedAt != nil {
create.SetLastUsedAt(*a.LastUsedAt)
}
if a.RateLimitedAt != nil {
create.SetRateLimitedAt(*a.RateLimitedAt)
}
if a.RateLimitResetAt != nil {
create.SetRateLimitResetAt(*a.RateLimitResetAt)
}
if a.OverloadUntil != nil {
create.SetOverloadUntil(*a.OverloadUntil)
}
if a.SessionWindowStart != nil {
create.SetSessionWindowStart(*a.SessionWindowStart)
}
if a.SessionWindowEnd != nil {
create.SetSessionWindowEnd(*a.SessionWindowEnd)
}
if a.SessionWindowStatus != "" {
create.SetSessionWindowStatus(a.SessionWindowStatus)
}
if !a.CreatedAt.IsZero() {
create.SetCreatedAt(a.CreatedAt)
}
if !a.UpdatedAt.IsZero() {
create.SetUpdatedAt(a.UpdatedAt)
}
created, err := create.Save(ctx)
require.NoError(t, err, "create account")
a.ID = created.ID
a.CreatedAt = created.CreatedAt
a.UpdatedAt = created.UpdatedAt
return a
}
func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.APIKey) *service.APIKey {
t.Helper()
ctx := context.Background()
if k.Status == "" {
k.Status = service.StatusActive
}
if k.Key == "" {
k.Key = "sk-" + time.Now().Format("150405.000000")
}
if k.Name == "" {
k.Name = "default"
}
create := client.APIKey.Create().
SetUserID(k.UserID).
SetKey(k.Key).
SetName(k.Name).
SetStatus(k.Status)
if k.GroupID != nil {
create.SetGroupID(*k.GroupID)
}
if !k.CreatedAt.IsZero() {
create.SetCreatedAt(k.CreatedAt)
}
if !k.UpdatedAt.IsZero() {
create.SetUpdatedAt(k.UpdatedAt)
}
created, err := create.Save(ctx)
require.NoError(t, err, "create api key")
k.ID = created.ID
k.CreatedAt = created.CreatedAt
k.UpdatedAt = created.UpdatedAt
return k
}
func mustCreateRedeemCode(t *testing.T, client *dbent.Client, c *service.RedeemCode) *service.RedeemCode {
t.Helper()
ctx := context.Background()
if c.Status == "" {
c.Status = service.StatusUnused
}
if c.Type == "" {
c.Type = service.RedeemTypeBalance
}
if c.Code == "" {
c.Code = "rc-" + time.Now().Format("150405.000000")
}
create := client.RedeemCode.Create().
SetCode(c.Code).
SetType(c.Type).
SetValue(c.Value).
SetStatus(c.Status).
SetNotes(c.Notes).
SetValidityDays(c.ValidityDays)
if c.UsedBy != nil {
create.SetUsedBy(*c.UsedBy)
}
if c.UsedAt != nil {
create.SetUsedAt(*c.UsedAt)
}
if c.GroupID != nil {
create.SetGroupID(*c.GroupID)
}
if !c.CreatedAt.IsZero() {
create.SetCreatedAt(c.CreatedAt)
}
created, err := create.Save(ctx)
require.NoError(t, err, "create redeem code")
c.ID = created.ID
c.CreatedAt = created.CreatedAt
return c
}
func mustCreateSubscription(t *testing.T, client *dbent.Client, s *service.UserSubscription) *service.UserSubscription {
t.Helper()
ctx := context.Background()
if s.Status == "" {
s.Status = service.SubscriptionStatusActive
}
now := time.Now()
if s.StartsAt.IsZero() {
s.StartsAt = now.Add(-1 * time.Hour)
}
if s.ExpiresAt.IsZero() {
s.ExpiresAt = now.Add(24 * time.Hour)
}
if s.AssignedAt.IsZero() {
s.AssignedAt = now
}
if s.CreatedAt.IsZero() {
s.CreatedAt = now
}
if s.UpdatedAt.IsZero() {
s.UpdatedAt = now
}
create := client.UserSubscription.Create().
SetUserID(s.UserID).
SetGroupID(s.GroupID).
SetStartsAt(s.StartsAt).
SetExpiresAt(s.ExpiresAt).
SetStatus(s.Status).
SetAssignedAt(s.AssignedAt).
SetNotes(s.Notes).
SetDailyUsageUsd(s.DailyUsageUSD).
SetWeeklyUsageUsd(s.WeeklyUsageUSD).
SetMonthlyUsageUsd(s.MonthlyUsageUSD)
if s.AssignedBy != nil {
create.SetAssignedBy(*s.AssignedBy)
}
if !s.CreatedAt.IsZero() {
create.SetCreatedAt(s.CreatedAt)
}
if !s.UpdatedAt.IsZero() {
create.SetUpdatedAt(s.UpdatedAt)
}
created, err := create.Save(ctx)
require.NoError(t, err, "create user subscription")
s.ID = created.ID
s.CreatedAt = created.CreatedAt
s.UpdatedAt = created.UpdatedAt
return s
}
func mustBindAccountToGroup(t *testing.T, client *dbent.Client, accountID, groupID int64, priority int) {
t.Helper()
ctx := context.Background()
_, err := client.AccountGroup.Create().
SetAccountID(accountID).
SetGroupID(groupID).
SetPriority(priority).
Save(ctx)
require.NoError(t, err, "create account_group")
}

View File

@@ -0,0 +1,41 @@
package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const stickySessionPrefix = "sticky_session:"
type gatewayCache struct {
rdb *redis.Client
}
func NewGatewayCache(rdb *redis.Client) service.GatewayCache {
return &gatewayCache{rdb: rdb}
}
// buildSessionKey 构建 session key包含 groupID 实现分组隔离
// 格式: sticky_session:{groupID}:{sessionHash}
func buildSessionKey(groupID int64, sessionHash string) string {
return fmt.Sprintf("%s%d:%s", stickySessionPrefix, groupID, sessionHash)
}
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Get(ctx, key).Int64()
}
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Set(ctx, key, accountID, ttl).Err()
}
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Expire(ctx, key, ttl).Err()
}

View File

@@ -0,0 +1,96 @@
//go:build integration
package repository
import (
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type GatewayCacheSuite struct {
IntegrationRedisSuite
cache service.GatewayCache
}
func (s *GatewayCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewGatewayCache(s.rdb)
}
func (s *GatewayCacheSuite) TestGetSessionAccountID_Missing() {
_, err := s.cache.GetSessionAccountID(s.ctx, 1, "nonexistent")
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing session")
}
func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() {
sessionID := "s1"
accountID := int64(99)
groupID := int64(1)
sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
sid, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
require.NoError(s.T(), err, "GetSessionAccountID")
require.Equal(s.T(), accountID, sid, "session id mismatch")
}
func (s *GatewayCacheSuite) TestSessionAccountID_TTL() {
sessionID := "s2"
accountID := int64(100)
groupID := int64(1)
sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
sessionKey := buildSessionKey(groupID, sessionID)
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
require.NoError(s.T(), err, "TTL sessionKey after Set")
s.AssertTTLWithin(ttl, 1*time.Second, sessionTTL)
}
func (s *GatewayCacheSuite) TestRefreshSessionTTL() {
sessionID := "s3"
accountID := int64(101)
groupID := int64(1)
initialTTL := 1 * time.Minute
refreshTTL := 3 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, initialTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, groupID, sessionID, refreshTTL), "RefreshSessionTTL")
sessionKey := buildSessionKey(groupID, sessionID)
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
require.NoError(s.T(), err, "TTL after Refresh")
s.AssertTTLWithin(ttl, 1*time.Second, refreshTTL)
}
func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
// RefreshSessionTTL on a missing key should not error (no-op)
err := s.cache.RefreshSessionTTL(s.ctx, 1, "missing-session", 1*time.Minute)
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
}
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
sessionID := "corrupted"
groupID := int64(1)
sessionKey := buildSessionKey(groupID, sessionID)
// Set a non-integer value
require.NoError(s.T(), s.rdb.Set(s.ctx, sessionKey, "not-a-number", 1*time.Minute).Err(), "Set invalid value")
_, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
require.Error(s.T(), err, "expected error for corrupted value")
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
}
func TestGatewayCacheSuite(t *testing.T) {
suite.Run(t, new(GatewayCacheSuite))
}

View File

@@ -0,0 +1,250 @@
//go:build integration
package repository
import (
"context"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
)
// GatewayRoutingSuite 测试网关路由相关的数据库查询
// 验证账户选择和分流逻辑在真实数据库环境下的行为
type GatewayRoutingSuite struct {
suite.Suite
ctx context.Context
client *dbent.Client
accountRepo *accountRepository
}
func (s *GatewayRoutingSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx)
}
func TestGatewayRoutingSuite(t *testing.T) {
suite.Run(t, new(GatewayRoutingSuite))
}
// TestListSchedulableByPlatforms_GeminiAndAntigravity 验证多平台账户查询
func (s *GatewayRoutingSuite) TestListSchedulableByPlatforms_GeminiAndAntigravity() {
// 创建各平台账户
geminiAcc := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "gemini-oauth",
Platform: service.PlatformGemini,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Schedulable: true,
Priority: 1,
})
antigravityAcc := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "antigravity-oauth",
Platform: service.PlatformAntigravity,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Schedulable: true,
Priority: 2,
Credentials: map[string]any{
"access_token": "test-token",
"refresh_token": "test-refresh",
"project_id": "test-project",
},
})
// 创建不应被选中的 anthropic 账户
mustCreateAccount(s.T(), s.client, &service.Account{
Name: "anthropic-oauth",
Platform: service.PlatformAnthropic,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Schedulable: true,
Priority: 0,
})
// 查询 gemini + antigravity 平台
accounts, err := s.accountRepo.ListSchedulableByPlatforms(s.ctx, []string{
service.PlatformGemini,
service.PlatformAntigravity,
})
s.Require().NoError(err)
s.Require().Len(accounts, 2, "应返回 gemini 和 antigravity 两个账户")
// 验证返回的账户平台
platforms := make(map[string]bool)
for _, acc := range accounts {
platforms[acc.Platform] = true
}
s.Require().True(platforms[service.PlatformGemini], "应包含 gemini 账户")
s.Require().True(platforms[service.PlatformAntigravity], "应包含 antigravity 账户")
s.Require().False(platforms[service.PlatformAnthropic], "不应包含 anthropic 账户")
// 验证账户 ID 匹配
ids := make(map[int64]bool)
for _, acc := range accounts {
ids[acc.ID] = true
}
s.Require().True(ids[geminiAcc.ID])
s.Require().True(ids[antigravityAcc.ID])
}
// TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding 验证按分组过滤
func (s *GatewayRoutingSuite) TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding() {
// 创建 gemini 分组
group := mustCreateGroup(s.T(), s.client, &service.Group{
Name: "gemini-group",
Platform: service.PlatformGemini,
Status: service.StatusActive,
})
// 创建账户
boundAcc := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "bound-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
Schedulable: true,
})
unboundAcc := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "unbound-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
Schedulable: true,
})
// 只绑定一个账户到分组
mustBindAccountToGroup(s.T(), s.client, boundAcc.ID, group.ID, 1)
// 查询分组内的账户
accounts, err := s.accountRepo.ListSchedulableByGroupIDAndPlatforms(s.ctx, group.ID, []string{
service.PlatformGemini,
service.PlatformAntigravity,
})
s.Require().NoError(err)
s.Require().Len(accounts, 1, "应只返回绑定到分组的账户")
s.Require().Equal(boundAcc.ID, accounts[0].ID)
// 确认未绑定的账户不在结果中
for _, acc := range accounts {
s.Require().NotEqual(unboundAcc.ID, acc.ID, "不应包含未绑定的账户")
}
}
// TestListSchedulableByPlatform_Antigravity 验证单平台查询
func (s *GatewayRoutingSuite) TestListSchedulableByPlatform_Antigravity() {
// 创建多种平台账户
mustCreateAccount(s.T(), s.client, &service.Account{
Name: "gemini-1",
Platform: service.PlatformGemini,
Status: service.StatusActive,
Schedulable: true,
})
antigravity := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "antigravity-1",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
Schedulable: true,
})
// 只查询 antigravity 平台
accounts, err := s.accountRepo.ListSchedulableByPlatform(s.ctx, service.PlatformAntigravity)
s.Require().NoError(err)
s.Require().Len(accounts, 1)
s.Require().Equal(antigravity.ID, accounts[0].ID)
s.Require().Equal(service.PlatformAntigravity, accounts[0].Platform)
}
// TestSchedulableFilter_ExcludesInactive 验证不可调度账户被过滤
func (s *GatewayRoutingSuite) TestSchedulableFilter_ExcludesInactive() {
// 创建可调度账户
activeAcc := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "active-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
Schedulable: true,
})
// 创建不可调度账户(需要先创建再更新,因为 fixture 默认设置 Schedulable=true
inactiveAcc := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "inactive-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
})
s.Require().NoError(s.client.Account.UpdateOneID(inactiveAcc.ID).SetSchedulable(false).Exec(s.ctx))
// 创建错误状态账户
mustCreateAccount(s.T(), s.client, &service.Account{
Name: "error-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusError,
Schedulable: true,
})
accounts, err := s.accountRepo.ListSchedulableByPlatform(s.ctx, service.PlatformAntigravity)
s.Require().NoError(err)
s.Require().Len(accounts, 1, "应只返回可调度的 active 账户")
s.Require().Equal(activeAcc.ID, accounts[0].ID)
}
// TestPlatformRoutingDecision 验证平台路由决策
// 这个测试模拟 Handler 层在选择账户后的路由决策逻辑
func (s *GatewayRoutingSuite) TestPlatformRoutingDecision() {
// 创建两种平台的账户
geminiAcc := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "gemini-route-test",
Platform: service.PlatformGemini,
Status: service.StatusActive,
Schedulable: true,
})
antigravityAcc := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "antigravity-route-test",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
Schedulable: true,
})
tests := []struct {
name string
accountID int64
expectedService string
}{
{
name: "Gemini账户路由到ForwardNative",
accountID: geminiAcc.ID,
expectedService: "GeminiMessagesCompatService.ForwardNative",
},
{
name: "Antigravity账户路由到ForwardGemini",
accountID: antigravityAcc.ID,
expectedService: "AntigravityGatewayService.ForwardGemini",
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
// 从数据库获取账户
account, err := s.accountRepo.GetByID(s.ctx, tt.accountID)
s.Require().NoError(err)
// 模拟 Handler 层的路由决策
var routedService string
if account.Platform == service.PlatformAntigravity {
routedService = "AntigravityGatewayService.ForwardGemini"
} else {
routedService = "GeminiMessagesCompatService.ForwardNative"
}
s.Require().Equal(tt.expectedService, routedService)
})
}
}

View File

@@ -0,0 +1,119 @@
package repository
import (
"context"
"fmt"
"net/url"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3"
)
type geminiOAuthClient struct {
tokenURL string
cfg *config.Config
}
func NewGeminiOAuthClient(cfg *config.Config) service.GeminiOAuthClient {
return &geminiOAuthClient{
tokenURL: geminicli.TokenURL,
cfg: cfg,
}
}
func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) {
client := createGeminiReqClient(proxyURL)
// Use different OAuth clients based on oauthType:
// - code_assist: always use built-in Gemini CLI OAuth client (public)
// - google_one: always use built-in Gemini CLI OAuth client (public)
// - ai_studio: requires a user-provided OAuth client
oauthCfgInput := geminicli.OAuthConfig{
ClientID: c.cfg.Gemini.OAuth.ClientID,
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
Scopes: c.cfg.Gemini.OAuth.Scopes,
}
if oauthType == "code_assist" || oauthType == "google_one" {
// Force use of built-in Gemini CLI OAuth client
oauthCfgInput.ClientID = ""
oauthCfgInput.ClientSecret = ""
}
oauthCfg, err := geminicli.EffectiveOAuthConfig(oauthCfgInput, oauthType)
if err != nil {
return nil, err
}
formData := url.Values{}
formData.Set("grant_type", "authorization_code")
formData.Set("client_id", oauthCfg.ClientID)
formData.Set("client_secret", oauthCfg.ClientSecret)
formData.Set("code", code)
formData.Set("code_verifier", codeVerifier)
formData.Set("redirect_uri", redirectURI)
var tokenResp geminicli.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(c.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, geminicli.SanitizeBodyForLogs(resp.String()))
}
return &tokenResp, nil
}
func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) {
client := createGeminiReqClient(proxyURL)
oauthCfgInput := geminicli.OAuthConfig{
ClientID: c.cfg.Gemini.OAuth.ClientID,
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
Scopes: c.cfg.Gemini.OAuth.Scopes,
}
if oauthType == "code_assist" || oauthType == "google_one" {
// Force use of built-in Gemini CLI OAuth client
oauthCfgInput.ClientID = ""
oauthCfgInput.ClientSecret = ""
}
oauthCfg, err := geminicli.EffectiveOAuthConfig(oauthCfgInput, oauthType)
if err != nil {
return nil, err
}
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken)
formData.Set("client_id", oauthCfg.ClientID)
formData.Set("client_secret", oauthCfg.ClientSecret)
var tokenResp geminicli.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(c.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, geminicli.SanitizeBodyForLogs(resp.String()))
}
return &tokenResp, nil
}
func createGeminiReqClient(proxyURL string) *req.Client {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 60 * time.Second,
})
}

View File

@@ -0,0 +1,49 @@
package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
geminiTokenKeyPrefix = "gemini:token:"
geminiRefreshLockKeyPrefix = "gemini:refresh_lock:"
)
type geminiTokenCache struct {
rdb *redis.Client
}
func NewGeminiTokenCache(rdb *redis.Client) service.GeminiTokenCache {
return &geminiTokenCache{rdb: rdb}
}
func (c *geminiTokenCache) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
return c.rdb.Get(ctx, key).Result()
}
func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
return c.rdb.Set(ctx, key, token, ttl).Err()
}
func (c *geminiTokenCache) DeleteAccessToken(ctx context.Context, cacheKey string) error {
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
return c.rdb.Del(ctx, key).Err()
}
func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
}
func (c *geminiTokenCache) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,47 @@
//go:build integration
package repository
import (
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type GeminiTokenCacheSuite struct {
IntegrationRedisSuite
cache service.GeminiTokenCache
}
func (s *GeminiTokenCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewGeminiTokenCache(s.rdb)
}
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken() {
cacheKey := "project-123"
token := "token-value"
require.NoError(s.T(), s.cache.SetAccessToken(s.ctx, cacheKey, token, time.Minute))
got, err := s.cache.GetAccessToken(s.ctx, cacheKey)
require.NoError(s.T(), err)
require.Equal(s.T(), token, got)
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, cacheKey))
_, err = s.cache.GetAccessToken(s.ctx, cacheKey)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
}
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken_MissingKey() {
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, "missing-key"))
}
func TestGeminiTokenCacheSuite(t *testing.T) {
suite.Run(t, new(GeminiTokenCacheSuite))
}

View File

@@ -0,0 +1,28 @@
//go:build unit
package repository
import (
"context"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
func TestGeminiTokenCache_DeleteAccessToken_RedisError(t *testing.T) {
rdb := redis.NewClient(&redis.Options{
Addr: "127.0.0.1:1",
DialTimeout: 50 * time.Millisecond,
ReadTimeout: 50 * time.Millisecond,
WriteTimeout: 50 * time.Millisecond,
})
t.Cleanup(func() {
_ = rdb.Close()
})
cache := NewGeminiTokenCache(rdb)
err := cache.DeleteAccessToken(context.Background(), "broken")
require.Error(t, err)
}

View File

@@ -0,0 +1,104 @@
package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3"
)
type geminiCliCodeAssistClient struct {
baseURL string
}
func NewGeminiCliCodeAssistClient() service.GeminiCliCodeAssistClient {
return &geminiCliCodeAssistClient{baseURL: geminicli.GeminiCliBaseURL}
}
func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessToken, proxyURL string, reqBody *geminicli.LoadCodeAssistRequest) (*geminicli.LoadCodeAssistResponse, error) {
if reqBody == nil {
reqBody = defaultLoadCodeAssistRequest()
}
var out geminicli.LoadCodeAssistResponse
resp, err := createGeminiCliReqClient(proxyURL).R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+accessToken).
SetHeader("Content-Type", "application/json").
SetHeader("User-Agent", geminicli.GeminiCLIUserAgent).
SetBody(reqBody).
SetSuccessResult(&out).
Post(c.baseURL + "/v1internal:loadCodeAssist")
if err != nil {
fmt.Printf("[CodeAssist] LoadCodeAssist request error: %v\n", err)
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
body := geminicli.SanitizeBodyForLogs(resp.String())
fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, body)
return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, body)
}
fmt.Printf("[CodeAssist] LoadCodeAssist success: status %d, response: %+v\n", resp.StatusCode, out)
return &out, nil
}
func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken, proxyURL string, reqBody *geminicli.OnboardUserRequest) (*geminicli.OnboardUserResponse, error) {
if reqBody == nil {
reqBody = defaultOnboardUserRequest()
}
fmt.Printf("[CodeAssist] OnboardUser request body: %+v\n", reqBody)
var out geminicli.OnboardUserResponse
resp, err := createGeminiCliReqClient(proxyURL).R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+accessToken).
SetHeader("Content-Type", "application/json").
SetHeader("User-Agent", geminicli.GeminiCLIUserAgent).
SetBody(reqBody).
SetSuccessResult(&out).
Post(c.baseURL + "/v1internal:onboardUser")
if err != nil {
fmt.Printf("[CodeAssist] OnboardUser request error: %v\n", err)
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
body := geminicli.SanitizeBodyForLogs(resp.String())
fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, body)
return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, body)
}
fmt.Printf("[CodeAssist] OnboardUser success: status %d, response: %+v\n", resp.StatusCode, out)
return &out, nil
}
func createGeminiCliReqClient(proxyURL string) *req.Client {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 30 * time.Second,
})
}
func defaultLoadCodeAssistRequest() *geminicli.LoadCodeAssistRequest {
return &geminicli.LoadCodeAssistRequest{
Metadata: geminicli.LoadCodeAssistMetadata{
IDEType: "ANTIGRAVITY",
Platform: "PLATFORM_UNSPECIFIED",
PluginType: "GEMINI",
},
}
}
func defaultOnboardUserRequest() *geminicli.OnboardUserRequest {
return &geminicli.OnboardUserRequest{
TierID: "LEGACY",
Metadata: geminicli.LoadCodeAssistMetadata{
IDEType: "ANTIGRAVITY",
Platform: "PLATFORM_UNSPECIFIED",
PluginType: "GEMINI",
},
}
}

View File

@@ -0,0 +1,136 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type githubReleaseClient struct {
httpClient *http.Client
downloadHTTPClient *http.Client
}
// NewGitHubReleaseClient 创建 GitHub Release 客户端
// proxyURL 为空时直连 GitHub支持 http/https/socks5/socks5h 协议
func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient {
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
ProxyURL: proxyURL,
})
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}
}
// 下载客户端需要更长的超时时间
downloadClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Minute,
ProxyURL: proxyURL,
})
if err != nil {
downloadClient = &http.Client{Timeout: 10 * time.Minute}
}
return &githubReleaseClient{
httpClient: sharedClient,
downloadHTTPClient: downloadClient,
}
}
func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo string) (*service.GitHubRelease, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/vnd.github.v3+json")
req.Header.Set("User-Agent", "Sub2API-Updater")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
}
var release service.GitHubRelease
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
return nil, err
}
return &release, nil
}
func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string, maxSize int64) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return err
}
// 使用预配置的下载客户端(已包含代理配置)
resp, err := c.downloadHTTPClient.Do(req)
if err != nil {
return err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download returned %d", resp.StatusCode)
}
// SECURITY: Check Content-Length if available
if resp.ContentLength > maxSize {
return fmt.Errorf("file too large: %d bytes (max %d)", resp.ContentLength, maxSize)
}
out, err := os.Create(dest)
if err != nil {
return err
}
defer func() { _ = out.Close() }()
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
limited := io.LimitReader(resp.Body, maxSize+1)
written, err := io.Copy(out, limited)
if err != nil {
return err
}
// Check if we hit the limit (downloaded more than maxSize)
if written > maxSize {
_ = os.Remove(dest) // Clean up partial file (best-effort)
return fmt.Errorf("download exceeded maximum size of %d bytes", maxSize)
}
return nil
}
func (c *githubReleaseClient) FetchChecksumFile(ctx context.Context, url string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
}
return io.ReadAll(resp.Body)
}

View File

@@ -0,0 +1,317 @@
package repository
import (
"bytes"
"context"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type GitHubReleaseServiceSuite struct {
suite.Suite
srv *httptest.Server
client *githubReleaseClient
tempDir string
}
// testTransport redirects requests to the test server
type testTransport struct {
testServerURL string
}
func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Rewrite the URL to point to our test server
testURL := t.testServerURL + req.URL.Path
newReq, err := http.NewRequestWithContext(req.Context(), req.Method, testURL, req.Body)
if err != nil {
return nil, err
}
newReq.Header = req.Header
return http.DefaultTransport.RoundTrip(newReq)
}
func newTestGitHubReleaseClient() *githubReleaseClient {
return &githubReleaseClient{
httpClient: &http.Client{},
downloadHTTPClient: &http.Client{},
}
}
func (s *GitHubReleaseServiceSuite) SetupTest() {
s.tempDir = s.T().TempDir()
}
func (s *GitHubReleaseServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLength() {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", "100")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(bytes.Repeat([]byte("a"), 100))
}))
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "file1.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
require.Error(s.T(), err, "expected error for oversized download with Content-Length")
_, statErr := os.Stat(dest)
require.Error(s.T(), statErr, "expected file to not exist for rejected download")
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Force chunked encoding (unknown Content-Length) by flushing headers before writing.
w.WriteHeader(http.StatusOK)
if fl, ok := w.(http.Flusher); ok {
fl.Flush()
}
for i := 0; i < 10; i++ {
_, _ = w.Write(bytes.Repeat([]byte("b"), 10))
if fl, ok := w.(http.Flusher); ok {
fl.Flush()
}
}
}))
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "file2.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
require.Error(s.T(), err, "expected error for oversized chunked download")
_, statErr := os.Stat(dest)
require.Error(s.T(), statErr, "expected file to be cleaned up for oversized chunked download")
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
if fl, ok := w.(http.Flusher); ok {
fl.Flush()
}
for i := 0; i < 10; i++ {
_, _ = w.Write(bytes.Repeat([]byte("b"), 10))
if fl, ok := w.(http.Flusher); ok {
fl.Flush()
}
}
}))
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "file3.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 200)
require.NoError(s.T(), err, "expected success")
b, err := os.ReadFile(dest)
require.NoError(s.T(), err, "read")
require.True(s.T(), strings.HasPrefix(string(b), "b"), "downloaded content should start with 'b'")
require.Len(s.T(), b, 100, "downloaded content length mismatch")
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "notfound.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
require.Error(s.T(), err, "expected error for 404")
_, statErr := os.Stat(dest)
require.Error(s.T(), statErr, "expected file to not exist for 404")
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("sum"))
}))
s.client = newTestGitHubReleaseClient()
body, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
require.NoError(s.T(), err, "FetchChecksumFile")
require.Equal(s.T(), "sum", string(body), "checksum body mismatch")
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
s.client = newTestGitHubReleaseClient()
_, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
require.Error(s.T(), err, "expected error for non-200")
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done()
}))
s.client = newTestGitHubReleaseClient()
ctx, cancel := context.WithCancel(context.Background())
cancel()
dest := filepath.Join(s.tempDir, "cancelled.bin")
err := s.client.DownloadFile(ctx, s.srv.URL, dest, 100)
require.Error(s.T(), err, "expected error for cancelled context")
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
s.client = newTestGitHubReleaseClient()
dest := filepath.Join(s.tempDir, "invalid.bin")
err := s.client.DownloadFile(context.Background(), "://invalid-url", dest, 100)
require.Error(s.T(), err, "expected error for invalid URL")
}
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("content"))
}))
s.client = newTestGitHubReleaseClient()
// Use a path that cannot be created (directory doesn't exist)
dest := filepath.Join(s.tempDir, "nonexistent", "subdir", "file.bin")
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
require.Error(s.T(), err, "expected error for invalid destination path")
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_InvalidURL() {
s.client = newTestGitHubReleaseClient()
_, err := s.client.FetchChecksumFile(context.Background(), "://invalid-url")
require.Error(s.T(), err, "expected error for invalid URL")
}
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
releaseJSON := `{
"tag_name": "v1.0.0",
"name": "Release 1.0.0",
"body": "Release notes",
"html_url": "https://github.com/test/repo/releases/v1.0.0",
"assets": [
{
"name": "app-linux-amd64.tar.gz",
"browser_download_url": "https://github.com/test/repo/releases/download/v1.0.0/app-linux-amd64.tar.gz"
}
]
}`
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(s.T(), "/repos/test/repo/releases/latest", r.URL.Path)
require.Equal(s.T(), "application/vnd.github.v3+json", r.Header.Get("Accept"))
require.Equal(s.T(), "Sub2API-Updater", r.Header.Get("User-Agent"))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(releaseJSON))
}))
// Use custom transport to redirect requests to test server
s.client = &githubReleaseClient{
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
downloadHTTPClient: &http.Client{},
}
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
require.NoError(s.T(), err)
require.Equal(s.T(), "v1.0.0", release.TagName)
require.Equal(s.T(), "Release 1.0.0", release.Name)
require.Len(s.T(), release.Assets, 1)
require.Equal(s.T(), "app-linux-amd64.tar.gz", release.Assets[0].Name)
}
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
s.client = &githubReleaseClient{
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
downloadHTTPClient: &http.Client{},
}
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
require.Error(s.T(), err)
require.Contains(s.T(), err.Error(), "404")
}
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("not valid json"))
}))
s.client = &githubReleaseClient{
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
downloadHTTPClient: &http.Client{},
}
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
require.Error(s.T(), err)
}
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done()
}))
s.client = &githubReleaseClient{
httpClient: &http.Client{
Transport: &testTransport{testServerURL: s.srv.URL},
},
downloadHTTPClient: &http.Client{},
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := s.client.FetchLatestRelease(ctx, "test/repo")
require.Error(s.T(), err)
}
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-r.Context().Done()
}))
s.client = newTestGitHubReleaseClient()
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := s.client.FetchChecksumFile(ctx, s.srv.URL)
require.Error(s.T(), err)
}
func TestGitHubReleaseServiceSuite(t *testing.T) {
suite.Run(t, new(GitHubReleaseServiceSuite))
}

View File

@@ -0,0 +1,413 @@
package repository
import (
"context"
"database/sql"
"errors"
"log"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
type sqlExecutor interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
type groupRepository struct {
client *dbent.Client
sql sqlExecutor
}
func NewGroupRepository(client *dbent.Client, sqlDB *sql.DB) service.GroupRepository {
return newGroupRepositoryWithSQL(client, sqlDB)
}
func newGroupRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *groupRepository {
return &groupRepository{client: client, sql: sqlq}
}
func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) error {
builder := r.client.Group.Create().
SetName(groupIn.Name).
SetDescription(groupIn.Description).
SetPlatform(groupIn.Platform).
SetRateMultiplier(groupIn.RateMultiplier).
SetIsExclusive(groupIn.IsExclusive).
SetStatus(groupIn.Status).
SetSubscriptionType(groupIn.SubscriptionType).
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetNillableFallbackGroupID(groupIn.FallbackGroupID)
created, err := builder.Save(ctx)
if err == nil {
groupIn.ID = created.ID
groupIn.CreatedAt = created.CreatedAt
groupIn.UpdatedAt = created.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group create failed: group=%d err=%v", groupIn.ID, err)
}
}
return translatePersistenceError(err, nil, service.ErrGroupExists)
}
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) {
out, err := r.GetByIDLite(ctx, id)
if err != nil {
return nil, err
}
count, _ := r.GetAccountCount(ctx, out.ID)
out.AccountCount = count
return out, nil
}
func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
// AccountCount is intentionally not loaded here; use GetByID when needed.
m, err := r.client.Group.Query().
Where(group.IDEQ(id)).
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
return groupEntityToService(m), nil
}
func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error {
builder := r.client.Group.UpdateOneID(groupIn.ID).
SetName(groupIn.Name).
SetDescription(groupIn.Description).
SetPlatform(groupIn.Platform).
SetRateMultiplier(groupIn.RateMultiplier).
SetIsExclusive(groupIn.IsExclusive).
SetStatus(groupIn.Status).
SetSubscriptionType(groupIn.SubscriptionType).
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
SetNillableImagePrice1k(groupIn.ImagePrice1K).
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly)
// 处理 FallbackGroupIDnil 时清除,否则设置
if groupIn.FallbackGroupID != nil {
builder = builder.SetFallbackGroupID(*groupIn.FallbackGroupID)
} else {
builder = builder.ClearFallbackGroupID()
}
updated, err := builder.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
}
groupIn.UpdatedAt = updated.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group update failed: group=%d err=%v", groupIn.ID, err)
}
return nil
}
func (r *groupRepository) Delete(ctx context.Context, id int64) error {
_, err := r.client.Group.Delete().Where(group.IDEQ(id)).Exec(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group delete failed: group=%d err=%v", id, err)
}
return nil
}
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", nil)
}
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
q := r.client.Group.Query()
if platform != "" {
q = q.Where(group.PlatformEQ(platform))
}
if status != "" {
q = q.Where(group.StatusEQ(status))
}
if search != "" {
q = q.Where(group.Or(
group.NameContainsFold(search),
group.DescriptionContainsFold(search),
))
}
if isExclusive != nil {
q = q.Where(group.IsExclusiveEQ(*isExclusive))
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
groups, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Asc(group.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
groupIDs := make([]int64, 0, len(groups))
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
g := groupEntityToService(groups[i])
outGroups = append(outGroups, *g)
groupIDs = append(groupIDs, g.ID)
}
counts, err := r.loadAccountCounts(ctx, groupIDs)
if err == nil {
for i := range outGroups {
outGroups[i].AccountCount = counts[outGroups[i].ID]
}
}
return outGroups, paginationResultFromTotal(int64(total), params), nil
}
func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, error) {
groups, err := r.client.Group.Query().
Where(group.StatusEQ(service.StatusActive)).
Order(dbent.Asc(group.FieldID)).
All(ctx)
if err != nil {
return nil, err
}
groupIDs := make([]int64, 0, len(groups))
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
g := groupEntityToService(groups[i])
outGroups = append(outGroups, *g)
groupIDs = append(groupIDs, g.ID)
}
counts, err := r.loadAccountCounts(ctx, groupIDs)
if err == nil {
for i := range outGroups {
outGroups[i].AccountCount = counts[outGroups[i].ID]
}
}
return outGroups, nil
}
func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
groups, err := r.client.Group.Query().
Where(group.StatusEQ(service.StatusActive), group.PlatformEQ(platform)).
Order(dbent.Asc(group.FieldID)).
All(ctx)
if err != nil {
return nil, err
}
groupIDs := make([]int64, 0, len(groups))
outGroups := make([]service.Group, 0, len(groups))
for i := range groups {
g := groupEntityToService(groups[i])
outGroups = append(outGroups, *g)
groupIDs = append(groupIDs, g.ID)
}
counts, err := r.loadAccountCounts(ctx, groupIDs)
if err == nil {
for i := range outGroups {
outGroups[i].AccountCount = counts[outGroups[i].ID]
}
}
return outGroups, nil
}
func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
return r.client.Group.Query().Where(group.NameEQ(name)).Exist(ctx)
}
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil {
return 0, err
}
return count, nil
}
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
res, err := r.sql.ExecContext(ctx, "DELETE FROM account_groups WHERE group_id = $1", groupID)
if err != nil {
return 0, err
}
affected, _ := res.RowsAffected()
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group account clear failed: group=%d err=%v", groupID, err)
}
return affected, nil
}
func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
g, err := r.client.Group.Query().Where(group.IDEQ(id)).Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
groupSvc := groupEntityToService(g)
// 使用 ent 事务统一包裹:避免手工基于 *sql.Tx 构造 ent client 带来的驱动断言问题,
// 同时保证级联删除的原子性。
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return nil, err
}
exec := r.client
txClient := r.client
if err == nil {
defer func() { _ = tx.Rollback() }()
exec = tx.Client()
txClient = exec
}
// err 为 dbent.ErrTxStarted 时,复用当前 client 参与同一事务。
// Lock the group row to avoid concurrent writes while we cascade.
// 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分"未找到"与其他错误。
rows, err := exec.QueryContext(ctx, "SELECT id FROM groups WHERE id = $1 AND deleted_at IS NULL FOR UPDATE", id)
if err != nil {
return nil, err
}
var lockedID int64
if rows.Next() {
if err := rows.Scan(&lockedID); err != nil {
_ = rows.Close()
return nil, err
}
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
if lockedID == 0 {
return nil, service.ErrGroupNotFound
}
var affectedUserIDs []int64
if groupSvc.IsSubscriptionType() {
// 只查询未软删除的订阅,避免通知已取消订阅的用户
rows, err := exec.QueryContext(ctx, "SELECT user_id FROM user_subscriptions WHERE group_id = $1 AND deleted_at IS NULL", id)
if err != nil {
return nil, err
}
for rows.Next() {
var userID int64
if scanErr := rows.Scan(&userID); scanErr != nil {
_ = rows.Close()
return nil, scanErr
}
affectedUserIDs = append(affectedUserIDs, userID)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
// 软删除订阅:设置 deleted_at 而非硬删除
if _, err := exec.ExecContext(ctx, "UPDATE user_subscriptions SET deleted_at = NOW() WHERE group_id = $1 AND deleted_at IS NULL", id); err != nil {
return nil, err
}
}
// 2. Clear group_id for api keys bound to this group.
// 仅更新未软删除的记录,避免修改已删除数据,保证审计与历史回溯一致性。
// 与 APIKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
if _, err := txClient.APIKey.Update().
Where(apikey.GroupIDEQ(id), apikey.DeletedAtIsNil()).
ClearGroupID().
Save(ctx); err != nil {
return nil, err
}
// 3. Remove the group id from user_allowed_groups join table.
// Legacy users.allowed_groups 列已弃用,不再同步。
if _, err := exec.ExecContext(ctx, "DELETE FROM user_allowed_groups WHERE group_id = $1", id); err != nil {
return nil, err
}
// 4. Delete account_groups join rows.
if _, err := exec.ExecContext(ctx, "DELETE FROM account_groups WHERE group_id = $1", id); err != nil {
return nil, err
}
// 5. Soft-delete group itself.
if _, err := txClient.Group.Delete().Where(group.IDEQ(id)).Exec(ctx); err != nil {
return nil, err
}
if tx != nil {
if err := tx.Commit(); err != nil {
return nil, err
}
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group cascade delete failed: group=%d err=%v", id, err)
}
return affectedUserIDs, nil
}
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]int64, err error) {
counts = make(map[int64]int64, len(groupIDs))
if len(groupIDs) == 0 {
return counts, nil
}
rows, err := r.sql.QueryContext(
ctx,
"SELECT group_id, COUNT(*) FROM account_groups WHERE group_id = ANY($1) GROUP BY group_id",
pq.Array(groupIDs),
)
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
counts = nil
}
}()
for rows.Next() {
var groupID int64
var count int64
if err = rows.Scan(&groupID, &count); err != nil {
return nil, err
}
counts[groupID] = count
}
if err = rows.Err(); err != nil {
return nil, err
}
return counts, nil
}

View File

@@ -0,0 +1,677 @@
//go:build integration
package repository
import (
"context"
"database/sql"
"errors"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
)
type GroupRepoSuite struct {
suite.Suite
ctx context.Context
tx *dbent.Tx
repo *groupRepository
}
type forbidSQLExecutor struct {
called bool
}
func (s *forbidSQLExecutor) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
s.called = true
return nil, errors.New("unexpected sql exec")
}
func (s *forbidSQLExecutor) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
s.called = true
return nil, errors.New("unexpected sql query")
}
func (s *GroupRepoSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.tx = tx
s.repo = newGroupRepositoryWithSQL(tx.Client(), tx)
}
func TestGroupRepoSuite(t *testing.T) {
suite.Run(t, new(GroupRepoSuite))
}
// --- Create / GetByID / Update / Delete ---
func (s *GroupRepoSuite) TestCreate() {
group := &service.Group{
Name: "test-create",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
err := s.repo.Create(s.ctx, group)
s.Require().NoError(err, "Create")
s.Require().NotZero(group.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, group.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("test-create", got.Name)
}
func (s *GroupRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
s.Require().ErrorIs(err, service.ErrGroupNotFound)
}
func (s *GroupRepoSuite) TestGetByIDLite_DoesNotUseAccountCount() {
group := &service.Group{
Name: "lite-group",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
spy := &forbidSQLExecutor{}
repo := newGroupRepositoryWithSQL(s.tx.Client(), spy)
got, err := repo.GetByIDLite(s.ctx, group.ID)
s.Require().NoError(err)
s.Require().Equal(group.ID, got.ID)
s.Require().False(spy.called, "expected no direct sql executor usage")
}
func (s *GroupRepoSuite) TestUpdate() {
group := &service.Group{
Name: "original",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
group.Name = "updated"
err := s.repo.Update(s.ctx, group)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, group.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated", got.Name)
}
func (s *GroupRepoSuite) TestDelete() {
group := &service.Group{
Name: "to-delete",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
err := s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, group.ID)
s.Require().Error(err, "expected error after delete")
s.Require().ErrorIs(err, service.ErrGroupNotFound)
}
// --- List / ListWithFilters ---
func (s *GroupRepoSuite) TestList() {
baseGroups, basePage, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List base")
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g2",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
s.Require().Len(groups, len(baseGroups)+2)
s.Require().Equal(basePage.Total+2, page.Total)
}
func (s *GroupRepoSuite) TestListWithFilters_Platform() {
baseGroups, _, err := s.repo.ListWithFilters(
s.ctx,
pagination.PaginationParams{Page: 1, PageSize: 10},
service.PlatformOpenAI,
"",
"",
nil,
)
s.Require().NoError(err, "ListWithFilters base")
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g2",
Platform: service.PlatformOpenAI,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", "", nil)
s.Require().NoError(err)
s.Require().Len(groups, len(baseGroups)+1)
// Verify all groups are OpenAI platform
for _, g := range groups {
s.Require().Equal(service.PlatformOpenAI, g.Platform)
}
}
func (s *GroupRepoSuite) TestListWithFilters_Status() {
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g2",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusDisabled,
SubscriptionType: service.SubscriptionTypeStandard,
}))
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "", nil)
s.Require().NoError(err)
s.Require().Len(groups, 1)
s.Require().Equal(service.StatusDisabled, groups[0].Status)
}
func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g2",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: true,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
isExclusive := true
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", &isExclusive)
s.Require().NoError(err)
s.Require().Len(groups, 1)
s.Require().True(groups[0].IsExclusive)
}
func (s *GroupRepoSuite) TestListWithFilters_Search() {
newRepo := func() (*groupRepository, context.Context) {
tx := testEntTx(s.T())
return newGroupRepositoryWithSQL(tx.Client(), tx), context.Background()
}
containsID := func(groups []service.Group, id int64) bool {
for i := range groups {
if groups[i].ID == id {
return true
}
}
return false
}
mustCreate := func(repo *groupRepository, ctx context.Context, g *service.Group) *service.Group {
s.Require().NoError(repo.Create(ctx, g))
s.Require().NotZero(g.ID)
return g
}
newGroup := func(name string) *service.Group {
return &service.Group{
Name: name,
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
}
s.Run("search_name_should_match", func() {
repo, ctx := newRepo()
target := mustCreate(repo, ctx, newGroup("it-group-search-name-target"))
other := mustCreate(repo, ctx, newGroup("it-group-search-name-other"))
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "name-target", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, target.ID), "expected target group to match by name")
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
})
s.Run("search_description_should_match", func() {
repo, ctx := newRepo()
target := newGroup("it-group-search-desc-target")
target.Description = "something about desc-needle in here"
target = mustCreate(repo, ctx, target)
other := newGroup("it-group-search-desc-other")
other.Description = "nothing to see here"
other = mustCreate(repo, ctx, other)
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "desc-needle", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, target.ID), "expected target group to match by description")
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
})
s.Run("search_nonexistent_should_return_empty", func() {
repo, ctx := newRepo()
_ = mustCreate(repo, ctx, newGroup("it-group-search-nonexistent-baseline"))
search := s.T().Name() + "__no_such_group__"
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", search, nil)
s.Require().NoError(err)
s.Require().Empty(groups)
})
s.Run("search_should_be_case_insensitive", func() {
repo, ctx := newRepo()
target := mustCreate(repo, ctx, newGroup("MiXeDCaSe-Needle"))
other := mustCreate(repo, ctx, newGroup("it-group-search-case-other"))
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "mixedcase-needle", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, target.ID), "expected case-insensitive match")
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
})
s.Run("search_should_escape_like_wildcards", func() {
repo, ctx := newRepo()
percentTarget := mustCreate(repo, ctx, newGroup("it-group-search-100%-target"))
percentOther := mustCreate(repo, ctx, newGroup("it-group-search-100X-other"))
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "100%", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, percentTarget.ID), "expected literal %% match")
s.Require().False(containsID(groups, percentOther.ID), "expected %% not to act as wildcard")
underscoreTarget := mustCreate(repo, ctx, newGroup("it-group-search-ab_cd-target"))
underscoreOther := mustCreate(repo, ctx, newGroup("it-group-search-abXcd-other"))
groups, _, err = repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "ab_cd", nil)
s.Require().NoError(err)
s.Require().True(containsID(groups, underscoreTarget.ID), "expected literal _ match")
s.Require().False(containsID(groups, underscoreOther.ID), "expected _ not to act as wildcard")
})
}
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
g1 := &service.Group{
Name: "g1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
g2 := &service.Group{
Name: "g2",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: true,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, g1))
s.Require().NoError(s.repo.Create(s.ctx, g2))
var accountID int64
s.Require().NoError(scanSingleRow(
s.ctx,
s.tx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
[]any{"acc1", service.PlatformAnthropic, service.AccountTypeOAuth},
&accountID,
))
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g1.ID, 1)
s.Require().NoError(err)
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g2.ID, 1)
s.Require().NoError(err)
isExclusive := true
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, "", &isExclusive)
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(groups, 1)
s.Require().Equal(g2.ID, groups[0].ID, "ListWithFilters returned wrong group")
s.Require().Equal(int64(1), groups[0].AccountCount, "AccountCount mismatch")
}
// --- ListActive / ListActiveByPlatform ---
func (s *GroupRepoSuite) TestListActive() {
baseGroups, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive base")
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "active1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "inactive1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusDisabled,
SubscriptionType: service.SubscriptionTypeStandard,
}))
groups, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
s.Require().Len(groups, len(baseGroups)+1)
// Verify our test group is in the results
var found bool
for _, g := range groups {
if g.Name == "active1" {
found = true
break
}
}
s.Require().True(found, "active1 group should be in results")
}
func (s *GroupRepoSuite) TestListActiveByPlatform() {
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g2",
Platform: service.PlatformOpenAI,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "g3",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusDisabled,
SubscriptionType: service.SubscriptionTypeStandard,
}))
groups, err := s.repo.ListActiveByPlatform(s.ctx, service.PlatformAnthropic)
s.Require().NoError(err, "ListActiveByPlatform")
// 1 default anthropic group + 1 test active anthropic group = 2 total
s.Require().Len(groups, 2)
// Verify our test group is in the results
var found bool
for _, g := range groups {
if g.Name == "g1" {
found = true
break
}
}
s.Require().True(found, "g1 group should be in results")
}
// --- ExistsByName ---
func (s *GroupRepoSuite) TestExistsByName() {
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
Name: "existing-group",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}))
exists, err := s.repo.ExistsByName(s.ctx, "existing-group")
s.Require().NoError(err, "ExistsByName")
s.Require().True(exists)
notExists, err := s.repo.ExistsByName(s.ctx, "non-existing")
s.Require().NoError(err)
s.Require().False(notExists)
}
// --- GetAccountCount ---
func (s *GroupRepoSuite) TestGetAccountCount() {
group := &service.Group{
Name: "g-count",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
var a1 int64
s.Require().NoError(scanSingleRow(
s.ctx,
s.tx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
[]any{"a1", service.PlatformAnthropic, service.AccountTypeOAuth},
&a1,
))
var a2 int64
s.Require().NoError(scanSingleRow(
s.ctx,
s.tx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
[]any{"a2", service.PlatformAnthropic, service.AccountTypeOAuth},
&a2,
))
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a1, group.ID, 1)
s.Require().NoError(err)
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, group.ID, 2)
s.Require().NoError(err)
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err, "GetAccountCount")
s.Require().Equal(int64(2), count)
}
func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
group := &service.Group{
Name: "g-empty",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
s.Require().NoError(err)
s.Require().Zero(count)
}
// --- DeleteAccountGroupsByGroupID ---
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
g := &service.Group{
Name: "g-del",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, g))
var accountID int64
s.Require().NoError(scanSingleRow(
s.ctx,
s.tx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
[]any{"acc-del", service.PlatformAnthropic, service.AccountTypeOAuth},
&accountID,
))
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g.ID, 1)
s.Require().NoError(err)
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
s.Require().NoError(err, "DeleteAccountGroupsByGroupID")
s.Require().Equal(int64(1), affected, "expected 1 affected row")
count, err := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().NoError(err, "GetAccountCount")
s.Require().Equal(int64(0), count, "expected 0 account groups")
}
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
g := &service.Group{
Name: "g-multi",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, g))
insertAccount := func(name string) int64 {
var id int64
s.Require().NoError(scanSingleRow(
s.ctx,
s.tx,
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
[]any{name, service.PlatformAnthropic, service.AccountTypeOAuth},
&id,
))
return id
}
a1 := insertAccount("a1")
a2 := insertAccount("a2")
a3 := insertAccount("a3")
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a1, g.ID, 1)
s.Require().NoError(err)
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, g.ID, 2)
s.Require().NoError(err)
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a3, g.ID, 3)
s.Require().NoError(err)
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
s.Require().NoError(err)
s.Require().Equal(int64(3), affected)
count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
s.Require().Zero(count)
}
// --- 软删除过滤测试 ---
func (s *GroupRepoSuite) TestDelete_SoftDelete_NotVisibleInList() {
group := &service.Group{
Name: "to-soft-delete",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
// 获取删除前的列表数量
listBefore, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100})
s.Require().NoError(err)
beforeCount := len(listBefore)
// 软删除
err = s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err, "Delete (soft delete)")
// 验证列表中不再包含软删除的 group
listAfter, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100})
s.Require().NoError(err)
s.Require().Len(listAfter, beforeCount-1, "soft deleted group should not appear in list")
// 验证 GetByID 也无法找到
_, err = s.repo.GetByID(s.ctx, group.ID)
s.Require().Error(err)
s.Require().ErrorIs(err, service.ErrGroupNotFound)
}
func (s *GroupRepoSuite) TestDelete_SoftDeletedGroup_lockForUpdate() {
group := &service.Group{
Name: "lock-soft-delete",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
// 软删除
err := s.repo.Delete(s.ctx, group.ID)
s.Require().NoError(err)
// 验证软删除的 group 在 GetByID 时返回 ErrGroupNotFound
// 这证明 lockForUpdate 的 deleted_at IS NULL 过滤正在工作
_, err = s.repo.GetByID(s.ctx, group.ID)
s.Require().Error(err, "should fail to get soft-deleted group")
s.Require().ErrorIs(err, service.ErrGroupNotFound)
}

View File

@@ -0,0 +1,653 @@
package repository
import (
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
// 默认配置常量
// 这些值在配置文件未指定时作为回退默认值使用
const (
// directProxyKey: 无代理时的缓存键标识
directProxyKey = "direct"
// defaultMaxIdleConns: 默认最大空闲连接总数
// HTTP/2 场景下单连接可多路复用240 足以支撑高并发
defaultMaxIdleConns = 240
// defaultMaxIdleConnsPerHost: 默认每主机最大空闲连接数
defaultMaxIdleConnsPerHost = 120
// defaultMaxConnsPerHost: 默认每主机最大连接数(含活跃连接)
// 达到上限后新请求会等待,而非无限创建连接
defaultMaxConnsPerHost = 240
// defaultIdleConnTimeout: 默认空闲连接超时时间90秒
// 超时后连接会被关闭,释放系统资源(建议小于上游 LB 超时)
defaultIdleConnTimeout = 90 * time.Second
// defaultResponseHeaderTimeout: 默认等待响应头超时时间5分钟
// LLM 请求可能排队较久,需要较长超时
defaultResponseHeaderTimeout = 300 * time.Second
// defaultMaxUpstreamClients: 默认最大客户端缓存数量
// 超出后会淘汰最久未使用的客户端
defaultMaxUpstreamClients = 5000
// defaultClientIdleTTLSeconds: 默认客户端空闲回收阈值15分钟
defaultClientIdleTTLSeconds = 900
)
var errUpstreamClientLimitReached = errors.New("upstream client cache limit reached")
// poolSettings 连接池配置参数
// 封装 Transport 所需的各项连接池参数
type poolSettings struct {
maxIdleConns int // 最大空闲连接总数
maxIdleConnsPerHost int // 每主机最大空闲连接数
maxConnsPerHost int // 每主机最大连接数(含活跃)
idleConnTimeout time.Duration // 空闲连接超时时间
responseHeaderTimeout time.Duration // 等待响应头超时时间
}
// upstreamClientEntry 上游客户端缓存条目
// 记录客户端实例及其元数据,用于连接池管理和淘汰策略
type upstreamClientEntry struct {
client *http.Client // HTTP 客户端实例
proxyKey string // 代理标识(用于检测代理变更)
poolKey string // 连接池配置标识(用于检测配置变更)
lastUsed int64 // 最后使用时间戳(纳秒),用于 LRU 淘汰
inFlight int64 // 当前进行中的请求数,>0 时不可淘汰
}
// httpUpstreamService 通用 HTTP 上游服务
// 用于向任意 HTTP APIClaude、OpenAI 等)发送请求,支持可选代理
//
// 架构设计:
// - 根据隔离策略proxy/account/account_proxy缓存客户端实例
// - 每个客户端拥有独立的 Transport 连接池
// - 支持 LRU + 空闲时间双重淘汰策略
//
// 性能优化:
// 1. 根据隔离策略缓存客户端实例,避免频繁创建 http.Client
// 2. 复用 Transport 连接池,减少 TCP 握手和 TLS 协商开销
// 3. 支持账号级隔离与空闲回收,降低连接层关联风险
// 4. 达到最大连接数后等待可用连接,而非无限创建
// 5. 仅回收空闲客户端,避免中断活跃请求
// 6. HTTP/2 多路复用,连接上限不等于并发请求上限
// 7. 代理变更时清空旧连接池,避免复用错误代理
// 8. 账号并发数与连接池上限对应(账号隔离策略下)
type httpUpstreamService struct {
cfg *config.Config // 全局配置
mu sync.RWMutex // 保护 clients map 的读写锁
clients map[string]*upstreamClientEntry // 客户端缓存池key 由隔离策略决定
}
// NewHTTPUpstream 创建通用 HTTP 上游服务
// 使用配置中的连接池参数构建 Transport
//
// 参数:
// - cfg: 全局配置,包含连接池参数和隔离策略
//
// 返回:
// - service.HTTPUpstream 接口实现
func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
return &httpUpstreamService{
cfg: cfg,
clients: make(map[string]*upstreamClientEntry),
}
}
// Do 执行 HTTP 请求
// 根据隔离策略获取或创建客户端,并跟踪请求生命周期
//
// 参数:
// - req: HTTP 请求对象
// - proxyURL: 代理地址,空字符串表示直连
// - accountID: 账户 ID用于账户级隔离
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
//
// 返回:
// - *http.Response: HTTP 响应Body 已包装,关闭时自动更新计数)
// - error: 请求错误
//
// 注意:
// - 调用方必须关闭 resp.Body否则会导致 inFlight 计数泄漏
// - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
if err := s.validateRequestHost(req); err != nil {
return nil, err
}
// 获取或创建对应的客户端,并标记请求占用
entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency)
if err != nil {
return nil, err
}
// 执行请求
resp, err := entry.client.Do(req)
if err != nil {
// 请求失败,立即减少计数
atomic.AddInt64(&entry.inFlight, -1)
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
return nil, err
}
// 包装响应体,在关闭时自动减少计数并更新时间戳
// 这确保了流式响应(如 SSE在完全读取前不会被淘汰
resp.Body = wrapTrackedBody(resp.Body, func() {
atomic.AddInt64(&entry.inFlight, -1)
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
})
return resp, nil
}
func (s *httpUpstreamService) shouldValidateResolvedIP() bool {
if s.cfg == nil {
return false
}
if !s.cfg.Security.URLAllowlist.Enabled {
return false
}
return !s.cfg.Security.URLAllowlist.AllowPrivateHosts
}
func (s *httpUpstreamService) validateRequestHost(req *http.Request) error {
if !s.shouldValidateResolvedIP() {
return nil
}
if req == nil || req.URL == nil {
return errors.New("request url is nil")
}
host := strings.TrimSpace(req.URL.Hostname())
if host == "" {
return errors.New("request host is empty")
}
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
return err
}
return nil
}
func (s *httpUpstreamService) redirectChecker(req *http.Request, via []*http.Request) error {
if len(via) >= 10 {
return errors.New("stopped after 10 redirects")
}
return s.validateRequestHost(req)
}
// acquireClient 获取或创建客户端,并标记为进行中请求
// 用于请求路径,避免在获取后被淘汰
func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
return s.getClientEntry(proxyURL, accountID, accountConcurrency, true, true)
}
// getOrCreateClient 获取或创建客户端
// 根据隔离策略和参数决定缓存键,处理代理变更和配置变更
//
// 参数:
// - proxyURL: 代理地址
// - accountID: 账户 ID
// - accountConcurrency: 账户并发限制
//
// 返回:
// - *upstreamClientEntry: 客户端缓存条目
//
// 隔离策略说明:
// - proxy: 按代理地址隔离,同一代理共享客户端
// - account: 按账户隔离,同一账户共享客户端(代理变更时重建)
// - account_proxy: 按账户+代理组合隔离,最细粒度
func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) *upstreamClientEntry {
entry, _ := s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false)
return entry
}
// getClientEntry 获取或创建客户端条目
// markInFlight=true 时会标记进行中请求,用于请求路径防止被淘汰
// enforceLimit=true 时会限制客户端数量,超限且无法淘汰时返回错误
func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
// 获取隔离模式
isolation := s.getIsolationMode()
// 标准化代理 URL 并解析
proxyKey, parsedProxy := normalizeProxyURL(proxyURL)
// 构建缓存键(根据隔离策略不同)
cacheKey := buildCacheKey(isolation, proxyKey, accountID)
// 构建连接池配置键(用于检测配置变更)
poolKey := s.buildPoolKey(isolation, accountConcurrency)
now := time.Now()
nowUnix := now.UnixNano()
// 读锁快速路径:命中缓存直接返回,减少锁竞争
s.mu.RLock()
if entry, ok := s.clients[cacheKey]; ok && s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
atomic.StoreInt64(&entry.lastUsed, nowUnix)
if markInFlight {
atomic.AddInt64(&entry.inFlight, 1)
}
s.mu.RUnlock()
return entry, nil
}
s.mu.RUnlock()
// 写锁慢路径:创建或重建客户端
s.mu.Lock()
if entry, ok := s.clients[cacheKey]; ok {
if s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
atomic.StoreInt64(&entry.lastUsed, nowUnix)
if markInFlight {
atomic.AddInt64(&entry.inFlight, 1)
}
s.mu.Unlock()
return entry, nil
}
s.removeClientLocked(cacheKey, entry)
}
// 超出缓存上限时尝试淘汰,无法淘汰则拒绝新建
if enforceLimit && s.maxUpstreamClients() > 0 {
s.evictIdleLocked(now)
if len(s.clients) >= s.maxUpstreamClients() {
if !s.evictOldestIdleLocked() {
s.mu.Unlock()
return nil, errUpstreamClientLimitReached
}
}
}
// 缓存未命中或需要重建,创建新客户端
settings := s.resolvePoolSettings(isolation, accountConcurrency)
transport, err := buildUpstreamTransport(settings, parsedProxy)
if err != nil {
s.mu.Unlock()
return nil, fmt.Errorf("build transport: %w", err)
}
client := &http.Client{Transport: transport}
if s.shouldValidateResolvedIP() {
client.CheckRedirect = s.redirectChecker
}
entry := &upstreamClientEntry{
client: client,
proxyKey: proxyKey,
poolKey: poolKey,
}
atomic.StoreInt64(&entry.lastUsed, nowUnix)
if markInFlight {
atomic.StoreInt64(&entry.inFlight, 1)
}
s.clients[cacheKey] = entry
// 执行淘汰策略:先淘汰空闲超时的,再淘汰超出数量限制的
s.evictIdleLocked(now)
s.evictOverLimitLocked()
s.mu.Unlock()
return entry, nil
}
// shouldReuseEntry 判断缓存条目是否可复用
// 若代理或连接池配置发生变化,则需要重建客户端
func (s *httpUpstreamService) shouldReuseEntry(entry *upstreamClientEntry, isolation, proxyKey, poolKey string) bool {
if entry == nil {
return false
}
if isolation == config.ConnectionPoolIsolationAccount && entry.proxyKey != proxyKey {
return false
}
if entry.poolKey != poolKey {
return false
}
return true
}
// removeClientLocked 移除客户端(需持有锁)
// 从缓存中删除并关闭空闲连接
//
// 参数:
// - key: 缓存键
// - entry: 客户端条目
func (s *httpUpstreamService) removeClientLocked(key string, entry *upstreamClientEntry) {
delete(s.clients, key)
if entry != nil && entry.client != nil {
// 关闭空闲连接,释放系统资源
// 注意:这不会中断活跃连接
entry.client.CloseIdleConnections()
}
}
// evictIdleLocked 淘汰空闲超时的客户端(需持有锁)
// 遍历所有客户端,移除超过 TTL 且无活跃请求的条目
//
// 参数:
// - now: 当前时间
func (s *httpUpstreamService) evictIdleLocked(now time.Time) {
ttl := s.clientIdleTTL()
if ttl <= 0 {
return
}
// 计算淘汰截止时间
cutoff := now.Add(-ttl).UnixNano()
for key, entry := range s.clients {
// 跳过有活跃请求的客户端
if atomic.LoadInt64(&entry.inFlight) != 0 {
continue
}
// 淘汰超时的空闲客户端
if atomic.LoadInt64(&entry.lastUsed) <= cutoff {
s.removeClientLocked(key, entry)
}
}
}
// evictOldestIdleLocked 淘汰最久未使用且无活跃请求的客户端(需持有锁)
func (s *httpUpstreamService) evictOldestIdleLocked() bool {
var (
oldestKey string
oldestEntry *upstreamClientEntry
oldestTime int64
)
// 查找最久未使用且无活跃请求的客户端
for key, entry := range s.clients {
// 跳过有活跃请求的客户端
if atomic.LoadInt64(&entry.inFlight) != 0 {
continue
}
lastUsed := atomic.LoadInt64(&entry.lastUsed)
if oldestEntry == nil || lastUsed < oldestTime {
oldestKey = key
oldestEntry = entry
oldestTime = lastUsed
}
}
// 所有客户端都有活跃请求,无法淘汰
if oldestEntry == nil {
return false
}
s.removeClientLocked(oldestKey, oldestEntry)
return true
}
// evictOverLimitLocked 淘汰超出数量限制的客户端(需持有锁)
// 使用 LRU 策略,优先淘汰最久未使用且无活跃请求的客户端
func (s *httpUpstreamService) evictOverLimitLocked() bool {
maxClients := s.maxUpstreamClients()
if maxClients <= 0 {
return false
}
evicted := false
// 循环淘汰直到满足数量限制
for len(s.clients) > maxClients {
if !s.evictOldestIdleLocked() {
return evicted
}
evicted = true
}
return evicted
}
// getIsolationMode 获取连接池隔离模式
// 从配置中读取,无效值回退到 account_proxy 模式
//
// 返回:
// - string: 隔离模式proxy/account/account_proxy
func (s *httpUpstreamService) getIsolationMode() string {
if s.cfg == nil {
return config.ConnectionPoolIsolationAccountProxy
}
mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.ConnectionPoolIsolation))
if mode == "" {
return config.ConnectionPoolIsolationAccountProxy
}
switch mode {
case config.ConnectionPoolIsolationProxy, config.ConnectionPoolIsolationAccount, config.ConnectionPoolIsolationAccountProxy:
return mode
default:
return config.ConnectionPoolIsolationAccountProxy
}
}
// maxUpstreamClients 获取最大客户端缓存数量
// 从配置中读取,无效值使用默认值
func (s *httpUpstreamService) maxUpstreamClients() int {
if s.cfg == nil {
return defaultMaxUpstreamClients
}
if s.cfg.Gateway.MaxUpstreamClients > 0 {
return s.cfg.Gateway.MaxUpstreamClients
}
return defaultMaxUpstreamClients
}
// clientIdleTTL 获取客户端空闲回收阈值
// 从配置中读取,无效值使用默认值
func (s *httpUpstreamService) clientIdleTTL() time.Duration {
if s.cfg == nil {
return time.Duration(defaultClientIdleTTLSeconds) * time.Second
}
if s.cfg.Gateway.ClientIdleTTLSeconds > 0 {
return time.Duration(s.cfg.Gateway.ClientIdleTTLSeconds) * time.Second
}
return time.Duration(defaultClientIdleTTLSeconds) * time.Second
}
// resolvePoolSettings 解析连接池配置
// 根据隔离策略和账户并发数动态调整连接池参数
//
// 参数:
// - isolation: 隔离模式
// - accountConcurrency: 账户并发限制
//
// 返回:
// - poolSettings: 连接池配置
//
// 说明:
// - 账户隔离模式下,连接池大小与账户并发数对应
// - 这确保了单账户不会占用过多连接资源
func (s *httpUpstreamService) resolvePoolSettings(isolation string, accountConcurrency int) poolSettings {
settings := defaultPoolSettings(s.cfg)
// 账户隔离模式下,根据账户并发数调整连接池大小
if (isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy) && accountConcurrency > 0 {
settings.maxIdleConns = accountConcurrency
settings.maxIdleConnsPerHost = accountConcurrency
settings.maxConnsPerHost = accountConcurrency
}
return settings
}
// buildPoolKey 构建连接池配置键
// 用于检测配置变更,配置变更时需要重建客户端
//
// 参数:
// - isolation: 隔离模式
// - accountConcurrency: 账户并发限制
//
// 返回:
// - string: 配置键
func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency int) string {
if isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy {
if accountConcurrency > 0 {
return fmt.Sprintf("account:%d", accountConcurrency)
}
}
return "default"
}
// buildCacheKey 构建客户端缓存键
// 根据隔离策略决定缓存键的组成
//
// 参数:
// - isolation: 隔离模式
// - proxyKey: 代理标识
// - accountID: 账户 ID
//
// 返回:
// - string: 缓存键
//
// 缓存键格式:
// - proxy 模式: "proxy:{proxyKey}"
// - account 模式: "account:{accountID}"
// - account_proxy 模式: "account:{accountID}|proxy:{proxyKey}"
func buildCacheKey(isolation, proxyKey string, accountID int64) string {
switch isolation {
case config.ConnectionPoolIsolationAccount:
return fmt.Sprintf("account:%d", accountID)
case config.ConnectionPoolIsolationAccountProxy:
return fmt.Sprintf("account:%d|proxy:%s", accountID, proxyKey)
default:
return fmt.Sprintf("proxy:%s", proxyKey)
}
}
// normalizeProxyURL 标准化代理 URL
// 处理空值和解析错误,返回标准化的键和解析后的 URL
//
// 参数:
// - raw: 原始代理 URL 字符串
//
// 返回:
// - string: 标准化的代理键(空或解析失败返回 "direct"
// - *url.URL: 解析后的 URL空或解析失败返回 nil
func normalizeProxyURL(raw string) (string, *url.URL) {
proxyURL := strings.TrimSpace(raw)
if proxyURL == "" {
return directProxyKey, nil
}
parsed, err := url.Parse(proxyURL)
if err != nil {
return directProxyKey, nil
}
parsed.Scheme = strings.ToLower(parsed.Scheme)
parsed.Host = strings.ToLower(parsed.Host)
parsed.Path = ""
parsed.RawPath = ""
parsed.RawQuery = ""
parsed.Fragment = ""
parsed.ForceQuery = false
if hostname := parsed.Hostname(); hostname != "" {
port := parsed.Port()
if (parsed.Scheme == "http" && port == "80") || (parsed.Scheme == "https" && port == "443") {
port = ""
}
hostname = strings.ToLower(hostname)
if port != "" {
parsed.Host = net.JoinHostPort(hostname, port)
} else {
parsed.Host = hostname
}
}
return parsed.String(), parsed
}
// defaultPoolSettings 获取默认连接池配置
// 从全局配置中读取,无效值使用常量默认值
//
// 参数:
// - cfg: 全局配置
//
// 返回:
// - poolSettings: 连接池配置
func defaultPoolSettings(cfg *config.Config) poolSettings {
maxIdleConns := defaultMaxIdleConns
maxIdleConnsPerHost := defaultMaxIdleConnsPerHost
maxConnsPerHost := defaultMaxConnsPerHost
idleConnTimeout := defaultIdleConnTimeout
responseHeaderTimeout := defaultResponseHeaderTimeout
if cfg != nil {
if cfg.Gateway.MaxIdleConns > 0 {
maxIdleConns = cfg.Gateway.MaxIdleConns
}
if cfg.Gateway.MaxIdleConnsPerHost > 0 {
maxIdleConnsPerHost = cfg.Gateway.MaxIdleConnsPerHost
}
if cfg.Gateway.MaxConnsPerHost >= 0 {
maxConnsPerHost = cfg.Gateway.MaxConnsPerHost
}
if cfg.Gateway.IdleConnTimeoutSeconds > 0 {
idleConnTimeout = time.Duration(cfg.Gateway.IdleConnTimeoutSeconds) * time.Second
}
if cfg.Gateway.ResponseHeaderTimeout > 0 {
responseHeaderTimeout = time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
}
}
return poolSettings{
maxIdleConns: maxIdleConns,
maxIdleConnsPerHost: maxIdleConnsPerHost,
maxConnsPerHost: maxConnsPerHost,
idleConnTimeout: idleConnTimeout,
responseHeaderTimeout: responseHeaderTimeout,
}
}
// buildUpstreamTransport 构建上游请求的 Transport
// 使用配置文件中的连接池参数,支持生产环境调优
//
// 参数:
// - settings: 连接池配置
// - proxyURL: 代理 URLnil 表示直连)
//
// 返回:
// - *http.Transport: 配置好的 Transport 实例
// - error: 代理配置错误
//
// Transport 参数说明:
// - MaxIdleConns: 所有主机的最大空闲连接总数
// - MaxIdleConnsPerHost: 每主机最大空闲连接数(影响连接复用率)
// - MaxConnsPerHost: 每主机最大连接数(达到后新请求等待)
// - IdleConnTimeout: 空闲连接超时(超时后关闭)
// - ResponseHeaderTimeout: 等待响应头超时(不影响流式传输)
func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Transport, error) {
transport := &http.Transport{
MaxIdleConns: settings.maxIdleConns,
MaxIdleConnsPerHost: settings.maxIdleConnsPerHost,
MaxConnsPerHost: settings.maxConnsPerHost,
IdleConnTimeout: settings.idleConnTimeout,
ResponseHeaderTimeout: settings.responseHeaderTimeout,
}
if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil {
return nil, err
}
return transport, nil
}
// trackedBody 带跟踪功能的响应体包装器
// 在 Close 时执行回调,用于更新请求计数
type trackedBody struct {
io.ReadCloser // 原始响应体
once sync.Once
onClose func() // 关闭时的回调函数
}
// Close 关闭响应体并执行回调
// 使用 sync.Once 确保回调只执行一次
func (b *trackedBody) Close() error {
err := b.ReadCloser.Close()
if b.onClose != nil {
b.once.Do(b.onClose)
}
return err
}
// wrapTrackedBody 包装响应体以跟踪关闭事件
// 用于在响应体关闭时更新 inFlight 计数
//
// 参数:
// - body: 原始响应体
// - onClose: 关闭时的回调函数
//
// 返回:
// - io.ReadCloser: 包装后的响应体
func wrapTrackedBody(body io.ReadCloser, onClose func()) io.ReadCloser {
if body == nil {
return body
}
return &trackedBody{ReadCloser: body, onClose: onClose}
}

View File

@@ -0,0 +1,70 @@
package repository
import (
"net/http"
"net/url"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
)
// httpClientSink 用于防止编译器优化掉基准测试中的赋值操作
// 这是 Go 基准测试的常见模式,确保测试结果准确
var httpClientSink *http.Client
// BenchmarkHTTPUpstreamProxyClient 对比重复创建与复用代理客户端的开销
//
// 测试目的:
// - 验证连接池复用相比每次新建的性能提升
// - 量化内存分配差异
//
// 预期结果:
// - "复用" 子测试应显著快于 "新建"
// - "复用" 子测试应零内存分配
func BenchmarkHTTPUpstreamProxyClient(b *testing.B) {
// 创建测试配置
cfg := &config.Config{
Gateway: config.GatewayConfig{ResponseHeaderTimeout: 300},
}
upstream := NewHTTPUpstream(cfg)
svc, ok := upstream.(*httpUpstreamService)
if !ok {
b.Fatalf("类型断言失败,无法获取 httpUpstreamService")
}
proxyURL := "http://127.0.0.1:8080"
b.ReportAllocs() // 报告内存分配统计
// 子测试:每次新建客户端
// 模拟未优化前的行为,每次请求都创建新的 http.Client
b.Run("新建", func(b *testing.B) {
parsedProxy, err := url.Parse(proxyURL)
if err != nil {
b.Fatalf("解析代理地址失败: %v", err)
}
settings := defaultPoolSettings(cfg)
for i := 0; i < b.N; i++ {
// 每次迭代都创建新客户端,包含 Transport 分配
transport, err := buildUpstreamTransport(settings, parsedProxy)
if err != nil {
b.Fatalf("创建 Transport 失败: %v", err)
}
httpClientSink = &http.Client{
Transport: transport,
}
}
})
// 子测试:复用已缓存的客户端
// 模拟优化后的行为,从缓存获取客户端
b.Run("复用", func(b *testing.B) {
// 预热:确保客户端已缓存
entry := svc.getOrCreateClient(proxyURL, 1, 1)
client := entry.client
b.ResetTimer() // 重置计时器,排除预热时间
for i := 0; i < b.N; i++ {
// 直接使用缓存的客户端,无内存分配
httpClientSink = client
}
})
}

View File

@@ -0,0 +1,291 @@
package repository
import (
"io"
"net/http"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
// HTTPUpstreamSuite HTTP 上游服务测试套件
// 使用 testify/suite 组织测试,支持 SetupTest 初始化
type HTTPUpstreamSuite struct {
suite.Suite
cfg *config.Config // 测试用配置
}
// SetupTest 每个测试用例执行前的初始化
// 创建空配置,各测试用例可按需覆盖
func (s *HTTPUpstreamSuite) SetupTest() {
s.cfg = &config.Config{
Security: config.SecurityConfig{
URLAllowlist: config.URLAllowlistConfig{
AllowPrivateHosts: true,
},
},
}
}
// newService 创建测试用的 httpUpstreamService 实例
// 返回具体类型以便访问内部状态进行断言
func (s *HTTPUpstreamSuite) newService() *httpUpstreamService {
up := NewHTTPUpstream(s.cfg)
svc, ok := up.(*httpUpstreamService)
require.True(s.T(), ok, "expected *httpUpstreamService")
return svc
}
// TestDefaultResponseHeaderTimeout 测试默认响应头超时配置
// 验证未配置时使用 300 秒默认值
func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() {
svc := s.newService()
entry := svc.getOrCreateClient("", 0, 0)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 300*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
}
// TestCustomResponseHeaderTimeout 测试自定义响应头超时配置
// 验证配置值能正确应用到 Transport
func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() {
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 7}
svc := s.newService()
entry := svc.getOrCreateClient("", 0, 0)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
}
// TestGetOrCreateClient_InvalidURLFallsBackToDirect 测试无效代理 URL 回退
// 验证解析失败时回退到直连模式
func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLFallsBackToDirect() {
svc := s.newService()
entry := svc.getOrCreateClient("://bad-proxy-url", 1, 1)
require.Equal(s.T(), directProxyKey, entry.proxyKey, "expected direct proxy fallback")
}
// TestNormalizeProxyURL_Canonicalizes 测试代理 URL 规范化
// 验证等价地址能够映射到同一缓存键
func (s *HTTPUpstreamSuite) TestNormalizeProxyURL_Canonicalizes() {
key1, _ := normalizeProxyURL("http://proxy.local:8080")
key2, _ := normalizeProxyURL("http://proxy.local:8080/")
require.Equal(s.T(), key1, key2, "expected normalized proxy keys to match")
}
// TestAcquireClient_OverLimitReturnsError 测试连接池缓存上限保护
// 验证超限且无可淘汰条目时返回错误
func (s *HTTPUpstreamSuite) TestAcquireClient_OverLimitReturnsError() {
s.cfg.Gateway = config.GatewayConfig{
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy,
MaxUpstreamClients: 1,
}
svc := s.newService()
entry1, err := svc.acquireClient("http://proxy-a:8080", 1, 1)
require.NoError(s.T(), err, "expected first acquire to succeed")
require.NotNil(s.T(), entry1, "expected entry")
entry2, err := svc.acquireClient("http://proxy-b:8080", 2, 1)
require.Error(s.T(), err, "expected error when cache limit reached")
require.Nil(s.T(), entry2, "expected nil entry when cache limit reached")
}
// TestDo_WithoutProxy_GoesDirect 测试无代理时直连
// 验证空代理 URL 时请求直接发送到目标服务器
func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() {
// 创建模拟上游服务器
upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "direct")
}))
s.T().Cleanup(upstream.Close)
up := NewHTTPUpstream(s.cfg)
req, err := http.NewRequest(http.MethodGet, upstream.URL+"/x", nil)
require.NoError(s.T(), err, "NewRequest")
resp, err := up.Do(req, "", 1, 1)
require.NoError(s.T(), err, "Do")
defer func() { _ = resp.Body.Close() }()
b, _ := io.ReadAll(resp.Body)
require.Equal(s.T(), "direct", string(b), "unexpected body")
}
// TestDo_WithHTTPProxy_UsesProxy 测试 HTTP 代理功能
// 验证请求通过代理服务器转发,使用绝对 URI 格式
func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
// 用于接收代理请求的通道
seen := make(chan string, 1)
// 创建模拟代理服务器
proxySrv := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seen <- r.RequestURI // 记录请求 URI
_, _ = io.WriteString(w, "proxied")
}))
s.T().Cleanup(proxySrv.Close)
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 1}
up := NewHTTPUpstream(s.cfg)
// 发送请求到外部地址,应通过代理
req, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil)
require.NoError(s.T(), err, "NewRequest")
resp, err := up.Do(req, proxySrv.URL, 1, 1)
require.NoError(s.T(), err, "Do")
defer func() { _ = resp.Body.Close() }()
b, _ := io.ReadAll(resp.Body)
require.Equal(s.T(), "proxied", string(b), "unexpected body")
// 验证代理收到的是绝对 URI 格式HTTP 代理规范要求)
select {
case uri := <-seen:
require.Equal(s.T(), "http://example.com/test", uri, "expected absolute-form request URI")
default:
require.Fail(s.T(), "expected proxy to receive request")
}
}
// TestDo_EmptyProxy_UsesDirect 测试空代理字符串
// 验证空字符串代理等同于直连
func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() {
upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = io.WriteString(w, "direct-empty")
}))
s.T().Cleanup(upstream.Close)
up := NewHTTPUpstream(s.cfg)
req, err := http.NewRequest(http.MethodGet, upstream.URL+"/y", nil)
require.NoError(s.T(), err, "NewRequest")
resp, err := up.Do(req, "", 1, 1)
require.NoError(s.T(), err, "Do with empty proxy")
defer func() { _ = resp.Body.Close() }()
b, _ := io.ReadAll(resp.Body)
require.Equal(s.T(), "direct-empty", string(b))
}
// TestAccountIsolation_DifferentAccounts 测试账户隔离模式
// 验证不同账户使用独立的连接池
func (s *HTTPUpstreamSuite) TestAccountIsolation_DifferentAccounts() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
svc := s.newService()
// 同一代理,不同账户
entry1 := svc.getOrCreateClient("http://proxy.local:8080", 1, 3)
entry2 := svc.getOrCreateClient("http://proxy.local:8080", 2, 3)
require.NotSame(s.T(), entry1, entry2, "不同账号不应共享连接池")
require.Equal(s.T(), 2, len(svc.clients), "账号隔离应缓存两个客户端")
}
// TestAccountProxyIsolation_DifferentProxy 测试账户+代理组合隔离模式
// 验证同一账户使用不同代理时创建独立连接池
func (s *HTTPUpstreamSuite) TestAccountProxyIsolation_DifferentProxy() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy}
svc := s.newService()
// 同一账户,不同代理
entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3)
entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3)
require.NotSame(s.T(), entry1, entry2, "账号+代理隔离应区分不同代理")
require.Equal(s.T(), 2, len(svc.clients), "账号+代理隔离应缓存两个客户端")
}
// TestAccountModeProxyChangeClearsPool 测试账户模式下代理变更
// 验证账户切换代理时清理旧连接池,避免复用错误代理
func (s *HTTPUpstreamSuite) TestAccountModeProxyChangeClearsPool() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
svc := s.newService()
// 同一账户,先后使用不同代理
entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3)
entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3)
require.NotSame(s.T(), entry1, entry2, "账号切换代理应创建新连接池")
require.Equal(s.T(), 1, len(svc.clients), "账号模式下应仅保留一个连接池")
require.False(s.T(), hasEntry(svc, entry1), "旧连接池应被清理")
}
// TestAccountConcurrencyOverridesPoolSettings 测试账户并发数覆盖连接池配置
// 验证账户隔离模式下,连接池大小与账户并发数对应
func (s *HTTPUpstreamSuite) TestAccountConcurrencyOverridesPoolSettings() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
svc := s.newService()
// 账户并发数为 12
entry := svc.getOrCreateClient("", 1, 12)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
// 连接池参数应与并发数一致
require.Equal(s.T(), 12, transport.MaxConnsPerHost, "MaxConnsPerHost mismatch")
require.Equal(s.T(), 12, transport.MaxIdleConns, "MaxIdleConns mismatch")
require.Equal(s.T(), 12, transport.MaxIdleConnsPerHost, "MaxIdleConnsPerHost mismatch")
}
// TestAccountConcurrencyFallbackToDefault 测试账户并发数为 0 时回退到默认配置
// 验证未指定并发数时使用全局配置值
func (s *HTTPUpstreamSuite) TestAccountConcurrencyFallbackToDefault() {
s.cfg.Gateway = config.GatewayConfig{
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount,
MaxIdleConns: 77,
MaxIdleConnsPerHost: 55,
MaxConnsPerHost: 66,
}
svc := s.newService()
// 账户并发数为 0应使用全局配置
entry := svc.getOrCreateClient("", 1, 0)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 66, transport.MaxConnsPerHost, "MaxConnsPerHost fallback mismatch")
require.Equal(s.T(), 77, transport.MaxIdleConns, "MaxIdleConns fallback mismatch")
require.Equal(s.T(), 55, transport.MaxIdleConnsPerHost, "MaxIdleConnsPerHost fallback mismatch")
}
// TestEvictOverLimitRemovesOldestIdle 测试超出数量限制时的 LRU 淘汰
// 验证优先淘汰最久未使用的空闲客户端
func (s *HTTPUpstreamSuite) TestEvictOverLimitRemovesOldestIdle() {
s.cfg.Gateway = config.GatewayConfig{
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy,
MaxUpstreamClients: 2, // 最多缓存 2 个客户端
}
svc := s.newService()
// 创建两个客户端,设置不同的最后使用时间
entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 1)
entry2 := svc.getOrCreateClient("http://proxy-b:8080", 2, 1)
atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Hour).UnixNano()) // 最久
atomic.StoreInt64(&entry2.lastUsed, time.Now().Add(-time.Hour).UnixNano())
// 创建第三个客户端,触发淘汰
_ = svc.getOrCreateClient("http://proxy-c:8080", 3, 1)
require.LessOrEqual(s.T(), len(svc.clients), 2, "应保持在缓存上限内")
require.False(s.T(), hasEntry(svc, entry1), "最久未使用的连接池应被清理")
}
// TestIdleTTLDoesNotEvictActive 测试活跃请求保护
// 验证有进行中请求的客户端不会被空闲超时淘汰
func (s *HTTPUpstreamSuite) TestIdleTTLDoesNotEvictActive() {
s.cfg.Gateway = config.GatewayConfig{
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount,
ClientIdleTTLSeconds: 1, // 1 秒空闲超时
}
svc := s.newService()
entry1 := svc.getOrCreateClient("", 1, 1)
// 设置为很久之前使用,但有活跃请求
atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Minute).UnixNano())
atomic.StoreInt64(&entry1.inFlight, 1) // 模拟有活跃请求
// 创建新客户端,触发淘汰检查
_ = svc.getOrCreateClient("", 2, 1)
require.True(s.T(), hasEntry(svc, entry1), "有活跃请求时不应回收")
}
// TestHTTPUpstreamSuite 运行测试套件
func TestHTTPUpstreamSuite(t *testing.T) {
suite.Run(t, new(HTTPUpstreamSuite))
}
// hasEntry 检查客户端是否存在于缓存中
// 辅助函数,用于验证淘汰逻辑
func hasEntry(svc *httpUpstreamService, target *upstreamClientEntry) bool {
for _, entry := range svc.clients {
if entry == target {
return true
}
}
return false
}

View File

@@ -0,0 +1,51 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
fingerprintKeyPrefix = "fingerprint:"
fingerprintTTL = 24 * time.Hour
)
// fingerprintKey generates the Redis key for account fingerprint cache.
func fingerprintKey(accountID int64) string {
return fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
}
type identityCache struct {
rdb *redis.Client
}
func NewIdentityCache(rdb *redis.Client) service.IdentityCache {
return &identityCache{rdb: rdb}
}
func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*service.Fingerprint, error) {
key := fingerprintKey(accountID)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
}
var fp service.Fingerprint
if err := json.Unmarshal([]byte(val), &fp); err != nil {
return nil, err
}
return &fp, nil
}
func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *service.Fingerprint) error {
key := fingerprintKey(accountID)
val, err := json.Marshal(fp)
if err != nil {
return err
}
return c.rdb.Set(ctx, key, val, fingerprintTTL).Err()
}

View File

@@ -0,0 +1,67 @@
//go:build integration
package repository
import (
"errors"
"fmt"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type IdentityCacheSuite struct {
IntegrationRedisSuite
cache *identityCache
}
func (s *IdentityCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewIdentityCache(s.rdb).(*identityCache)
}
func (s *IdentityCacheSuite) TestGetFingerprint_Missing() {
_, err := s.cache.GetFingerprint(s.ctx, 1)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing fingerprint")
}
func (s *IdentityCacheSuite) TestSetAndGetFingerprint() {
fp := &service.Fingerprint{ClientID: "c1", UserAgent: "ua"}
require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 1, fp), "SetFingerprint")
gotFP, err := s.cache.GetFingerprint(s.ctx, 1)
require.NoError(s.T(), err, "GetFingerprint")
require.Equal(s.T(), "c1", gotFP.ClientID)
require.Equal(s.T(), "ua", gotFP.UserAgent)
}
func (s *IdentityCacheSuite) TestFingerprint_TTL() {
fp := &service.Fingerprint{ClientID: "c1", UserAgent: "ua"}
require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 2, fp))
fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 2)
ttl, err := s.rdb.TTL(s.ctx, fpKey).Result()
require.NoError(s.T(), err, "TTL fpKey")
s.AssertTTLWithin(ttl, 1*time.Second, fingerprintTTL)
}
func (s *IdentityCacheSuite) TestGetFingerprint_JSONCorruption() {
fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 999)
require.NoError(s.T(), s.rdb.Set(s.ctx, fpKey, "invalid-json-data", 1*time.Minute).Err(), "Set invalid JSON")
_, err := s.cache.GetFingerprint(s.ctx, 999)
require.Error(s.T(), err, "expected error for corrupted JSON")
require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil")
}
func (s *IdentityCacheSuite) TestSetFingerprint_Nil() {
err := s.cache.SetFingerprint(s.ctx, 100, nil)
require.NoError(s.T(), err, "SetFingerprint(nil) should succeed")
}
func TestIdentityCacheSuite(t *testing.T) {
suite.Run(t, new(IdentityCacheSuite))
}

View File

@@ -0,0 +1,46 @@
//go:build unit
package repository
import (
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestFingerprintKey(t *testing.T) {
tests := []struct {
name string
accountID int64
expected string
}{
{
name: "normal_account_id",
accountID: 123,
expected: "fingerprint:123",
},
{
name: "zero_account_id",
accountID: 0,
expected: "fingerprint:0",
},
{
name: "negative_account_id",
accountID: -1,
expected: "fingerprint:-1",
},
{
name: "max_int64",
accountID: math.MaxInt64,
expected: "fingerprint:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := fingerprintKey(tc.accountID)
require.Equal(t, tc.expected, got)
})
}
}

View File

@@ -0,0 +1,63 @@
package repository
import (
"bytes"
"io"
"net"
"net/http"
"net/http/httptest"
"sync"
"testing"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
// newInProcessTransport adapts an http.HandlerFunc into an http.RoundTripper without opening sockets.
// It captures the request body (if any) and then rewinds it before invoking the handler.
func newInProcessTransport(handler http.HandlerFunc, capture func(r *http.Request, body []byte)) http.RoundTripper {
return roundTripFunc(func(r *http.Request) (*http.Response, error) {
var body []byte
if r.Body != nil {
body, _ = io.ReadAll(r.Body)
_ = r.Body.Close()
r.Body = io.NopCloser(bytes.NewReader(body))
}
if capture != nil {
capture(r, body)
}
rec := httptest.NewRecorder()
handler(rec, r)
return rec.Result(), nil
})
}
var (
canListenOnce sync.Once
canListen bool
canListenErr error
)
func localListenerAvailable() bool {
canListenOnce.Do(func() {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
canListenErr = err
canListen = false
return
}
_ = ln.Close()
canListen = true
})
return canListen
}
func newLocalTestServer(tb testing.TB, handler http.Handler) *httptest.Server {
tb.Helper()
if !localListenerAvailable() {
tb.Skipf("local listeners are not permitted in this environment: %v", canListenErr)
}
return httptest.NewServer(handler)
}

View File

@@ -0,0 +1,408 @@
//go:build integration
package repository
import (
"context"
"database/sql"
"fmt"
"log"
"os"
"os/exec"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
_ "github.com/Wei-Shaw/sub2api/ent/runtime"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"entgo.io/ent/dialect"
entsql "entgo.io/ent/dialect/sql"
_ "github.com/lib/pq"
redisclient "github.com/redis/go-redis/v9"
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
)
const (
redisImageTag = "redis:8.4-alpine"
postgresImageTag = "postgres:18.1-alpine3.23"
)
var (
integrationDB *sql.DB
integrationEntClient *dbent.Client
integrationRedis *redisclient.Client
redisNamespaceSeq uint64
)
func TestMain(m *testing.M) {
ctx := context.Background()
if err := timezone.Init("UTC"); err != nil {
log.Printf("failed to init timezone: %v", err)
os.Exit(1)
}
if !dockerIsAvailable(ctx) {
// In CI we expect Docker to be available so integration tests should fail loudly.
if os.Getenv("CI") != "" {
log.Printf("docker is not available (CI=true); failing integration tests")
os.Exit(1)
}
log.Printf("docker is not available; skipping integration tests (start Docker to enable)")
os.Exit(0)
}
postgresImage := selectDockerImage(ctx, postgresImageTag)
pgContainer, err := tcpostgres.Run(
ctx,
postgresImage,
tcpostgres.WithDatabase("sub2api_test"),
tcpostgres.WithUsername("postgres"),
tcpostgres.WithPassword("postgres"),
tcpostgres.BasicWaitStrategies(),
)
if err != nil {
log.Printf("failed to start postgres container: %v", err)
os.Exit(1)
}
defer func() { _ = pgContainer.Terminate(ctx) }()
redisContainer, err := tcredis.Run(
ctx,
redisImageTag,
)
if err != nil {
log.Printf("failed to start redis container: %v", err)
os.Exit(1)
}
defer func() { _ = redisContainer.Terminate(ctx) }()
dsn, err := pgContainer.ConnectionString(ctx, "sslmode=disable", "TimeZone=UTC")
if err != nil {
log.Printf("failed to get postgres dsn: %v", err)
os.Exit(1)
}
integrationDB, err = openSQLWithRetry(ctx, dsn, 30*time.Second)
if err != nil {
log.Printf("failed to open sql db: %v", err)
os.Exit(1)
}
if err := ApplyMigrations(ctx, integrationDB); err != nil {
log.Printf("failed to apply db migrations: %v", err)
os.Exit(1)
}
// 创建 ent client 用于集成测试
drv := entsql.OpenDB(dialect.Postgres, integrationDB)
integrationEntClient = dbent.NewClient(dbent.Driver(drv))
redisHost, err := redisContainer.Host(ctx)
if err != nil {
log.Printf("failed to get redis host: %v", err)
os.Exit(1)
}
redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp")
if err != nil {
log.Printf("failed to get redis port: %v", err)
os.Exit(1)
}
integrationRedis = redisclient.NewClient(&redisclient.Options{
Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()),
DB: 0,
})
if err := integrationRedis.Ping(ctx).Err(); err != nil {
log.Printf("failed to ping redis: %v", err)
os.Exit(1)
}
code := m.Run()
_ = integrationEntClient.Close()
_ = integrationRedis.Close()
_ = integrationDB.Close()
os.Exit(code)
}
func dockerIsAvailable(ctx context.Context) bool {
cmd := exec.CommandContext(ctx, "docker", "info")
cmd.Env = os.Environ()
return cmd.Run() == nil
}
func selectDockerImage(ctx context.Context, preferred string) string {
if dockerImageExists(ctx, preferred) {
return preferred
}
return preferred
}
func dockerImageExists(ctx context.Context, image string) bool {
cmd := exec.CommandContext(ctx, "docker", "image", "inspect", image)
cmd.Env = os.Environ()
cmd.Stdout = nil
cmd.Stderr = nil
return cmd.Run() == nil
}
func openSQLWithRetry(ctx context.Context, dsn string, timeout time.Duration) (*sql.DB, error) {
deadline := time.Now().Add(timeout)
var lastErr error
for time.Now().Before(deadline) {
db, err := sql.Open("postgres", dsn)
if err != nil {
lastErr = err
time.Sleep(250 * time.Millisecond)
continue
}
if err := pingWithTimeout(ctx, db, 2*time.Second); err != nil {
lastErr = err
_ = db.Close()
time.Sleep(250 * time.Millisecond)
continue
}
return db, nil
}
return nil, fmt.Errorf("db not ready after %s: %w", timeout, lastErr)
}
func pingWithTimeout(ctx context.Context, db *sql.DB, timeout time.Duration) error {
pingCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
return db.PingContext(pingCtx)
}
func testTx(t *testing.T) *sql.Tx {
t.Helper()
tx, err := integrationDB.BeginTx(context.Background(), nil)
require.NoError(t, err, "begin tx")
t.Cleanup(func() {
_ = tx.Rollback()
})
return tx
}
// testEntClient 返回全局的 ent client用于测试需要内部管理事务的代码如 Create/Update 方法)。
// 注意:此 client 的操作会真正写入数据库,测试结束后不会自动回滚。
func testEntClient(t *testing.T) *dbent.Client {
t.Helper()
return integrationEntClient
}
// testEntTx 返回一个 ent 事务,用于需要事务隔离的测试。
// 测试结束后会自动回滚,不会影响数据库状态。
func testEntTx(t *testing.T) *dbent.Tx {
t.Helper()
tx, err := integrationEntClient.Tx(context.Background())
require.NoError(t, err, "begin ent tx")
t.Cleanup(func() {
_ = tx.Rollback()
})
return tx
}
// testEntSQLTx 已弃用:不要在新测试中使用此函数。
// 基于 *sql.Tx 创建的 ent client 在调用 client.Tx() 时会 panic。
// 对于需要测试内部使用事务的代码,请使用 testEntClient。
// 对于需要事务隔离的测试,请使用 testEntTx。
//
// Deprecated: Use testEntClient or testEntTx instead.
func testEntSQLTx(t *testing.T) (*dbent.Client, *sql.Tx) {
t.Helper()
// 直接失败,避免旧测试误用导致的事务嵌套 panic。
t.Fatalf("testEntSQLTx 已弃用:请使用 testEntClient 或 testEntTx")
return nil, nil
}
func testRedis(t *testing.T) *redisclient.Client {
t.Helper()
prefix := fmt.Sprintf(
"it:%s:%d:%d:",
sanitizeRedisNamespace(t.Name()),
time.Now().UnixNano(),
atomic.AddUint64(&redisNamespaceSeq, 1),
)
opts := *integrationRedis.Options()
rdb := redisclient.NewClient(&opts)
rdb.AddHook(prefixHook{prefix: prefix})
t.Cleanup(func() {
ctx := context.Background()
var cursor uint64
for {
keys, nextCursor, err := integrationRedis.Scan(ctx, cursor, prefix+"*", 500).Result()
require.NoError(t, err, "scan redis keys for cleanup")
if len(keys) > 0 {
require.NoError(t, integrationRedis.Unlink(ctx, keys...).Err(), "unlink redis keys for cleanup")
}
cursor = nextCursor
if cursor == 0 {
break
}
}
_ = rdb.Close()
})
return rdb
}
func assertTTLWithin(t *testing.T, ttl time.Duration, min, max time.Duration) {
t.Helper()
require.GreaterOrEqual(t, ttl, min, "ttl should be >= min")
require.LessOrEqual(t, ttl, max, "ttl should be <= max")
}
func sanitizeRedisNamespace(name string) string {
name = strings.ReplaceAll(name, "/", "_")
name = strings.ReplaceAll(name, " ", "_")
return name
}
type prefixHook struct {
prefix string
}
func (h prefixHook) DialHook(next redisclient.DialHook) redisclient.DialHook { return next }
func (h prefixHook) ProcessHook(next redisclient.ProcessHook) redisclient.ProcessHook {
return func(ctx context.Context, cmd redisclient.Cmder) error {
h.prefixCmd(cmd)
return next(ctx, cmd)
}
}
func (h prefixHook) ProcessPipelineHook(next redisclient.ProcessPipelineHook) redisclient.ProcessPipelineHook {
return func(ctx context.Context, cmds []redisclient.Cmder) error {
for _, cmd := range cmds {
h.prefixCmd(cmd)
}
return next(ctx, cmds)
}
}
func (h prefixHook) prefixCmd(cmd redisclient.Cmder) {
args := cmd.Args()
if len(args) < 2 {
return
}
prefixOne := func(i int) {
if i < 0 || i >= len(args) {
return
}
switch v := args[i].(type) {
case string:
if v != "" && !strings.HasPrefix(v, h.prefix) {
args[i] = h.prefix + v
}
case []byte:
s := string(v)
if s != "" && !strings.HasPrefix(s, h.prefix) {
args[i] = []byte(h.prefix + s)
}
}
}
switch strings.ToLower(cmd.Name()) {
case "get", "set", "setnx", "setex", "psetex", "incr", "decr", "incrby", "expire", "pexpire", "ttl", "pttl",
"hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists",
"zadd", "zcard", "zrange", "zrangebyscore", "zrem", "zremrangebyscore", "zrevrange", "zrevrangebyscore", "zscore":
prefixOne(1)
case "del", "unlink":
for i := 1; i < len(args); i++ {
prefixOne(i)
}
case "eval", "evalsha", "eval_ro", "evalsha_ro":
if len(args) < 3 {
return
}
numKeys, err := strconv.Atoi(fmt.Sprint(args[2]))
if err != nil || numKeys <= 0 {
return
}
for i := 0; i < numKeys && 3+i < len(args); i++ {
prefixOne(3 + i)
}
case "scan":
for i := 2; i+1 < len(args); i++ {
if strings.EqualFold(fmt.Sprint(args[i]), "match") {
prefixOne(i + 1)
break
}
}
}
}
// IntegrationRedisSuite provides a base suite for Redis integration tests.
// Embedding suites should call SetupTest to initialize ctx and rdb.
type IntegrationRedisSuite struct {
suite.Suite
ctx context.Context
rdb *redisclient.Client
}
// SetupTest initializes ctx and rdb for each test method.
func (s *IntegrationRedisSuite) SetupTest() {
s.ctx = context.Background()
s.rdb = testRedis(s.T())
}
// RequireNoError is a convenience method wrapping require.NoError with s.T().
func (s *IntegrationRedisSuite) RequireNoError(err error, msgAndArgs ...any) {
s.T().Helper()
require.NoError(s.T(), err, msgAndArgs...)
}
// AssertTTLWithin asserts that ttl is within [min, max].
func (s *IntegrationRedisSuite) AssertTTLWithin(ttl, min, max time.Duration) {
s.T().Helper()
assertTTLWithin(s.T(), ttl, min, max)
}
// IntegrationDBSuite provides a base suite for DB integration tests.
// Embedding suites should call SetupTest to initialize ctx and client.
type IntegrationDBSuite struct {
suite.Suite
ctx context.Context
client *dbent.Client
tx *dbent.Tx
}
// SetupTest initializes ctx and client for each test method.
func (s *IntegrationDBSuite) SetupTest() {
s.ctx = context.Background()
// 统一使用 ent.Tx确保每个测试都有独立事务并自动回滚。
tx := testEntTx(s.T())
s.tx = tx
s.client = tx.Client()
}
// RequireNoError is a convenience method wrapping require.NoError with s.T().
func (s *IntegrationDBSuite) RequireNoError(err error, msgAndArgs ...any) {
s.T().Helper()
require.NoError(s.T(), err, msgAndArgs...)
}

View File

@@ -0,0 +1,302 @@
package repository
import (
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"io/fs"
"sort"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/migrations"
)
// schemaMigrationsTableDDL 定义迁移记录表的 DDL。
// 该表用于跟踪已应用的迁移文件及其校验和。
// - filename: 迁移文件名,作为主键唯一标识每个迁移
// - checksum: 文件内容的 SHA256 哈希值,用于检测迁移文件是否被篡改
// - applied_at: 迁移应用时间戳
const schemaMigrationsTableDDL = `
CREATE TABLE IF NOT EXISTS schema_migrations (
filename TEXT PRIMARY KEY,
checksum TEXT NOT NULL,
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
`
const atlasSchemaRevisionsTableDDL = `
CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
version TEXT PRIMARY KEY,
description TEXT NOT NULL,
type INTEGER NOT NULL,
applied INTEGER NOT NULL DEFAULT 0,
total INTEGER NOT NULL DEFAULT 0,
executed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
execution_time BIGINT NOT NULL DEFAULT 0,
error TEXT NULL,
error_stmt TEXT NULL,
hash TEXT NOT NULL DEFAULT '',
partial_hashes TEXT[] NULL,
operator_version TEXT NULL
);
`
// migrationsAdvisoryLockID 是用于序列化迁移操作的 PostgreSQL Advisory Lock ID。
// 在多实例部署场景下,该锁确保同一时间只有一个实例执行迁移。
// 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。
const migrationsAdvisoryLockID int64 = 694208311321144027
const migrationsLockRetryInterval = 500 * time.Millisecond
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
//
// 该函数可以在每次应用启动时安全调用:
// - 已应用的迁移会被自动跳过(通过校验 filename 判断)
// - 如果迁移文件内容被修改checksum 不匹配),会返回错误
// - 使用 PostgreSQL Advisory Lock 确保多实例并发安全
//
// 参数:
// - ctx: 上下文,用于超时控制和取消
// - db: 数据库连接
//
// 返回:
// - error: 迁移过程中的任何错误
func ApplyMigrations(ctx context.Context, db *sql.DB) error {
if db == nil {
return errors.New("nil sql db")
}
return applyMigrationsFS(ctx, db, migrations.FS)
}
// applyMigrationsFS 是迁移执行的核心实现。
// 它从指定的文件系统读取 SQL 迁移文件并按顺序应用。
//
// 迁移执行流程:
// 1. 获取 PostgreSQL Advisory Lock防止多实例并发迁移
// 2. 确保 schema_migrations 表存在
// 3. 按文件名排序读取所有 .sql 文件
// 4. 对于每个迁移文件:
// - 计算文件内容的 SHA256 校验和
// - 检查该迁移是否已应用(通过 filename 查询)
// - 如果已应用,验证校验和是否匹配
// - 如果未应用,在事务中执行迁移并记录
// 5. 释放 Advisory Lock
//
// 参数:
// - ctx: 上下文
// - db: 数据库连接
// - fsys: 包含迁移文件的文件系统(通常是 embed.FS
func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
if db == nil {
return errors.New("nil sql db")
}
// 获取分布式锁,确保多实例部署时只有一个实例执行迁移。
// 这是 PostgreSQL 特有的 Advisory Lock 机制。
if err := pgAdvisoryLock(ctx, db); err != nil {
return err
}
defer func() {
// 无论迁移是否成功,都要释放锁。
// 使用 context.Background() 确保即使原 ctx 已取消也能释放锁。
_ = pgAdvisoryUnlock(context.Background(), db)
}()
// 创建迁移记录表(如果不存在)。
// 该表记录所有已应用的迁移及其校验和。
if _, err := db.ExecContext(ctx, schemaMigrationsTableDDL); err != nil {
return fmt.Errorf("create schema_migrations: %w", err)
}
// 自动对齐 Atlas 基线(如果检测到 legacy schema_migrations 且缺失 atlas_schema_revisions
if err := ensureAtlasBaselineAligned(ctx, db, fsys); err != nil {
return err
}
// 获取所有 .sql 迁移文件并按文件名排序。
// 命名规范:使用零填充数字前缀(如 001_init.sql, 002_add_users.sql
files, err := fs.Glob(fsys, "*.sql")
if err != nil {
return fmt.Errorf("list migrations: %w", err)
}
sort.Strings(files) // 确保按文件名顺序执行迁移
for _, name := range files {
// 读取迁移文件内容
contentBytes, err := fs.ReadFile(fsys, name)
if err != nil {
return fmt.Errorf("read migration %s: %w", name, err)
}
content := strings.TrimSpace(string(contentBytes))
if content == "" {
continue // 跳过空文件
}
// 计算文件内容的 SHA256 校验和,用于检测文件是否被修改。
// 这是一种防篡改机制:如果有人修改了已应用的迁移文件,系统会拒绝启动。
sum := sha256.Sum256([]byte(content))
checksum := hex.EncodeToString(sum[:])
// 检查该迁移是否已经应用
var existing string
rowErr := db.QueryRowContext(ctx, "SELECT checksum FROM schema_migrations WHERE filename = $1", name).Scan(&existing)
if rowErr == nil {
// 迁移已应用,验证校验和是否匹配
if existing != checksum {
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
// 正确的做法是创建新的迁移文件来进行变更。
return fmt.Errorf(
"migration %s checksum mismatch (db=%s file=%s)\n"+
"This means the migration file was modified after being applied to the database.\n"+
"Solutions:\n"+
" 1. Revert to original: git log --oneline -- migrations/%s && git checkout <commit> -- migrations/%s\n"+
" 2. For new changes, create a new migration file instead of modifying existing ones\n"+
"Note: Modifying applied migrations breaks the immutability principle and can cause inconsistencies across environments",
name, existing, checksum, name, name,
)
}
continue // 迁移已应用且校验和匹配,跳过
}
if !errors.Is(rowErr, sql.ErrNoRows) {
return fmt.Errorf("check migration %s: %w", name, rowErr)
}
// 迁移未应用,在事务中执行。
// 使用事务确保迁移的原子性:要么完全成功,要么完全回滚。
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin migration %s: %w", name, err)
}
// 执行迁移 SQL
if _, err := tx.ExecContext(ctx, content); err != nil {
_ = tx.Rollback()
return fmt.Errorf("apply migration %s: %w", name, err)
}
// 记录迁移已完成,保存文件名和校验和
if _, err := tx.ExecContext(ctx, "INSERT INTO schema_migrations (filename, checksum) VALUES ($1, $2)", name, checksum); err != nil {
_ = tx.Rollback()
return fmt.Errorf("record migration %s: %w", name, err)
}
// 提交事务
if err := tx.Commit(); err != nil {
_ = tx.Rollback()
return fmt.Errorf("commit migration %s: %w", name, err)
}
}
return nil
}
func ensureAtlasBaselineAligned(ctx context.Context, db *sql.DB, fsys fs.FS) error {
hasLegacy, err := tableExists(ctx, db, "schema_migrations")
if err != nil {
return fmt.Errorf("check schema_migrations: %w", err)
}
if !hasLegacy {
return nil
}
hasAtlas, err := tableExists(ctx, db, "atlas_schema_revisions")
if err != nil {
return fmt.Errorf("check atlas_schema_revisions: %w", err)
}
if !hasAtlas {
if _, err := db.ExecContext(ctx, atlasSchemaRevisionsTableDDL); err != nil {
return fmt.Errorf("create atlas_schema_revisions: %w", err)
}
}
var count int
if err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM atlas_schema_revisions").Scan(&count); err != nil {
return fmt.Errorf("count atlas_schema_revisions: %w", err)
}
if count > 0 {
return nil
}
version, description, hash, err := latestMigrationBaseline(fsys)
if err != nil {
return fmt.Errorf("atlas baseline version: %w", err)
}
if _, err := db.ExecContext(ctx, `
INSERT INTO atlas_schema_revisions (version, description, type, applied, total, executed_at, execution_time, hash)
VALUES ($1, $2, $3, 0, 0, NOW(), 0, $4)
`, version, description, 1, hash); err != nil {
return fmt.Errorf("insert atlas baseline: %w", err)
}
return nil
}
func tableExists(ctx context.Context, db *sql.DB, tableName string) (bool, error) {
var exists bool
err := db.QueryRowContext(ctx, `
SELECT EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_schema = 'public' AND table_name = $1
)
`, tableName).Scan(&exists)
return exists, err
}
func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) {
files, err := fs.Glob(fsys, "*.sql")
if err != nil {
return "", "", "", err
}
if len(files) == 0 {
return "baseline", "baseline", "", nil
}
sort.Strings(files)
name := files[len(files)-1]
contentBytes, err := fs.ReadFile(fsys, name)
if err != nil {
return "", "", "", err
}
content := strings.TrimSpace(string(contentBytes))
sum := sha256.Sum256([]byte(content))
hash := hex.EncodeToString(sum[:])
version := strings.TrimSuffix(name, ".sql")
return version, version, hash, nil
}
// pgAdvisoryLock 获取 PostgreSQL Advisory Lock。
// Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。
// 它非常适合用于应用层面的分布式锁场景,如迁移序列化。
func pgAdvisoryLock(ctx context.Context, db *sql.DB) error {
ticker := time.NewTicker(migrationsLockRetryInterval)
defer ticker.Stop()
for {
var locked bool
if err := db.QueryRowContext(ctx, "SELECT pg_try_advisory_lock($1)", migrationsAdvisoryLockID).Scan(&locked); err != nil {
return fmt.Errorf("acquire migrations lock: %w", err)
}
if locked {
return nil
}
select {
case <-ctx.Done():
return fmt.Errorf("acquire migrations lock: %w", ctx.Err())
case <-ticker.C:
}
}
}
// pgAdvisoryUnlock 释放 PostgreSQL Advisory Lock。
// 必须在获取锁后确保释放,否则会阻塞其他实例的迁移操作。
func pgAdvisoryUnlock(ctx context.Context, db *sql.DB) error {
_, err := db.ExecContext(ctx, "SELECT pg_advisory_unlock($1)", migrationsAdvisoryLockID)
if err != nil {
return fmt.Errorf("release migrations lock: %w", err)
}
return nil
}

View File

@@ -0,0 +1,103 @@
//go:build integration
package repository
import (
"context"
"database/sql"
"testing"
"github.com/stretchr/testify/require"
)
func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
tx := testTx(t)
// Re-apply migrations to verify idempotency (no errors, no duplicate rows).
require.NoError(t, ApplyMigrations(context.Background(), integrationDB))
// schema_migrations should have at least the current migration set.
var applied int
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT COUNT(*) FROM schema_migrations").Scan(&applied))
require.GreaterOrEqual(t, applied, 7, "expected schema_migrations to contain applied migrations")
// users: columns required by repository queries
requireColumn(t, tx, "users", "username", "character varying", 100, false)
requireColumn(t, tx, "users", "notes", "text", 0, false)
// accounts: schedulable and rate-limit fields
requireColumn(t, tx, "accounts", "notes", "text", 0, true)
requireColumn(t, tx, "accounts", "schedulable", "boolean", 0, false)
requireColumn(t, tx, "accounts", "rate_limited_at", "timestamp with time zone", 0, true)
requireColumn(t, tx, "accounts", "rate_limit_reset_at", "timestamp with time zone", 0, true)
requireColumn(t, tx, "accounts", "overload_until", "timestamp with time zone", 0, true)
requireColumn(t, tx, "accounts", "session_window_status", "character varying", 20, true)
// api_keys: key length should be 128
requireColumn(t, tx, "api_keys", "key", "character varying", 128, false)
// redeem_codes: subscription fields
requireColumn(t, tx, "redeem_codes", "group_id", "bigint", 0, true)
requireColumn(t, tx, "redeem_codes", "validity_days", "integer", 0, false)
// usage_logs: billing_type used by filters/stats
requireColumn(t, tx, "usage_logs", "billing_type", "smallint", 0, false)
// settings table should exist
var settingsRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
require.True(t, settingsRegclass.Valid, "expected settings table to exist")
// user_allowed_groups table should exist
var uagRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.user_allowed_groups')").Scan(&uagRegclass))
require.True(t, uagRegclass.Valid, "expected user_allowed_groups table to exist")
// user_subscriptions: deleted_at for soft delete support (migration 012)
requireColumn(t, tx, "user_subscriptions", "deleted_at", "timestamp with time zone", 0, true)
// orphan_allowed_groups_audit table should exist (migration 013)
var orphanAuditRegclass sql.NullString
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.orphan_allowed_groups_audit')").Scan(&orphanAuditRegclass))
require.True(t, orphanAuditRegclass.Valid, "expected orphan_allowed_groups_audit table to exist")
// account_groups: created_at should be timestamptz
requireColumn(t, tx, "account_groups", "created_at", "timestamp with time zone", 0, false)
// user_allowed_groups: created_at should be timestamptz
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
}
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
t.Helper()
var row struct {
DataType string
MaxLen sql.NullInt64
Nullable string
}
err := tx.QueryRowContext(context.Background(), `
SELECT
data_type,
character_maximum_length,
is_nullable
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = $1
AND column_name = $2
`, table, column).Scan(&row.DataType, &row.MaxLen, &row.Nullable)
require.NoError(t, err, "query information_schema.columns for %s.%s", table, column)
require.Equal(t, dataType, row.DataType, "data_type mismatch for %s.%s", table, column)
if maxLen > 0 {
require.True(t, row.MaxLen.Valid, "expected maxLen for %s.%s", table, column)
require.Equal(t, int64(maxLen), row.MaxLen.Int64, "maxLen mismatch for %s.%s", table, column)
}
if nullable {
require.Equal(t, "YES", row.Nullable, "nullable mismatch for %s.%s", table, column)
} else {
require.Equal(t, "NO", row.Nullable, "nullable mismatch for %s.%s", table, column)
}
}

View File

@@ -0,0 +1,89 @@
package repository
import (
"context"
"fmt"
"net/url"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/imroc/req/v3"
)
// NewOpenAIOAuthClient creates a new OpenAI OAuth client
func NewOpenAIOAuthClient() service.OpenAIOAuthClient {
return &openaiOAuthService{tokenURL: openai.TokenURL}
}
type openaiOAuthService struct {
tokenURL string
}
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL)
if redirectURI == "" {
redirectURI = openai.DefaultRedirectURI
}
formData := url.Values{}
formData.Set("grant_type", "authorization_code")
formData.Set("client_id", openai.ClientID)
formData.Set("code", code)
formData.Set("redirect_uri", redirectURI)
formData.Set("code_verifier", codeVerifier)
var tokenResp openai.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
}
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
formData.Set("refresh_token", refreshToken)
formData.Set("client_id", openai.ClientID)
formData.Set("scope", openai.RefreshScopes)
var tokenResp openai.TokenResponse
resp, err := client.R().
SetContext(ctx).
SetFormDataFromValues(formData).
SetSuccessResult(&tokenResp).
Post(s.tokenURL)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccessState() {
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
}
return &tokenResp, nil
}
func createOpenAIReqClient(proxyURL string) *req.Client {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 60 * time.Second,
})
}

View File

@@ -0,0 +1,249 @@
package repository
import (
"context"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type OpenAIOAuthServiceSuite struct {
suite.Suite
ctx context.Context
srv *httptest.Server
svc *openaiOAuthService
received chan url.Values
}
func (s *OpenAIOAuthServiceSuite) SetupTest() {
s.ctx = context.Background()
s.received = make(chan url.Values, 1)
}
func (s *OpenAIOAuthServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
func (s *OpenAIOAuthServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = newLocalTestServer(s.T(), handler)
s.svc = &openaiOAuthService{tokenURL: s.srv.URL}
}
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() {
errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
errCh <- "method mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if err := r.ParseForm(); err != nil {
errCh <- "ParseForm failed"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("grant_type"); got != "authorization_code" {
errCh <- "grant_type mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("client_id"); got != openai.ClientID {
errCh <- "client_id mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("code"); got != "code" {
errCh <- "code mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("redirect_uri"); got != openai.DefaultRedirectURI {
errCh <- "redirect_uri mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("code_verifier"); got != "ver" {
errCh <- "code_verifier mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`)
}))
resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case msg := <-errCh:
require.Fail(s.T(), msg)
default:
}
require.Equal(s.T(), "at", resp.AccessToken)
require.Equal(s.T(), "rt", resp.RefreshToken)
}
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
errCh <- "ParseForm failed"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("grant_type"); got != "refresh_token" {
errCh <- "grant_type mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("refresh_token"); got != "rt" {
errCh <- "refresh_token mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("client_id"); got != openai.ClientID {
errCh <- "client_id mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
if got := r.PostForm.Get("scope"); got != openai.RefreshScopes {
errCh <- "scope mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at2","refresh_token":"rt2","token_type":"bearer","expires_in":3600}`)
}))
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
require.NoError(s.T(), err, "RefreshToken")
select {
case msg := <-errCh:
require.Fail(s.T(), msg)
default:
}
require.Equal(s.T(), "at2", resp.AccessToken)
require.Equal(s.T(), "rt2", resp.RefreshToken)
}
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
_, _ = io.WriteString(w, "bad")
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "status 400")
require.ErrorContains(s.T(), err, "bad")
}
func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
s.srv.Close()
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "request failed")
}
func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
started := make(chan struct{})
block := make(chan struct{})
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(started)
<-block
}))
ctx, cancel := context.WithCancel(s.ctx)
done := make(chan error, 1)
go func() {
_, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "")
done <- err
}()
<-started
cancel()
close(block)
err := <-done
require.Error(s.T(), err)
}
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
want := "http://localhost:9999/cb"
errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseForm()
if got := r.PostForm.Get("redirect_uri"); got != want {
errCh <- "redirect_uri mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case msg := <-errCh:
require.Fail(s.T(), msg)
default:
}
}
func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseForm()
s.received <- r.PostForm
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
}))
s.svc.tokenURL = s.srv.URL + "?x=1"
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case <-s.received:
default:
require.Fail(s.T(), "expected server to receive request")
}
}
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "not-valid-json")
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
require.Error(s.T(), err, "expected error for invalid JSON response")
}
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = io.WriteString(w, "unauthorized")
}))
_, err := s.svc.RefreshToken(s.ctx, "rt", "")
require.Error(s.T(), err, "expected error for non-2xx status")
require.ErrorContains(s.T(), err, "status 401")
}
func TestOpenAIOAuthServiceSuite(t *testing.T) {
suite.Run(t, new(OpenAIOAuthServiceSuite))
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,853 @@
package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (r *opsRepository) ListAlertRules(ctx context.Context) ([]*service.OpsAlertRule, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
q := `
SELECT
id,
name,
COALESCE(description, ''),
enabled,
COALESCE(severity, ''),
metric_type,
operator,
threshold,
window_minutes,
sustained_minutes,
cooldown_minutes,
COALESCE(notify_email, true),
filters,
last_triggered_at,
created_at,
updated_at
FROM ops_alert_rules
ORDER BY id DESC`
rows, err := r.db.QueryContext(ctx, q)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
out := []*service.OpsAlertRule{}
for rows.Next() {
var rule service.OpsAlertRule
var filtersRaw []byte
var lastTriggeredAt sql.NullTime
if err := rows.Scan(
&rule.ID,
&rule.Name,
&rule.Description,
&rule.Enabled,
&rule.Severity,
&rule.MetricType,
&rule.Operator,
&rule.Threshold,
&rule.WindowMinutes,
&rule.SustainedMinutes,
&rule.CooldownMinutes,
&rule.NotifyEmail,
&filtersRaw,
&lastTriggeredAt,
&rule.CreatedAt,
&rule.UpdatedAt,
); err != nil {
return nil, err
}
if lastTriggeredAt.Valid {
v := lastTriggeredAt.Time
rule.LastTriggeredAt = &v
}
if len(filtersRaw) > 0 && string(filtersRaw) != "null" {
var decoded map[string]any
if err := json.Unmarshal(filtersRaw, &decoded); err == nil {
rule.Filters = decoded
}
}
out = append(out, &rule)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func (r *opsRepository) CreateAlertRule(ctx context.Context, input *service.OpsAlertRule) (*service.OpsAlertRule, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if input == nil {
return nil, fmt.Errorf("nil input")
}
filtersArg, err := opsNullJSONMap(input.Filters)
if err != nil {
return nil, err
}
q := `
INSERT INTO ops_alert_rules (
name,
description,
enabled,
severity,
metric_type,
operator,
threshold,
window_minutes,
sustained_minutes,
cooldown_minutes,
notify_email,
filters,
created_at,
updated_at
) VALUES (
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,NOW(),NOW()
)
RETURNING
id,
name,
COALESCE(description, ''),
enabled,
COALESCE(severity, ''),
metric_type,
operator,
threshold,
window_minutes,
sustained_minutes,
cooldown_minutes,
COALESCE(notify_email, true),
filters,
last_triggered_at,
created_at,
updated_at`
var out service.OpsAlertRule
var filtersRaw []byte
var lastTriggeredAt sql.NullTime
if err := r.db.QueryRowContext(
ctx,
q,
strings.TrimSpace(input.Name),
strings.TrimSpace(input.Description),
input.Enabled,
strings.TrimSpace(input.Severity),
strings.TrimSpace(input.MetricType),
strings.TrimSpace(input.Operator),
input.Threshold,
input.WindowMinutes,
input.SustainedMinutes,
input.CooldownMinutes,
input.NotifyEmail,
filtersArg,
).Scan(
&out.ID,
&out.Name,
&out.Description,
&out.Enabled,
&out.Severity,
&out.MetricType,
&out.Operator,
&out.Threshold,
&out.WindowMinutes,
&out.SustainedMinutes,
&out.CooldownMinutes,
&out.NotifyEmail,
&filtersRaw,
&lastTriggeredAt,
&out.CreatedAt,
&out.UpdatedAt,
); err != nil {
return nil, err
}
if lastTriggeredAt.Valid {
v := lastTriggeredAt.Time
out.LastTriggeredAt = &v
}
if len(filtersRaw) > 0 && string(filtersRaw) != "null" {
var decoded map[string]any
if err := json.Unmarshal(filtersRaw, &decoded); err == nil {
out.Filters = decoded
}
}
return &out, nil
}
func (r *opsRepository) UpdateAlertRule(ctx context.Context, input *service.OpsAlertRule) (*service.OpsAlertRule, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if input == nil {
return nil, fmt.Errorf("nil input")
}
if input.ID <= 0 {
return nil, fmt.Errorf("invalid id")
}
filtersArg, err := opsNullJSONMap(input.Filters)
if err != nil {
return nil, err
}
q := `
UPDATE ops_alert_rules
SET
name = $2,
description = $3,
enabled = $4,
severity = $5,
metric_type = $6,
operator = $7,
threshold = $8,
window_minutes = $9,
sustained_minutes = $10,
cooldown_minutes = $11,
notify_email = $12,
filters = $13,
updated_at = NOW()
WHERE id = $1
RETURNING
id,
name,
COALESCE(description, ''),
enabled,
COALESCE(severity, ''),
metric_type,
operator,
threshold,
window_minutes,
sustained_minutes,
cooldown_minutes,
COALESCE(notify_email, true),
filters,
last_triggered_at,
created_at,
updated_at`
var out service.OpsAlertRule
var filtersRaw []byte
var lastTriggeredAt sql.NullTime
if err := r.db.QueryRowContext(
ctx,
q,
input.ID,
strings.TrimSpace(input.Name),
strings.TrimSpace(input.Description),
input.Enabled,
strings.TrimSpace(input.Severity),
strings.TrimSpace(input.MetricType),
strings.TrimSpace(input.Operator),
input.Threshold,
input.WindowMinutes,
input.SustainedMinutes,
input.CooldownMinutes,
input.NotifyEmail,
filtersArg,
).Scan(
&out.ID,
&out.Name,
&out.Description,
&out.Enabled,
&out.Severity,
&out.MetricType,
&out.Operator,
&out.Threshold,
&out.WindowMinutes,
&out.SustainedMinutes,
&out.CooldownMinutes,
&out.NotifyEmail,
&filtersRaw,
&lastTriggeredAt,
&out.CreatedAt,
&out.UpdatedAt,
); err != nil {
return nil, err
}
if lastTriggeredAt.Valid {
v := lastTriggeredAt.Time
out.LastTriggeredAt = &v
}
if len(filtersRaw) > 0 && string(filtersRaw) != "null" {
var decoded map[string]any
if err := json.Unmarshal(filtersRaw, &decoded); err == nil {
out.Filters = decoded
}
}
return &out, nil
}
func (r *opsRepository) DeleteAlertRule(ctx context.Context, id int64) error {
if r == nil || r.db == nil {
return fmt.Errorf("nil ops repository")
}
if id <= 0 {
return fmt.Errorf("invalid id")
}
res, err := r.db.ExecContext(ctx, "DELETE FROM ops_alert_rules WHERE id = $1", id)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return sql.ErrNoRows
}
return nil
}
func (r *opsRepository) ListAlertEvents(ctx context.Context, filter *service.OpsAlertEventFilter) ([]*service.OpsAlertEvent, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if filter == nil {
filter = &service.OpsAlertEventFilter{}
}
limit := filter.Limit
if limit <= 0 {
limit = 100
}
if limit > 500 {
limit = 500
}
where, args := buildOpsAlertEventsWhere(filter)
args = append(args, limit)
limitArg := "$" + itoa(len(args))
q := `
SELECT
id,
COALESCE(rule_id, 0),
COALESCE(severity, ''),
COALESCE(status, ''),
COALESCE(title, ''),
COALESCE(description, ''),
metric_value,
threshold_value,
dimensions,
fired_at,
resolved_at,
email_sent,
created_at
FROM ops_alert_events
` + where + `
ORDER BY fired_at DESC, id DESC
LIMIT ` + limitArg
rows, err := r.db.QueryContext(ctx, q, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
out := []*service.OpsAlertEvent{}
for rows.Next() {
var ev service.OpsAlertEvent
var metricValue sql.NullFloat64
var thresholdValue sql.NullFloat64
var dimensionsRaw []byte
var resolvedAt sql.NullTime
if err := rows.Scan(
&ev.ID,
&ev.RuleID,
&ev.Severity,
&ev.Status,
&ev.Title,
&ev.Description,
&metricValue,
&thresholdValue,
&dimensionsRaw,
&ev.FiredAt,
&resolvedAt,
&ev.EmailSent,
&ev.CreatedAt,
); err != nil {
return nil, err
}
if metricValue.Valid {
v := metricValue.Float64
ev.MetricValue = &v
}
if thresholdValue.Valid {
v := thresholdValue.Float64
ev.ThresholdValue = &v
}
if resolvedAt.Valid {
v := resolvedAt.Time
ev.ResolvedAt = &v
}
if len(dimensionsRaw) > 0 && string(dimensionsRaw) != "null" {
var decoded map[string]any
if err := json.Unmarshal(dimensionsRaw, &decoded); err == nil {
ev.Dimensions = decoded
}
}
out = append(out, &ev)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func (r *opsRepository) GetAlertEventByID(ctx context.Context, eventID int64) (*service.OpsAlertEvent, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if eventID <= 0 {
return nil, fmt.Errorf("invalid event id")
}
q := `
SELECT
id,
COALESCE(rule_id, 0),
COALESCE(severity, ''),
COALESCE(status, ''),
COALESCE(title, ''),
COALESCE(description, ''),
metric_value,
threshold_value,
dimensions,
fired_at,
resolved_at,
email_sent,
created_at
FROM ops_alert_events
WHERE id = $1`
row := r.db.QueryRowContext(ctx, q, eventID)
ev, err := scanOpsAlertEvent(row)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return ev, nil
}
func (r *opsRepository) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*service.OpsAlertEvent, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if ruleID <= 0 {
return nil, fmt.Errorf("invalid rule id")
}
q := `
SELECT
id,
COALESCE(rule_id, 0),
COALESCE(severity, ''),
COALESCE(status, ''),
COALESCE(title, ''),
COALESCE(description, ''),
metric_value,
threshold_value,
dimensions,
fired_at,
resolved_at,
email_sent,
created_at
FROM ops_alert_events
WHERE rule_id = $1 AND status = $2
ORDER BY fired_at DESC
LIMIT 1`
row := r.db.QueryRowContext(ctx, q, ruleID, service.OpsAlertStatusFiring)
ev, err := scanOpsAlertEvent(row)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return ev, nil
}
func (r *opsRepository) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*service.OpsAlertEvent, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if ruleID <= 0 {
return nil, fmt.Errorf("invalid rule id")
}
q := `
SELECT
id,
COALESCE(rule_id, 0),
COALESCE(severity, ''),
COALESCE(status, ''),
COALESCE(title, ''),
COALESCE(description, ''),
metric_value,
threshold_value,
dimensions,
fired_at,
resolved_at,
email_sent,
created_at
FROM ops_alert_events
WHERE rule_id = $1
ORDER BY fired_at DESC
LIMIT 1`
row := r.db.QueryRowContext(ctx, q, ruleID)
ev, err := scanOpsAlertEvent(row)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return ev, nil
}
func (r *opsRepository) CreateAlertEvent(ctx context.Context, event *service.OpsAlertEvent) (*service.OpsAlertEvent, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if event == nil {
return nil, fmt.Errorf("nil event")
}
dimensionsArg, err := opsNullJSONMap(event.Dimensions)
if err != nil {
return nil, err
}
q := `
INSERT INTO ops_alert_events (
rule_id,
severity,
status,
title,
description,
metric_value,
threshold_value,
dimensions,
fired_at,
resolved_at,
email_sent,
created_at
) VALUES (
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,NOW()
)
RETURNING
id,
COALESCE(rule_id, 0),
COALESCE(severity, ''),
COALESCE(status, ''),
COALESCE(title, ''),
COALESCE(description, ''),
metric_value,
threshold_value,
dimensions,
fired_at,
resolved_at,
email_sent,
created_at`
row := r.db.QueryRowContext(
ctx,
q,
opsNullInt64(&event.RuleID),
opsNullString(event.Severity),
opsNullString(event.Status),
opsNullString(event.Title),
opsNullString(event.Description),
opsNullFloat64(event.MetricValue),
opsNullFloat64(event.ThresholdValue),
dimensionsArg,
event.FiredAt,
opsNullTime(event.ResolvedAt),
event.EmailSent,
)
return scanOpsAlertEvent(row)
}
func (r *opsRepository) UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error {
if r == nil || r.db == nil {
return fmt.Errorf("nil ops repository")
}
if eventID <= 0 {
return fmt.Errorf("invalid event id")
}
if strings.TrimSpace(status) == "" {
return fmt.Errorf("invalid status")
}
q := `
UPDATE ops_alert_events
SET status = $2,
resolved_at = $3
WHERE id = $1`
_, err := r.db.ExecContext(ctx, q, eventID, strings.TrimSpace(status), opsNullTime(resolvedAt))
return err
}
func (r *opsRepository) UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error {
if r == nil || r.db == nil {
return fmt.Errorf("nil ops repository")
}
if eventID <= 0 {
return fmt.Errorf("invalid event id")
}
_, err := r.db.ExecContext(ctx, "UPDATE ops_alert_events SET email_sent = $2 WHERE id = $1", eventID, emailSent)
return err
}
type opsAlertEventRow interface {
Scan(dest ...any) error
}
func (r *opsRepository) CreateAlertSilence(ctx context.Context, input *service.OpsAlertSilence) (*service.OpsAlertSilence, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if input == nil {
return nil, fmt.Errorf("nil input")
}
if input.RuleID <= 0 {
return nil, fmt.Errorf("invalid rule_id")
}
platform := strings.TrimSpace(input.Platform)
if platform == "" {
return nil, fmt.Errorf("invalid platform")
}
if input.Until.IsZero() {
return nil, fmt.Errorf("invalid until")
}
q := `
INSERT INTO ops_alert_silences (
rule_id,
platform,
group_id,
region,
until,
reason,
created_by,
created_at
) VALUES (
$1,$2,$3,$4,$5,$6,$7,NOW()
)
RETURNING id, rule_id, platform, group_id, region, until, COALESCE(reason,''), created_by, created_at`
row := r.db.QueryRowContext(
ctx,
q,
input.RuleID,
platform,
opsNullInt64(input.GroupID),
opsNullString(input.Region),
input.Until,
opsNullString(input.Reason),
opsNullInt64(input.CreatedBy),
)
var out service.OpsAlertSilence
var groupID sql.NullInt64
var region sql.NullString
var createdBy sql.NullInt64
if err := row.Scan(
&out.ID,
&out.RuleID,
&out.Platform,
&groupID,
&region,
&out.Until,
&out.Reason,
&createdBy,
&out.CreatedAt,
); err != nil {
return nil, err
}
if groupID.Valid {
v := groupID.Int64
out.GroupID = &v
}
if region.Valid {
v := strings.TrimSpace(region.String)
if v != "" {
out.Region = &v
}
}
if createdBy.Valid {
v := createdBy.Int64
out.CreatedBy = &v
}
return &out, nil
}
func (r *opsRepository) IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error) {
if r == nil || r.db == nil {
return false, fmt.Errorf("nil ops repository")
}
if ruleID <= 0 {
return false, fmt.Errorf("invalid rule id")
}
platform = strings.TrimSpace(platform)
if platform == "" {
return false, nil
}
if now.IsZero() {
now = time.Now().UTC()
}
q := `
SELECT 1
FROM ops_alert_silences
WHERE rule_id = $1
AND platform = $2
AND (group_id IS NOT DISTINCT FROM $3)
AND (region IS NOT DISTINCT FROM $4)
AND until > $5
LIMIT 1`
var dummy int
err := r.db.QueryRowContext(ctx, q, ruleID, platform, opsNullInt64(groupID), opsNullString(region), now).Scan(&dummy)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, err
}
return true, nil
}
func scanOpsAlertEvent(row opsAlertEventRow) (*service.OpsAlertEvent, error) {
var ev service.OpsAlertEvent
var metricValue sql.NullFloat64
var thresholdValue sql.NullFloat64
var dimensionsRaw []byte
var resolvedAt sql.NullTime
if err := row.Scan(
&ev.ID,
&ev.RuleID,
&ev.Severity,
&ev.Status,
&ev.Title,
&ev.Description,
&metricValue,
&thresholdValue,
&dimensionsRaw,
&ev.FiredAt,
&resolvedAt,
&ev.EmailSent,
&ev.CreatedAt,
); err != nil {
return nil, err
}
if metricValue.Valid {
v := metricValue.Float64
ev.MetricValue = &v
}
if thresholdValue.Valid {
v := thresholdValue.Float64
ev.ThresholdValue = &v
}
if resolvedAt.Valid {
v := resolvedAt.Time
ev.ResolvedAt = &v
}
if len(dimensionsRaw) > 0 && string(dimensionsRaw) != "null" {
var decoded map[string]any
if err := json.Unmarshal(dimensionsRaw, &decoded); err == nil {
ev.Dimensions = decoded
}
}
return &ev, nil
}
func buildOpsAlertEventsWhere(filter *service.OpsAlertEventFilter) (string, []any) {
clauses := []string{"1=1"}
args := []any{}
if filter == nil {
return "WHERE " + strings.Join(clauses, " AND "), args
}
if status := strings.TrimSpace(filter.Status); status != "" {
args = append(args, status)
clauses = append(clauses, "status = $"+itoa(len(args)))
}
if severity := strings.TrimSpace(filter.Severity); severity != "" {
args = append(args, severity)
clauses = append(clauses, "severity = $"+itoa(len(args)))
}
if filter.EmailSent != nil {
args = append(args, *filter.EmailSent)
clauses = append(clauses, "email_sent = $"+itoa(len(args)))
}
if filter.StartTime != nil && !filter.StartTime.IsZero() {
args = append(args, *filter.StartTime)
clauses = append(clauses, "fired_at >= $"+itoa(len(args)))
}
if filter.EndTime != nil && !filter.EndTime.IsZero() {
args = append(args, *filter.EndTime)
clauses = append(clauses, "fired_at < $"+itoa(len(args)))
}
// Cursor pagination (descending by fired_at, then id)
if filter.BeforeFiredAt != nil && !filter.BeforeFiredAt.IsZero() && filter.BeforeID != nil && *filter.BeforeID > 0 {
args = append(args, *filter.BeforeFiredAt)
tsArg := "$" + itoa(len(args))
args = append(args, *filter.BeforeID)
idArg := "$" + itoa(len(args))
clauses = append(clauses, fmt.Sprintf("(fired_at < %s OR (fired_at = %s AND id < %s))", tsArg, tsArg, idArg))
}
// Dimensions are stored in JSONB. We filter best-effort without requiring GIN indexes.
if platform := strings.TrimSpace(filter.Platform); platform != "" {
args = append(args, platform)
clauses = append(clauses, "(dimensions->>'platform') = $"+itoa(len(args)))
}
if filter.GroupID != nil && *filter.GroupID > 0 {
args = append(args, fmt.Sprintf("%d", *filter.GroupID))
clauses = append(clauses, "(dimensions->>'group_id') = $"+itoa(len(args)))
}
return "WHERE " + strings.Join(clauses, " AND "), args
}
func opsNullJSONMap(v map[string]any) (any, error) {
if v == nil {
return sql.NullString{}, nil
}
b, err := json.Marshal(v)
if err != nil {
return nil, err
}
if len(b) == 0 {
return sql.NullString{}, nil
}
return sql.NullString{String: string(b), Valid: true}, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,79 @@
package repository
import (
"context"
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (r *opsRepository) GetLatencyHistogram(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsLatencyHistogramResponse, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if filter == nil {
return nil, fmt.Errorf("nil filter")
}
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
return nil, fmt.Errorf("start_time/end_time required")
}
start := filter.StartTime.UTC()
end := filter.EndTime.UTC()
join, where, args, _ := buildUsageWhere(filter, start, end, 1)
rangeExpr := latencyHistogramRangeCaseExpr("ul.duration_ms")
orderExpr := latencyHistogramRangeOrderCaseExpr("ul.duration_ms")
q := `
SELECT
` + rangeExpr + ` AS range,
COALESCE(COUNT(*), 0) AS count,
` + orderExpr + ` AS ord
FROM usage_logs ul
` + join + `
` + where + `
AND ul.duration_ms IS NOT NULL
GROUP BY 1, 3
ORDER BY 3 ASC`
rows, err := r.db.QueryContext(ctx, q, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
counts := make(map[string]int64, len(latencyHistogramOrderedRanges))
var total int64
for rows.Next() {
var label string
var count int64
var _ord int
if err := rows.Scan(&label, &count, &_ord); err != nil {
return nil, err
}
counts[label] = count
total += count
}
if err := rows.Err(); err != nil {
return nil, err
}
buckets := make([]*service.OpsLatencyHistogramBucket, 0, len(latencyHistogramOrderedRanges))
for _, label := range latencyHistogramOrderedRanges {
buckets = append(buckets, &service.OpsLatencyHistogramBucket{
Range: label,
Count: counts[label],
})
}
return &service.OpsLatencyHistogramResponse{
StartTime: start,
EndTime: end,
Platform: strings.TrimSpace(filter.Platform),
GroupID: filter.GroupID,
TotalRequests: total,
Buckets: buckets,
}, nil
}

View File

@@ -0,0 +1,64 @@
package repository
import (
"fmt"
"strings"
)
type latencyHistogramBucket struct {
upperMs int
label string
}
var latencyHistogramBuckets = []latencyHistogramBucket{
{upperMs: 100, label: "0-100ms"},
{upperMs: 200, label: "100-200ms"},
{upperMs: 500, label: "200-500ms"},
{upperMs: 1000, label: "500-1000ms"},
{upperMs: 2000, label: "1000-2000ms"},
{upperMs: 0, label: "2000ms+"}, // default bucket
}
var latencyHistogramOrderedRanges = func() []string {
out := make([]string, 0, len(latencyHistogramBuckets))
for _, b := range latencyHistogramBuckets {
out = append(out, b.label)
}
return out
}()
func latencyHistogramRangeCaseExpr(column string) string {
var sb strings.Builder
_, _ = sb.WriteString("CASE\n")
for _, b := range latencyHistogramBuckets {
if b.upperMs <= 0 {
continue
}
_, _ = sb.WriteString(fmt.Sprintf("\tWHEN %s < %d THEN '%s'\n", column, b.upperMs, b.label))
}
// Default bucket.
last := latencyHistogramBuckets[len(latencyHistogramBuckets)-1]
_, _ = sb.WriteString(fmt.Sprintf("\tELSE '%s'\n", last.label))
_, _ = sb.WriteString("END")
return sb.String()
}
func latencyHistogramRangeOrderCaseExpr(column string) string {
var sb strings.Builder
_, _ = sb.WriteString("CASE\n")
order := 1
for _, b := range latencyHistogramBuckets {
if b.upperMs <= 0 {
continue
}
_, _ = sb.WriteString(fmt.Sprintf("\tWHEN %s < %d THEN %d\n", column, b.upperMs, order))
order++
}
_, _ = sb.WriteString(fmt.Sprintf("\tELSE %d\n", order))
_, _ = sb.WriteString("END")
return sb.String()
}

View File

@@ -0,0 +1,14 @@
package repository
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestLatencyHistogramBuckets_AreConsistent(t *testing.T) {
require.Equal(t, len(latencyHistogramBuckets), len(latencyHistogramOrderedRanges))
for i, b := range latencyHistogramBuckets {
require.Equal(t, b.label, latencyHistogramOrderedRanges[i])
}
}

View File

@@ -0,0 +1,422 @@
package repository
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (r *opsRepository) InsertSystemMetrics(ctx context.Context, input *service.OpsInsertSystemMetricsInput) error {
if r == nil || r.db == nil {
return fmt.Errorf("nil ops repository")
}
if input == nil {
return fmt.Errorf("nil input")
}
window := input.WindowMinutes
if window <= 0 {
window = 1
}
createdAt := input.CreatedAt
if createdAt.IsZero() {
createdAt = time.Now().UTC()
}
q := `
INSERT INTO ops_system_metrics (
created_at,
window_minutes,
platform,
group_id,
success_count,
error_count_total,
business_limited_count,
error_count_sla,
upstream_error_count_excl_429_529,
upstream_429_count,
upstream_529_count,
token_consumed,
qps,
tps,
duration_p50_ms,
duration_p90_ms,
duration_p95_ms,
duration_p99_ms,
duration_avg_ms,
duration_max_ms,
ttft_p50_ms,
ttft_p90_ms,
ttft_p95_ms,
ttft_p99_ms,
ttft_avg_ms,
ttft_max_ms,
cpu_usage_percent,
memory_used_mb,
memory_total_mb,
memory_usage_percent,
db_ok,
redis_ok,
redis_conn_total,
redis_conn_idle,
db_conn_active,
db_conn_idle,
db_conn_waiting,
goroutine_count,
concurrency_queue_depth
) VALUES (
$1,$2,$3,$4,
$5,$6,$7,$8,
$9,$10,$11,
$12,$13,$14,
$15,$16,$17,$18,$19,$20,
$21,$22,$23,$24,$25,$26,
$27,$28,$29,$30,
$31,$32,
$33,$34,
$35,$36,$37,
$38,$39
)`
_, err := r.db.ExecContext(
ctx,
q,
createdAt,
window,
opsNullString(input.Platform),
opsNullInt64(input.GroupID),
input.SuccessCount,
input.ErrorCountTotal,
input.BusinessLimitedCount,
input.ErrorCountSLA,
input.UpstreamErrorCountExcl429529,
input.Upstream429Count,
input.Upstream529Count,
input.TokenConsumed,
opsNullFloat64(input.QPS),
opsNullFloat64(input.TPS),
opsNullInt(input.DurationP50Ms),
opsNullInt(input.DurationP90Ms),
opsNullInt(input.DurationP95Ms),
opsNullInt(input.DurationP99Ms),
opsNullFloat64(input.DurationAvgMs),
opsNullInt(input.DurationMaxMs),
opsNullInt(input.TTFTP50Ms),
opsNullInt(input.TTFTP90Ms),
opsNullInt(input.TTFTP95Ms),
opsNullInt(input.TTFTP99Ms),
opsNullFloat64(input.TTFTAvgMs),
opsNullInt(input.TTFTMaxMs),
opsNullFloat64(input.CPUUsagePercent),
opsNullInt(input.MemoryUsedMB),
opsNullInt(input.MemoryTotalMB),
opsNullFloat64(input.MemoryUsagePercent),
opsNullBool(input.DBOK),
opsNullBool(input.RedisOK),
opsNullInt(input.RedisConnTotal),
opsNullInt(input.RedisConnIdle),
opsNullInt(input.DBConnActive),
opsNullInt(input.DBConnIdle),
opsNullInt(input.DBConnWaiting),
opsNullInt(input.GoroutineCount),
opsNullInt(input.ConcurrencyQueueDepth),
)
return err
}
func (r *opsRepository) GetLatestSystemMetrics(ctx context.Context, windowMinutes int) (*service.OpsSystemMetricsSnapshot, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if windowMinutes <= 0 {
windowMinutes = 1
}
q := `
SELECT
id,
created_at,
window_minutes,
cpu_usage_percent,
memory_used_mb,
memory_total_mb,
memory_usage_percent,
db_ok,
redis_ok,
redis_conn_total,
redis_conn_idle,
db_conn_active,
db_conn_idle,
db_conn_waiting,
goroutine_count,
concurrency_queue_depth
FROM ops_system_metrics
WHERE window_minutes = $1
AND platform IS NULL
AND group_id IS NULL
ORDER BY created_at DESC
LIMIT 1`
var out service.OpsSystemMetricsSnapshot
var cpu sql.NullFloat64
var memUsed sql.NullInt64
var memTotal sql.NullInt64
var memPct sql.NullFloat64
var dbOK sql.NullBool
var redisOK sql.NullBool
var redisTotal sql.NullInt64
var redisIdle sql.NullInt64
var dbActive sql.NullInt64
var dbIdle sql.NullInt64
var dbWaiting sql.NullInt64
var goroutines sql.NullInt64
var queueDepth sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, windowMinutes).Scan(
&out.ID,
&out.CreatedAt,
&out.WindowMinutes,
&cpu,
&memUsed,
&memTotal,
&memPct,
&dbOK,
&redisOK,
&redisTotal,
&redisIdle,
&dbActive,
&dbIdle,
&dbWaiting,
&goroutines,
&queueDepth,
); err != nil {
return nil, err
}
if cpu.Valid {
v := cpu.Float64
out.CPUUsagePercent = &v
}
if memUsed.Valid {
v := memUsed.Int64
out.MemoryUsedMB = &v
}
if memTotal.Valid {
v := memTotal.Int64
out.MemoryTotalMB = &v
}
if memPct.Valid {
v := memPct.Float64
out.MemoryUsagePercent = &v
}
if dbOK.Valid {
v := dbOK.Bool
out.DBOK = &v
}
if redisOK.Valid {
v := redisOK.Bool
out.RedisOK = &v
}
if redisTotal.Valid {
v := int(redisTotal.Int64)
out.RedisConnTotal = &v
}
if redisIdle.Valid {
v := int(redisIdle.Int64)
out.RedisConnIdle = &v
}
if dbActive.Valid {
v := int(dbActive.Int64)
out.DBConnActive = &v
}
if dbIdle.Valid {
v := int(dbIdle.Int64)
out.DBConnIdle = &v
}
if dbWaiting.Valid {
v := int(dbWaiting.Int64)
out.DBConnWaiting = &v
}
if goroutines.Valid {
v := int(goroutines.Int64)
out.GoroutineCount = &v
}
if queueDepth.Valid {
v := int(queueDepth.Int64)
out.ConcurrencyQueueDepth = &v
}
return &out, nil
}
func (r *opsRepository) UpsertJobHeartbeat(ctx context.Context, input *service.OpsUpsertJobHeartbeatInput) error {
if r == nil || r.db == nil {
return fmt.Errorf("nil ops repository")
}
if input == nil {
return fmt.Errorf("nil input")
}
if input.JobName == "" {
return fmt.Errorf("job_name required")
}
q := `
INSERT INTO ops_job_heartbeats (
job_name,
last_run_at,
last_success_at,
last_error_at,
last_error,
last_duration_ms,
updated_at
) VALUES (
$1,$2,$3,$4,$5,$6,NOW()
)
ON CONFLICT (job_name) DO UPDATE SET
last_run_at = COALESCE(EXCLUDED.last_run_at, ops_job_heartbeats.last_run_at),
last_success_at = COALESCE(EXCLUDED.last_success_at, ops_job_heartbeats.last_success_at),
last_error_at = CASE
WHEN EXCLUDED.last_success_at IS NOT NULL THEN NULL
ELSE COALESCE(EXCLUDED.last_error_at, ops_job_heartbeats.last_error_at)
END,
last_error = CASE
WHEN EXCLUDED.last_success_at IS NOT NULL THEN NULL
ELSE COALESCE(EXCLUDED.last_error, ops_job_heartbeats.last_error)
END,
last_duration_ms = COALESCE(EXCLUDED.last_duration_ms, ops_job_heartbeats.last_duration_ms),
updated_at = NOW()`
_, err := r.db.ExecContext(
ctx,
q,
input.JobName,
opsNullTime(input.LastRunAt),
opsNullTime(input.LastSuccessAt),
opsNullTime(input.LastErrorAt),
opsNullString(input.LastError),
opsNullInt(input.LastDurationMs),
)
return err
}
func (r *opsRepository) ListJobHeartbeats(ctx context.Context) ([]*service.OpsJobHeartbeat, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
q := `
SELECT
job_name,
last_run_at,
last_success_at,
last_error_at,
last_error,
last_duration_ms,
updated_at
FROM ops_job_heartbeats
ORDER BY job_name ASC`
rows, err := r.db.QueryContext(ctx, q)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
out := make([]*service.OpsJobHeartbeat, 0, 8)
for rows.Next() {
var item service.OpsJobHeartbeat
var lastRun sql.NullTime
var lastSuccess sql.NullTime
var lastErrorAt sql.NullTime
var lastError sql.NullString
var lastDuration sql.NullInt64
if err := rows.Scan(
&item.JobName,
&lastRun,
&lastSuccess,
&lastErrorAt,
&lastError,
&lastDuration,
&item.UpdatedAt,
); err != nil {
return nil, err
}
if lastRun.Valid {
v := lastRun.Time
item.LastRunAt = &v
}
if lastSuccess.Valid {
v := lastSuccess.Time
item.LastSuccessAt = &v
}
if lastErrorAt.Valid {
v := lastErrorAt.Time
item.LastErrorAt = &v
}
if lastError.Valid {
v := lastError.String
item.LastError = &v
}
if lastDuration.Valid {
v := lastDuration.Int64
item.LastDurationMs = &v
}
out = append(out, &item)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func opsNullBool(v *bool) any {
if v == nil {
return sql.NullBool{}
}
return sql.NullBool{Bool: *v, Valid: true}
}
func opsNullFloat64(v *float64) any {
if v == nil {
return sql.NullFloat64{}
}
return sql.NullFloat64{Float64: *v, Valid: true}
}
func opsNullTime(v *time.Time) any {
if v == nil || v.IsZero() {
return sql.NullTime{}
}
return sql.NullTime{Time: *v, Valid: true}
}

View File

@@ -0,0 +1,363 @@
package repository
import (
"context"
"database/sql"
"fmt"
"time"
)
func (r *opsRepository) UpsertHourlyMetrics(ctx context.Context, startTime, endTime time.Time) error {
if r == nil || r.db == nil {
return fmt.Errorf("nil ops repository")
}
if startTime.IsZero() || endTime.IsZero() || !endTime.After(startTime) {
return nil
}
start := startTime.UTC()
end := endTime.UTC()
// NOTE:
// - We aggregate usage_logs + ops_error_logs into ops_metrics_hourly.
// - We emit three dimension granularities via GROUPING SETS:
// 1) overall: (bucket_start)
// 2) platform: (bucket_start, platform)
// 3) group: (bucket_start, platform, group_id)
//
// IMPORTANT: Postgres UNIQUE treats NULLs as distinct, so the table uses a COALESCE-based
// unique index; our ON CONFLICT target must match that expression set.
q := `
WITH usage_base AS (
SELECT
date_trunc('hour', ul.created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
g.platform AS platform,
ul.group_id AS group_id,
ul.duration_ms AS duration_ms,
ul.first_token_ms AS first_token_ms,
(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens) AS tokens
FROM usage_logs ul
JOIN groups g ON g.id = ul.group_id
WHERE ul.created_at >= $1 AND ul.created_at < $2
),
usage_agg AS (
SELECT
bucket_start,
CASE WHEN GROUPING(platform) = 1 THEN NULL ELSE platform END AS platform,
CASE WHEN GROUPING(group_id) = 1 THEN NULL ELSE group_id END AS group_id,
COUNT(*) AS success_count,
COALESCE(SUM(tokens), 0) AS token_consumed,
percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p50_ms,
percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p90_ms,
percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p95_ms,
percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p99_ms,
AVG(duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_avg_ms,
MAX(duration_ms) AS duration_max_ms,
percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p50_ms,
percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p90_ms,
percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p95_ms,
percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p99_ms,
AVG(first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_avg_ms,
MAX(first_token_ms) AS ttft_max_ms
FROM usage_base
GROUP BY GROUPING SETS (
(bucket_start),
(bucket_start, platform),
(bucket_start, platform, group_id)
)
),
error_base AS (
SELECT
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
-- platform is NULL for some early-phase errors (e.g. before routing); map to a sentinel
-- value so platform-level GROUPING SETS don't collide with the overall (platform=NULL) row.
COALESCE(platform, 'unknown') AS platform,
group_id AS group_id,
is_business_limited AS is_business_limited,
error_owner AS error_owner,
status_code AS client_status_code,
COALESCE(upstream_status_code, status_code, 0) AS effective_status_code
FROM ops_error_logs
-- Exclude count_tokens requests from error metrics as they are informational probes
WHERE created_at >= $1 AND created_at < $2
AND is_count_tokens = FALSE
),
error_agg AS (
SELECT
bucket_start,
CASE WHEN GROUPING(platform) = 1 THEN NULL ELSE platform END AS platform,
CASE WHEN GROUPING(group_id) = 1 THEN NULL ELSE group_id END AS group_id,
COUNT(*) FILTER (WHERE COALESCE(client_status_code, 0) >= 400) AS error_count_total,
COUNT(*) FILTER (WHERE COALESCE(client_status_code, 0) >= 400 AND is_business_limited) AS business_limited_count,
COUNT(*) FILTER (WHERE COALESCE(client_status_code, 0) >= 400 AND NOT is_business_limited) AS error_count_sla,
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(effective_status_code, 0) NOT IN (429, 529)) AS upstream_error_count_excl_429_529,
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(effective_status_code, 0) = 429) AS upstream_429_count,
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(effective_status_code, 0) = 529) AS upstream_529_count
FROM error_base
GROUP BY GROUPING SETS (
(bucket_start),
(bucket_start, platform),
(bucket_start, platform, group_id)
)
HAVING GROUPING(group_id) = 1 OR group_id IS NOT NULL
),
combined AS (
SELECT
COALESCE(u.bucket_start, e.bucket_start) AS bucket_start,
COALESCE(u.platform, e.platform) AS platform,
COALESCE(u.group_id, e.group_id) AS group_id,
COALESCE(u.success_count, 0) AS success_count,
COALESCE(e.error_count_total, 0) AS error_count_total,
COALESCE(e.business_limited_count, 0) AS business_limited_count,
COALESCE(e.error_count_sla, 0) AS error_count_sla,
COALESCE(e.upstream_error_count_excl_429_529, 0) AS upstream_error_count_excl_429_529,
COALESCE(e.upstream_429_count, 0) AS upstream_429_count,
COALESCE(e.upstream_529_count, 0) AS upstream_529_count,
COALESCE(u.token_consumed, 0) AS token_consumed,
u.duration_p50_ms,
u.duration_p90_ms,
u.duration_p95_ms,
u.duration_p99_ms,
u.duration_avg_ms,
u.duration_max_ms,
u.ttft_p50_ms,
u.ttft_p90_ms,
u.ttft_p95_ms,
u.ttft_p99_ms,
u.ttft_avg_ms,
u.ttft_max_ms
FROM usage_agg u
FULL OUTER JOIN error_agg e
ON u.bucket_start = e.bucket_start
AND COALESCE(u.platform, '') = COALESCE(e.platform, '')
AND COALESCE(u.group_id, 0) = COALESCE(e.group_id, 0)
)
INSERT INTO ops_metrics_hourly (
bucket_start,
platform,
group_id,
success_count,
error_count_total,
business_limited_count,
error_count_sla,
upstream_error_count_excl_429_529,
upstream_429_count,
upstream_529_count,
token_consumed,
duration_p50_ms,
duration_p90_ms,
duration_p95_ms,
duration_p99_ms,
duration_avg_ms,
duration_max_ms,
ttft_p50_ms,
ttft_p90_ms,
ttft_p95_ms,
ttft_p99_ms,
ttft_avg_ms,
ttft_max_ms,
computed_at
)
SELECT
bucket_start,
NULLIF(platform, '') AS platform,
group_id,
success_count,
error_count_total,
business_limited_count,
error_count_sla,
upstream_error_count_excl_429_529,
upstream_429_count,
upstream_529_count,
token_consumed,
duration_p50_ms::int,
duration_p90_ms::int,
duration_p95_ms::int,
duration_p99_ms::int,
duration_avg_ms,
duration_max_ms::int,
ttft_p50_ms::int,
ttft_p90_ms::int,
ttft_p95_ms::int,
ttft_p99_ms::int,
ttft_avg_ms,
ttft_max_ms::int,
NOW()
FROM combined
WHERE bucket_start IS NOT NULL
AND (platform IS NULL OR platform <> '')
ON CONFLICT (bucket_start, COALESCE(platform, ''), COALESCE(group_id, 0)) DO UPDATE SET
success_count = EXCLUDED.success_count,
error_count_total = EXCLUDED.error_count_total,
business_limited_count = EXCLUDED.business_limited_count,
error_count_sla = EXCLUDED.error_count_sla,
upstream_error_count_excl_429_529 = EXCLUDED.upstream_error_count_excl_429_529,
upstream_429_count = EXCLUDED.upstream_429_count,
upstream_529_count = EXCLUDED.upstream_529_count,
token_consumed = EXCLUDED.token_consumed,
duration_p50_ms = EXCLUDED.duration_p50_ms,
duration_p90_ms = EXCLUDED.duration_p90_ms,
duration_p95_ms = EXCLUDED.duration_p95_ms,
duration_p99_ms = EXCLUDED.duration_p99_ms,
duration_avg_ms = EXCLUDED.duration_avg_ms,
duration_max_ms = EXCLUDED.duration_max_ms,
ttft_p50_ms = EXCLUDED.ttft_p50_ms,
ttft_p90_ms = EXCLUDED.ttft_p90_ms,
ttft_p95_ms = EXCLUDED.ttft_p95_ms,
ttft_p99_ms = EXCLUDED.ttft_p99_ms,
ttft_avg_ms = EXCLUDED.ttft_avg_ms,
ttft_max_ms = EXCLUDED.ttft_max_ms,
computed_at = NOW()
`
_, err := r.db.ExecContext(ctx, q, start, end)
return err
}
func (r *opsRepository) UpsertDailyMetrics(ctx context.Context, startTime, endTime time.Time) error {
if r == nil || r.db == nil {
return fmt.Errorf("nil ops repository")
}
if startTime.IsZero() || endTime.IsZero() || !endTime.After(startTime) {
return nil
}
start := startTime.UTC()
end := endTime.UTC()
q := `
INSERT INTO ops_metrics_daily (
bucket_date,
platform,
group_id,
success_count,
error_count_total,
business_limited_count,
error_count_sla,
upstream_error_count_excl_429_529,
upstream_429_count,
upstream_529_count,
token_consumed,
duration_p50_ms,
duration_p90_ms,
duration_p95_ms,
duration_p99_ms,
duration_avg_ms,
duration_max_ms,
ttft_p50_ms,
ttft_p90_ms,
ttft_p95_ms,
ttft_p99_ms,
ttft_avg_ms,
ttft_max_ms,
computed_at
)
SELECT
(bucket_start AT TIME ZONE 'UTC')::date AS bucket_date,
platform,
group_id,
COALESCE(SUM(success_count), 0) AS success_count,
COALESCE(SUM(error_count_total), 0) AS error_count_total,
COALESCE(SUM(business_limited_count), 0) AS business_limited_count,
COALESCE(SUM(error_count_sla), 0) AS error_count_sla,
COALESCE(SUM(upstream_error_count_excl_429_529), 0) AS upstream_error_count_excl_429_529,
COALESCE(SUM(upstream_429_count), 0) AS upstream_429_count,
COALESCE(SUM(upstream_529_count), 0) AS upstream_529_count,
COALESCE(SUM(token_consumed), 0) AS token_consumed,
-- Approximation: weighted average for p50/p90, max for p95/p99 (conservative tail).
ROUND(SUM(duration_p50_ms::double precision * success_count) FILTER (WHERE duration_p50_ms IS NOT NULL)
/ NULLIF(SUM(success_count) FILTER (WHERE duration_p50_ms IS NOT NULL), 0))::int AS duration_p50_ms,
ROUND(SUM(duration_p90_ms::double precision * success_count) FILTER (WHERE duration_p90_ms IS NOT NULL)
/ NULLIF(SUM(success_count) FILTER (WHERE duration_p90_ms IS NOT NULL), 0))::int AS duration_p90_ms,
MAX(duration_p95_ms) AS duration_p95_ms,
MAX(duration_p99_ms) AS duration_p99_ms,
SUM(duration_avg_ms * success_count) FILTER (WHERE duration_avg_ms IS NOT NULL)
/ NULLIF(SUM(success_count) FILTER (WHERE duration_avg_ms IS NOT NULL), 0) AS duration_avg_ms,
MAX(duration_max_ms) AS duration_max_ms,
ROUND(SUM(ttft_p50_ms::double precision * success_count) FILTER (WHERE ttft_p50_ms IS NOT NULL)
/ NULLIF(SUM(success_count) FILTER (WHERE ttft_p50_ms IS NOT NULL), 0))::int AS ttft_p50_ms,
ROUND(SUM(ttft_p90_ms::double precision * success_count) FILTER (WHERE ttft_p90_ms IS NOT NULL)
/ NULLIF(SUM(success_count) FILTER (WHERE ttft_p90_ms IS NOT NULL), 0))::int AS ttft_p90_ms,
MAX(ttft_p95_ms) AS ttft_p95_ms,
MAX(ttft_p99_ms) AS ttft_p99_ms,
SUM(ttft_avg_ms * success_count) FILTER (WHERE ttft_avg_ms IS NOT NULL)
/ NULLIF(SUM(success_count) FILTER (WHERE ttft_avg_ms IS NOT NULL), 0) AS ttft_avg_ms,
MAX(ttft_max_ms) AS ttft_max_ms,
NOW()
FROM ops_metrics_hourly
WHERE bucket_start >= $1 AND bucket_start < $2
GROUP BY 1, 2, 3
ON CONFLICT (bucket_date, COALESCE(platform, ''), COALESCE(group_id, 0)) DO UPDATE SET
success_count = EXCLUDED.success_count,
error_count_total = EXCLUDED.error_count_total,
business_limited_count = EXCLUDED.business_limited_count,
error_count_sla = EXCLUDED.error_count_sla,
upstream_error_count_excl_429_529 = EXCLUDED.upstream_error_count_excl_429_529,
upstream_429_count = EXCLUDED.upstream_429_count,
upstream_529_count = EXCLUDED.upstream_529_count,
token_consumed = EXCLUDED.token_consumed,
duration_p50_ms = EXCLUDED.duration_p50_ms,
duration_p90_ms = EXCLUDED.duration_p90_ms,
duration_p95_ms = EXCLUDED.duration_p95_ms,
duration_p99_ms = EXCLUDED.duration_p99_ms,
duration_avg_ms = EXCLUDED.duration_avg_ms,
duration_max_ms = EXCLUDED.duration_max_ms,
ttft_p50_ms = EXCLUDED.ttft_p50_ms,
ttft_p90_ms = EXCLUDED.ttft_p90_ms,
ttft_p95_ms = EXCLUDED.ttft_p95_ms,
ttft_p99_ms = EXCLUDED.ttft_p99_ms,
ttft_avg_ms = EXCLUDED.ttft_avg_ms,
ttft_max_ms = EXCLUDED.ttft_max_ms,
computed_at = NOW()
`
_, err := r.db.ExecContext(ctx, q, start, end)
return err
}
func (r *opsRepository) GetLatestHourlyBucketStart(ctx context.Context) (time.Time, bool, error) {
if r == nil || r.db == nil {
return time.Time{}, false, fmt.Errorf("nil ops repository")
}
var value sql.NullTime
if err := r.db.QueryRowContext(ctx, `SELECT MAX(bucket_start) FROM ops_metrics_hourly`).Scan(&value); err != nil {
return time.Time{}, false, err
}
if !value.Valid {
return time.Time{}, false, nil
}
return value.Time.UTC(), true, nil
}
func (r *opsRepository) GetLatestDailyBucketDate(ctx context.Context) (time.Time, bool, error) {
if r == nil || r.db == nil {
return time.Time{}, false, fmt.Errorf("nil ops repository")
}
var value sql.NullTime
if err := r.db.QueryRowContext(ctx, `SELECT MAX(bucket_date) FROM ops_metrics_daily`).Scan(&value); err != nil {
return time.Time{}, false, err
}
if !value.Valid {
return time.Time{}, false, nil
}
t := value.Time.UTC()
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC), true, nil
}

View File

@@ -0,0 +1,129 @@
package repository
import (
"context"
"fmt"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (r *opsRepository) GetRealtimeTrafficSummary(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsRealtimeTrafficSummary, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if filter == nil {
return nil, fmt.Errorf("nil filter")
}
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
return nil, fmt.Errorf("start_time/end_time required")
}
start := filter.StartTime.UTC()
end := filter.EndTime.UTC()
if start.After(end) {
return nil, fmt.Errorf("start_time must be <= end_time")
}
window := end.Sub(start)
if window <= 0 {
return nil, fmt.Errorf("invalid time window")
}
if window > time.Hour {
return nil, fmt.Errorf("window too large")
}
usageJoin, usageWhere, usageArgs, next := buildUsageWhere(filter, start, end, 1)
errorWhere, errorArgs, _ := buildErrorWhere(filter, start, end, next)
q := `
WITH usage_buckets AS (
SELECT
date_trunc('minute', ul.created_at) AS bucket,
COALESCE(COUNT(*), 0) AS success_count,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_sum
FROM usage_logs ul
` + usageJoin + `
` + usageWhere + `
GROUP BY 1
),
error_buckets AS (
SELECT
date_trunc('minute', created_at) AS bucket,
COALESCE(COUNT(*), 0) AS error_count
FROM ops_error_logs
` + errorWhere + `
AND COALESCE(status_code, 0) >= 400
GROUP BY 1
),
combined AS (
SELECT
COALESCE(u.bucket, e.bucket) AS bucket,
COALESCE(u.success_count, 0) AS success_count,
COALESCE(u.token_sum, 0) AS token_sum,
COALESCE(e.error_count, 0) AS error_count,
COALESCE(u.success_count, 0) + COALESCE(e.error_count, 0) AS request_total
FROM usage_buckets u
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket
)
SELECT
COALESCE(SUM(success_count), 0) AS success_total,
COALESCE(SUM(error_count), 0) AS error_total,
COALESCE(SUM(token_sum), 0) AS token_total,
COALESCE(MAX(request_total), 0) AS peak_requests_per_min,
COALESCE(MAX(token_sum), 0) AS peak_tokens_per_min
FROM combined`
args := append(usageArgs, errorArgs...)
var successCount int64
var errorTotal int64
var tokenConsumed int64
var peakRequestsPerMin int64
var peakTokensPerMin int64
if err := r.db.QueryRowContext(ctx, q, args...).Scan(
&successCount,
&errorTotal,
&tokenConsumed,
&peakRequestsPerMin,
&peakTokensPerMin,
); err != nil {
return nil, err
}
windowSeconds := window.Seconds()
if windowSeconds <= 0 {
windowSeconds = 1
}
requestCountTotal := successCount + errorTotal
qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds)
tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds)
// Keep "current" consistent with the dashboard overview semantics: last 1 minute.
// This remains "within the selected window" since end=start+window.
qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end)
if err != nil {
return nil, err
}
qpsPeak := roundTo1DP(float64(peakRequestsPerMin) / 60.0)
tpsPeak := roundTo1DP(float64(peakTokensPerMin) / 60.0)
return &service.OpsRealtimeTrafficSummary{
StartTime: start,
EndTime: end,
Platform: strings.TrimSpace(filter.Platform),
GroupID: filter.GroupID,
QPS: service.OpsRateSummary{
Current: qpsCurrent,
Peak: qpsPeak,
Avg: qpsAvg,
},
TPS: service.OpsRateSummary{
Current: tpsCurrent,
Peak: tpsPeak,
Avg: tpsAvg,
},
}, nil
}

View File

@@ -0,0 +1,286 @@
package repository
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (r *opsRepository) ListRequestDetails(ctx context.Context, filter *service.OpsRequestDetailFilter) ([]*service.OpsRequestDetail, int64, error) {
if r == nil || r.db == nil {
return nil, 0, fmt.Errorf("nil ops repository")
}
page, pageSize, startTime, endTime := filter.Normalize()
offset := (page - 1) * pageSize
conditions := make([]string, 0, 16)
args := make([]any, 0, 24)
// Placeholders $1/$2 reserved for time window inside the CTE.
args = append(args, startTime.UTC(), endTime.UTC())
addCondition := func(condition string, values ...any) {
conditions = append(conditions, condition)
args = append(args, values...)
}
if filter != nil {
if kind := strings.TrimSpace(strings.ToLower(filter.Kind)); kind != "" && kind != "all" {
if kind != string(service.OpsRequestKindSuccess) && kind != string(service.OpsRequestKindError) {
return nil, 0, fmt.Errorf("invalid kind")
}
addCondition(fmt.Sprintf("kind = $%d", len(args)+1), kind)
}
if platform := strings.TrimSpace(strings.ToLower(filter.Platform)); platform != "" {
addCondition(fmt.Sprintf("platform = $%d", len(args)+1), platform)
}
if filter.GroupID != nil && *filter.GroupID > 0 {
addCondition(fmt.Sprintf("group_id = $%d", len(args)+1), *filter.GroupID)
}
if filter.UserID != nil && *filter.UserID > 0 {
addCondition(fmt.Sprintf("user_id = $%d", len(args)+1), *filter.UserID)
}
if filter.APIKeyID != nil && *filter.APIKeyID > 0 {
addCondition(fmt.Sprintf("api_key_id = $%d", len(args)+1), *filter.APIKeyID)
}
if filter.AccountID != nil && *filter.AccountID > 0 {
addCondition(fmt.Sprintf("account_id = $%d", len(args)+1), *filter.AccountID)
}
if model := strings.TrimSpace(filter.Model); model != "" {
addCondition(fmt.Sprintf("model = $%d", len(args)+1), model)
}
if requestID := strings.TrimSpace(filter.RequestID); requestID != "" {
addCondition(fmt.Sprintf("request_id = $%d", len(args)+1), requestID)
}
if q := strings.TrimSpace(filter.Query); q != "" {
like := "%" + strings.ToLower(q) + "%"
startIdx := len(args) + 1
addCondition(
fmt.Sprintf("(LOWER(COALESCE(request_id,'')) LIKE $%d OR LOWER(COALESCE(model,'')) LIKE $%d OR LOWER(COALESCE(message,'')) LIKE $%d)",
startIdx, startIdx+1, startIdx+2,
),
like, like, like,
)
}
if filter.MinDurationMs != nil {
addCondition(fmt.Sprintf("duration_ms >= $%d", len(args)+1), *filter.MinDurationMs)
}
if filter.MaxDurationMs != nil {
addCondition(fmt.Sprintf("duration_ms <= $%d", len(args)+1), *filter.MaxDurationMs)
}
}
where := ""
if len(conditions) > 0 {
where = "WHERE " + strings.Join(conditions, " AND ")
}
cte := `
WITH combined AS (
SELECT
'success'::TEXT AS kind,
ul.created_at AS created_at,
ul.request_id AS request_id,
COALESCE(NULLIF(g.platform, ''), NULLIF(a.platform, ''), '') AS platform,
ul.model AS model,
ul.duration_ms AS duration_ms,
NULL::INT AS status_code,
NULL::BIGINT AS error_id,
NULL::TEXT AS phase,
NULL::TEXT AS severity,
NULL::TEXT AS message,
ul.user_id AS user_id,
ul.api_key_id AS api_key_id,
ul.account_id AS account_id,
ul.group_id AS group_id,
ul.stream AS stream
FROM usage_logs ul
LEFT JOIN groups g ON g.id = ul.group_id
LEFT JOIN accounts a ON a.id = ul.account_id
WHERE ul.created_at >= $1 AND ul.created_at < $2
UNION ALL
SELECT
'error'::TEXT AS kind,
o.created_at AS created_at,
COALESCE(NULLIF(o.request_id,''), NULLIF(o.client_request_id,''), '') AS request_id,
COALESCE(NULLIF(o.platform, ''), NULLIF(g.platform, ''), NULLIF(a.platform, ''), '') AS platform,
o.model AS model,
o.duration_ms AS duration_ms,
o.status_code AS status_code,
o.id AS error_id,
o.error_phase AS phase,
o.severity AS severity,
o.error_message AS message,
o.user_id AS user_id,
o.api_key_id AS api_key_id,
o.account_id AS account_id,
o.group_id AS group_id,
o.stream AS stream
FROM ops_error_logs o
LEFT JOIN groups g ON g.id = o.group_id
LEFT JOIN accounts a ON a.id = o.account_id
WHERE o.created_at >= $1 AND o.created_at < $2
AND COALESCE(o.status_code, 0) >= 400
)
`
countQuery := fmt.Sprintf(`%s SELECT COUNT(1) FROM combined %s`, cte, where)
var total int64
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
if err == sql.ErrNoRows {
total = 0
} else {
return nil, 0, err
}
}
sort := "ORDER BY created_at DESC"
if filter != nil {
switch strings.TrimSpace(strings.ToLower(filter.Sort)) {
case "", "created_at_desc":
// default
case "duration_desc":
sort = "ORDER BY duration_ms DESC NULLS LAST, created_at DESC"
default:
return nil, 0, fmt.Errorf("invalid sort")
}
}
listQuery := fmt.Sprintf(`
%s
SELECT
kind,
created_at,
request_id,
platform,
model,
duration_ms,
status_code,
error_id,
phase,
severity,
message,
user_id,
api_key_id,
account_id,
group_id,
stream
FROM combined
%s
%s
LIMIT $%d OFFSET $%d
`, cte, where, sort, len(args)+1, len(args)+2)
listArgs := append(append([]any{}, args...), pageSize, offset)
rows, err := r.db.QueryContext(ctx, listQuery, listArgs...)
if err != nil {
return nil, 0, err
}
defer func() { _ = rows.Close() }()
toIntPtr := func(v sql.NullInt64) *int {
if !v.Valid {
return nil
}
i := int(v.Int64)
return &i
}
toInt64Ptr := func(v sql.NullInt64) *int64 {
if !v.Valid {
return nil
}
i := v.Int64
return &i
}
out := make([]*service.OpsRequestDetail, 0, pageSize)
for rows.Next() {
var (
kind string
createdAt time.Time
requestID sql.NullString
platform sql.NullString
model sql.NullString
durationMs sql.NullInt64
statusCode sql.NullInt64
errorID sql.NullInt64
phase sql.NullString
severity sql.NullString
message sql.NullString
userID sql.NullInt64
apiKeyID sql.NullInt64
accountID sql.NullInt64
groupID sql.NullInt64
stream bool
)
if err := rows.Scan(
&kind,
&createdAt,
&requestID,
&platform,
&model,
&durationMs,
&statusCode,
&errorID,
&phase,
&severity,
&message,
&userID,
&apiKeyID,
&accountID,
&groupID,
&stream,
); err != nil {
return nil, 0, err
}
item := &service.OpsRequestDetail{
Kind: service.OpsRequestKind(kind),
CreatedAt: createdAt,
RequestID: strings.TrimSpace(requestID.String),
Platform: strings.TrimSpace(platform.String),
Model: strings.TrimSpace(model.String),
DurationMs: toIntPtr(durationMs),
StatusCode: toIntPtr(statusCode),
ErrorID: toInt64Ptr(errorID),
Phase: phase.String,
Severity: severity.String,
Message: message.String,
UserID: toInt64Ptr(userID),
APIKeyID: toInt64Ptr(apiKeyID),
AccountID: toInt64Ptr(accountID),
GroupID: toInt64Ptr(groupID),
Stream: stream,
}
if item.Platform == "" {
item.Platform = "unknown"
}
out = append(out, item)
}
if err := rows.Err(); err != nil {
return nil, 0, err
}
return out, total, nil
}

View File

@@ -0,0 +1,573 @@
package repository
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (r *opsRepository) GetThroughputTrend(ctx context.Context, filter *service.OpsDashboardFilter, bucketSeconds int) (*service.OpsThroughputTrendResponse, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if filter == nil {
return nil, fmt.Errorf("nil filter")
}
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
return nil, fmt.Errorf("start_time/end_time required")
}
if bucketSeconds <= 0 {
bucketSeconds = 60
}
if bucketSeconds != 60 && bucketSeconds != 300 && bucketSeconds != 3600 {
// Keep a small, predictable set of supported buckets for now.
bucketSeconds = 60
}
start := filter.StartTime.UTC()
end := filter.EndTime.UTC()
usageJoin, usageWhere, usageArgs, next := buildUsageWhere(filter, start, end, 1)
errorWhere, errorArgs, _ := buildErrorWhere(filter, start, end, next)
usageBucketExpr := opsBucketExprForUsage(bucketSeconds)
errorBucketExpr := opsBucketExprForError(bucketSeconds)
q := `
WITH usage_buckets AS (
SELECT ` + usageBucketExpr + ` AS bucket,
COUNT(*) AS success_count,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_consumed
FROM usage_logs ul
` + usageJoin + `
` + usageWhere + `
GROUP BY 1
),
error_buckets AS (
SELECT ` + errorBucketExpr + ` AS bucket,
COUNT(*) AS error_count
FROM ops_error_logs
` + errorWhere + `
AND COALESCE(status_code, 0) >= 400
GROUP BY 1
),
combined AS (
SELECT COALESCE(u.bucket, e.bucket) AS bucket,
COALESCE(u.success_count, 0) AS success_count,
COALESCE(e.error_count, 0) AS error_count,
COALESCE(u.token_consumed, 0) AS token_consumed
FROM usage_buckets u
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket
)
SELECT
bucket,
(success_count + error_count) AS request_count,
token_consumed
FROM combined
ORDER BY bucket ASC`
args := append(usageArgs, errorArgs...)
rows, err := r.db.QueryContext(ctx, q, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
points := make([]*service.OpsThroughputTrendPoint, 0, 256)
for rows.Next() {
var bucket time.Time
var requests int64
var tokens sql.NullInt64
if err := rows.Scan(&bucket, &requests, &tokens); err != nil {
return nil, err
}
tokenConsumed := int64(0)
if tokens.Valid {
tokenConsumed = tokens.Int64
}
denom := float64(bucketSeconds)
if denom <= 0 {
denom = 60
}
qps := roundTo1DP(float64(requests) / denom)
tps := roundTo1DP(float64(tokenConsumed) / denom)
points = append(points, &service.OpsThroughputTrendPoint{
BucketStart: bucket.UTC(),
RequestCount: requests,
TokenConsumed: tokenConsumed,
QPS: qps,
TPS: tps,
})
}
if err := rows.Err(); err != nil {
return nil, err
}
// Fill missing buckets with zeros so charts render continuous timelines.
points = fillOpsThroughputBuckets(start, end, bucketSeconds, points)
var byPlatform []*service.OpsThroughputPlatformBreakdownItem
var topGroups []*service.OpsThroughputGroupBreakdownItem
platform := ""
if filter != nil {
platform = strings.TrimSpace(strings.ToLower(filter.Platform))
}
groupID := (*int64)(nil)
if filter != nil {
groupID = filter.GroupID
}
// Drilldown helpers:
// - No platform/group: totals by platform
// - Platform selected but no group: top groups in that platform
if platform == "" && (groupID == nil || *groupID <= 0) {
items, err := r.getThroughputBreakdownByPlatform(ctx, start, end)
if err != nil {
return nil, err
}
byPlatform = items
} else if platform != "" && (groupID == nil || *groupID <= 0) {
items, err := r.getThroughputTopGroupsByPlatform(ctx, start, end, platform, 10)
if err != nil {
return nil, err
}
topGroups = items
}
return &service.OpsThroughputTrendResponse{
Bucket: opsBucketLabel(bucketSeconds),
Points: points,
ByPlatform: byPlatform,
TopGroups: topGroups,
}, nil
}
func (r *opsRepository) getThroughputBreakdownByPlatform(ctx context.Context, start, end time.Time) ([]*service.OpsThroughputPlatformBreakdownItem, error) {
q := `
WITH usage_totals AS (
SELECT COALESCE(NULLIF(g.platform,''), a.platform) AS platform,
COUNT(*) AS success_count,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_consumed
FROM usage_logs ul
LEFT JOIN groups g ON g.id = ul.group_id
LEFT JOIN accounts a ON a.id = ul.account_id
WHERE ul.created_at >= $1 AND ul.created_at < $2
GROUP BY 1
),
error_totals AS (
SELECT platform,
COUNT(*) AS error_count
FROM ops_error_logs
WHERE created_at >= $1 AND created_at < $2
AND COALESCE(status_code, 0) >= 400
AND is_count_tokens = FALSE -- 排除 count_tokens 请求的错误
GROUP BY 1
),
combined AS (
SELECT COALESCE(u.platform, e.platform) AS platform,
COALESCE(u.success_count, 0) AS success_count,
COALESCE(e.error_count, 0) AS error_count,
COALESCE(u.token_consumed, 0) AS token_consumed
FROM usage_totals u
FULL OUTER JOIN error_totals e ON u.platform = e.platform
)
SELECT platform, (success_count + error_count) AS request_count, token_consumed
FROM combined
WHERE platform IS NOT NULL AND platform <> ''
ORDER BY request_count DESC`
rows, err := r.db.QueryContext(ctx, q, start, end)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
items := make([]*service.OpsThroughputPlatformBreakdownItem, 0, 8)
for rows.Next() {
var platform string
var requests int64
var tokens sql.NullInt64
if err := rows.Scan(&platform, &requests, &tokens); err != nil {
return nil, err
}
tokenConsumed := int64(0)
if tokens.Valid {
tokenConsumed = tokens.Int64
}
items = append(items, &service.OpsThroughputPlatformBreakdownItem{
Platform: platform,
RequestCount: requests,
TokenConsumed: tokenConsumed,
})
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
func (r *opsRepository) getThroughputTopGroupsByPlatform(ctx context.Context, start, end time.Time, platform string, limit int) ([]*service.OpsThroughputGroupBreakdownItem, error) {
if strings.TrimSpace(platform) == "" {
return nil, nil
}
if limit <= 0 || limit > 100 {
limit = 10
}
q := `
WITH usage_totals AS (
SELECT ul.group_id AS group_id,
g.name AS group_name,
COUNT(*) AS success_count,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_consumed
FROM usage_logs ul
JOIN groups g ON g.id = ul.group_id
WHERE ul.created_at >= $1 AND ul.created_at < $2
AND g.platform = $3
GROUP BY 1, 2
),
error_totals AS (
SELECT group_id,
COUNT(*) AS error_count
FROM ops_error_logs
WHERE created_at >= $1 AND created_at < $2
AND platform = $3
AND group_id IS NOT NULL
AND COALESCE(status_code, 0) >= 400
AND is_count_tokens = FALSE -- 排除 count_tokens 请求的错误
GROUP BY 1
),
combined AS (
SELECT COALESCE(u.group_id, e.group_id) AS group_id,
COALESCE(u.group_name, g2.name, '') AS group_name,
COALESCE(u.success_count, 0) AS success_count,
COALESCE(e.error_count, 0) AS error_count,
COALESCE(u.token_consumed, 0) AS token_consumed
FROM usage_totals u
FULL OUTER JOIN error_totals e ON u.group_id = e.group_id
LEFT JOIN groups g2 ON g2.id = COALESCE(u.group_id, e.group_id)
)
SELECT group_id, group_name, (success_count + error_count) AS request_count, token_consumed
FROM combined
WHERE group_id IS NOT NULL
ORDER BY request_count DESC
LIMIT $4`
rows, err := r.db.QueryContext(ctx, q, start, end, platform, limit)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
items := make([]*service.OpsThroughputGroupBreakdownItem, 0, limit)
for rows.Next() {
var groupID int64
var groupName sql.NullString
var requests int64
var tokens sql.NullInt64
if err := rows.Scan(&groupID, &groupName, &requests, &tokens); err != nil {
return nil, err
}
tokenConsumed := int64(0)
if tokens.Valid {
tokenConsumed = tokens.Int64
}
name := ""
if groupName.Valid {
name = groupName.String
}
items = append(items, &service.OpsThroughputGroupBreakdownItem{
GroupID: groupID,
GroupName: name,
RequestCount: requests,
TokenConsumed: tokenConsumed,
})
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
func opsBucketExprForUsage(bucketSeconds int) string {
switch bucketSeconds {
case 3600:
return "date_trunc('hour', ul.created_at)"
case 300:
// 5-minute buckets in UTC.
return "to_timestamp(floor(extract(epoch from ul.created_at) / 300) * 300)"
default:
return "date_trunc('minute', ul.created_at)"
}
}
func opsBucketExprForError(bucketSeconds int) string {
switch bucketSeconds {
case 3600:
return "date_trunc('hour', created_at)"
case 300:
return "to_timestamp(floor(extract(epoch from created_at) / 300) * 300)"
default:
return "date_trunc('minute', created_at)"
}
}
func opsBucketLabel(bucketSeconds int) string {
if bucketSeconds <= 0 {
return "1m"
}
if bucketSeconds%3600 == 0 {
h := bucketSeconds / 3600
if h <= 0 {
h = 1
}
return fmt.Sprintf("%dh", h)
}
m := bucketSeconds / 60
if m <= 0 {
m = 1
}
return fmt.Sprintf("%dm", m)
}
func opsFloorToBucketStart(t time.Time, bucketSeconds int) time.Time {
t = t.UTC()
if bucketSeconds <= 0 {
bucketSeconds = 60
}
secs := t.Unix()
floored := secs - (secs % int64(bucketSeconds))
return time.Unix(floored, 0).UTC()
}
func fillOpsThroughputBuckets(start, end time.Time, bucketSeconds int, points []*service.OpsThroughputTrendPoint) []*service.OpsThroughputTrendPoint {
if bucketSeconds <= 0 {
bucketSeconds = 60
}
if !start.Before(end) {
return points
}
endMinus := end.Add(-time.Nanosecond)
if endMinus.Before(start) {
return points
}
first := opsFloorToBucketStart(start, bucketSeconds)
last := opsFloorToBucketStart(endMinus, bucketSeconds)
step := time.Duration(bucketSeconds) * time.Second
existing := make(map[int64]*service.OpsThroughputTrendPoint, len(points))
for _, p := range points {
if p == nil {
continue
}
existing[p.BucketStart.UTC().Unix()] = p
}
out := make([]*service.OpsThroughputTrendPoint, 0, int(last.Sub(first)/step)+1)
for cursor := first; !cursor.After(last); cursor = cursor.Add(step) {
if p, ok := existing[cursor.Unix()]; ok && p != nil {
out = append(out, p)
continue
}
out = append(out, &service.OpsThroughputTrendPoint{
BucketStart: cursor,
RequestCount: 0,
TokenConsumed: 0,
QPS: 0,
TPS: 0,
})
}
return out
}
func (r *opsRepository) GetErrorTrend(ctx context.Context, filter *service.OpsDashboardFilter, bucketSeconds int) (*service.OpsErrorTrendResponse, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if filter == nil {
return nil, fmt.Errorf("nil filter")
}
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
return nil, fmt.Errorf("start_time/end_time required")
}
if bucketSeconds <= 0 {
bucketSeconds = 60
}
if bucketSeconds != 60 && bucketSeconds != 300 && bucketSeconds != 3600 {
bucketSeconds = 60
}
start := filter.StartTime.UTC()
end := filter.EndTime.UTC()
where, args, _ := buildErrorWhere(filter, start, end, 1)
bucketExpr := opsBucketExprForError(bucketSeconds)
q := `
SELECT
` + bucketExpr + ` AS bucket,
COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400) AS error_total,
COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400 AND is_business_limited) AS business_limited,
COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400 AND NOT is_business_limited) AS error_sla,
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)) AS upstream_excl,
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) = 429) AS upstream_429,
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) = 529) AS upstream_529
FROM ops_error_logs
` + where + `
GROUP BY 1
ORDER BY 1 ASC`
rows, err := r.db.QueryContext(ctx, q, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
points := make([]*service.OpsErrorTrendPoint, 0, 256)
for rows.Next() {
var bucket time.Time
var total, businessLimited, sla, upstreamExcl, upstream429, upstream529 int64
if err := rows.Scan(&bucket, &total, &businessLimited, &sla, &upstreamExcl, &upstream429, &upstream529); err != nil {
return nil, err
}
points = append(points, &service.OpsErrorTrendPoint{
BucketStart: bucket.UTC(),
ErrorCountTotal: total,
BusinessLimitedCount: businessLimited,
ErrorCountSLA: sla,
UpstreamErrorCountExcl429529: upstreamExcl,
Upstream429Count: upstream429,
Upstream529Count: upstream529,
})
}
if err := rows.Err(); err != nil {
return nil, err
}
points = fillOpsErrorTrendBuckets(start, end, bucketSeconds, points)
return &service.OpsErrorTrendResponse{
Bucket: opsBucketLabel(bucketSeconds),
Points: points,
}, nil
}
func fillOpsErrorTrendBuckets(start, end time.Time, bucketSeconds int, points []*service.OpsErrorTrendPoint) []*service.OpsErrorTrendPoint {
if bucketSeconds <= 0 {
bucketSeconds = 60
}
if !start.Before(end) {
return points
}
endMinus := end.Add(-time.Nanosecond)
if endMinus.Before(start) {
return points
}
first := opsFloorToBucketStart(start, bucketSeconds)
last := opsFloorToBucketStart(endMinus, bucketSeconds)
step := time.Duration(bucketSeconds) * time.Second
existing := make(map[int64]*service.OpsErrorTrendPoint, len(points))
for _, p := range points {
if p == nil {
continue
}
existing[p.BucketStart.UTC().Unix()] = p
}
out := make([]*service.OpsErrorTrendPoint, 0, int(last.Sub(first)/step)+1)
for cursor := first; !cursor.After(last); cursor = cursor.Add(step) {
if p, ok := existing[cursor.Unix()]; ok && p != nil {
out = append(out, p)
continue
}
out = append(out, &service.OpsErrorTrendPoint{
BucketStart: cursor,
ErrorCountTotal: 0,
BusinessLimitedCount: 0,
ErrorCountSLA: 0,
UpstreamErrorCountExcl429529: 0,
Upstream429Count: 0,
Upstream529Count: 0,
})
}
return out
}
func (r *opsRepository) GetErrorDistribution(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsErrorDistributionResponse, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if filter == nil {
return nil, fmt.Errorf("nil filter")
}
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
return nil, fmt.Errorf("start_time/end_time required")
}
start := filter.StartTime.UTC()
end := filter.EndTime.UTC()
where, args, _ := buildErrorWhere(filter, start, end, 1)
q := `
SELECT
COALESCE(upstream_status_code, status_code, 0) AS status_code,
COUNT(*) AS total,
COUNT(*) FILTER (WHERE NOT is_business_limited) AS sla,
COUNT(*) FILTER (WHERE is_business_limited) AS business_limited
FROM ops_error_logs
` + where + `
AND COALESCE(status_code, 0) >= 400
GROUP BY 1
ORDER BY total DESC
LIMIT 20`
rows, err := r.db.QueryContext(ctx, q, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
items := make([]*service.OpsErrorDistributionItem, 0, 16)
var total int64
for rows.Next() {
var statusCode int
var cntTotal, cntSLA, cntBiz int64
if err := rows.Scan(&statusCode, &cntTotal, &cntSLA, &cntBiz); err != nil {
return nil, err
}
total += cntTotal
items = append(items, &service.OpsErrorDistributionItem{
StatusCode: statusCode,
Total: cntTotal,
SLA: cntSLA,
BusinessLimited: cntBiz,
})
}
if err := rows.Err(); err != nil {
return nil, err
}
return &service.OpsErrorDistributionResponse{
Total: total,
Items: items,
}, nil
}

View File

@@ -0,0 +1,50 @@
package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (r *opsRepository) GetWindowStats(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsWindowStats, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if filter == nil {
return nil, fmt.Errorf("nil filter")
}
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
return nil, fmt.Errorf("start_time/end_time required")
}
start := filter.StartTime.UTC()
end := filter.EndTime.UTC()
if start.After(end) {
return nil, fmt.Errorf("start_time must be <= end_time")
}
// Bound excessively large windows to prevent accidental heavy queries.
if end.Sub(start) > 24*time.Hour {
return nil, fmt.Errorf("window too large")
}
successCount, tokenConsumed, err := r.queryUsageCounts(ctx, filter, start, end)
if err != nil {
return nil, err
}
errorTotal, _, _, _, _, _, err := r.queryErrorCounts(ctx, filter, start, end)
if err != nil {
return nil, err
}
return &service.OpsWindowStats{
StartTime: start,
EndTime: end,
SuccessCount: successCount,
ErrorCountTotal: errorTotal,
TokenConsumed: tokenConsumed,
}, nil
}

View File

@@ -0,0 +1,16 @@
package repository
import "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
func paginationResultFromTotal(total int64, params pagination.PaginationParams) *pagination.PaginationResult {
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}
}

View File

@@ -0,0 +1,81 @@
package repository
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type pricingRemoteClient struct {
httpClient *http.Client
}
// NewPricingRemoteClient 创建定价数据远程客户端
// proxyURL 为空时直连,支持 http/https/socks5/socks5h 协议
func NewPricingRemoteClient(proxyURL string) service.PricingRemoteClient {
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
ProxyURL: proxyURL,
})
if err != nil {
sharedClient = &http.Client{Timeout: 30 * time.Second}
}
return &pricingRemoteClient{
httpClient: sharedClient,
}
}
func (c *pricingRemoteClient) FetchPricingJSON(ctx context.Context, url string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
func (c *pricingRemoteClient) FetchHashText(ctx context.Context, url string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", err
}
resp, err := c.httpClient.Do(req)
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
// 哈希文件格式hash filename 或者纯 hash
hash := strings.TrimSpace(string(body))
parts := strings.Fields(hash)
if len(parts) > 0 {
return parts[0], nil
}
return hash, nil
}

View File

@@ -0,0 +1,145 @@
package repository
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type PricingServiceSuite struct {
suite.Suite
ctx context.Context
srv *httptest.Server
client *pricingRemoteClient
}
func (s *PricingServiceSuite) SetupTest() {
s.ctx = context.Background()
client, ok := NewPricingRemoteClient("").(*pricingRemoteClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
}
func (s *PricingServiceSuite) TearDownTest() {
if s.srv != nil {
s.srv.Close()
s.srv = nil
}
}
func (s *PricingServiceSuite) setupServer(handler http.HandlerFunc) {
s.srv = newLocalTestServer(s.T(), handler)
}
func (s *PricingServiceSuite) TestFetchPricingJSON_Success() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/ok" {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"ok":true}`))
return
}
w.WriteHeader(http.StatusInternalServerError)
}))
body, err := s.client.FetchPricingJSON(s.ctx, s.srv.URL+"/ok")
require.NoError(s.T(), err, "FetchPricingJSON")
require.Equal(s.T(), `{"ok":true}`, string(body), "body mismatch")
}
func (s *PricingServiceSuite) TestFetchPricingJSON_NonOKStatus() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
_, err := s.client.FetchPricingJSON(s.ctx, s.srv.URL+"/err")
require.Error(s.T(), err, "expected error for non-200 status")
}
func (s *PricingServiceSuite) TestFetchHashText_ParsesFields() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/hashfile":
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("abc123 model_prices.json\n"))
case "/hashonly":
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("def456\n"))
default:
w.WriteHeader(http.StatusNotFound)
}
}))
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/hashfile")
require.NoError(s.T(), err, "FetchHashText")
require.Equal(s.T(), "abc123", hash, "hash mismatch")
hash2, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/hashonly")
require.NoError(s.T(), err, "FetchHashText")
require.Equal(s.T(), "def456", hash2, "hash mismatch")
}
func (s *PricingServiceSuite) TestFetchHashText_NonOKStatus() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
_, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/nope")
require.Error(s.T(), err, "expected error for non-200 status")
}
func (s *PricingServiceSuite) TestFetchPricingJSON_InvalidURL() {
_, err := s.client.FetchPricingJSON(s.ctx, "://invalid-url")
require.Error(s.T(), err, "expected error for invalid URL")
}
func (s *PricingServiceSuite) TestFetchHashText_EmptyBody() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
// empty body
}))
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/empty")
require.NoError(s.T(), err, "FetchHashText empty body should not error")
require.Equal(s.T(), "", hash, "expected empty hash")
}
func (s *PricingServiceSuite) TestFetchHashText_WhitespaceOnly() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(" \n"))
}))
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/ws")
require.NoError(s.T(), err, "FetchHashText whitespace body should not error")
require.Equal(s.T(), "", hash, "expected empty hash after trimming")
}
func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() {
started := make(chan struct{})
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(started)
<-r.Context().Done()
}))
ctx, cancel := context.WithCancel(s.ctx)
done := make(chan error, 1)
go func() {
_, err := s.client.FetchPricingJSON(ctx, s.srv.URL+"/block")
done <- err
}()
<-started
cancel()
err := <-done
require.Error(s.T(), err)
}
func TestPricingServiceSuite(t *testing.T) {
suite.Run(t, new(PricingServiceSuite))
}

View File

@@ -0,0 +1,273 @@
package repository
import (
"context"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type promoCodeRepository struct {
client *dbent.Client
}
func NewPromoCodeRepository(client *dbent.Client) service.PromoCodeRepository {
return &promoCodeRepository{client: client}
}
func (r *promoCodeRepository) Create(ctx context.Context, code *service.PromoCode) error {
client := clientFromContext(ctx, r.client)
builder := client.PromoCode.Create().
SetCode(code.Code).
SetBonusAmount(code.BonusAmount).
SetMaxUses(code.MaxUses).
SetUsedCount(code.UsedCount).
SetStatus(code.Status).
SetNotes(code.Notes)
if code.ExpiresAt != nil {
builder.SetExpiresAt(*code.ExpiresAt)
}
created, err := builder.Save(ctx)
if err != nil {
return err
}
code.ID = created.ID
code.CreatedAt = created.CreatedAt
code.UpdatedAt = created.UpdatedAt
return nil
}
func (r *promoCodeRepository) GetByID(ctx context.Context, id int64) (*service.PromoCode, error) {
m, err := r.client.PromoCode.Query().
Where(promocode.IDEQ(id)).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrPromoCodeNotFound
}
return nil, err
}
return promoCodeEntityToService(m), nil
}
func (r *promoCodeRepository) GetByCode(ctx context.Context, code string) (*service.PromoCode, error) {
m, err := r.client.PromoCode.Query().
Where(promocode.CodeEqualFold(code)).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrPromoCodeNotFound
}
return nil, err
}
return promoCodeEntityToService(m), nil
}
func (r *promoCodeRepository) GetByCodeForUpdate(ctx context.Context, code string) (*service.PromoCode, error) {
client := clientFromContext(ctx, r.client)
m, err := client.PromoCode.Query().
Where(promocode.CodeEqualFold(code)).
ForUpdate().
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrPromoCodeNotFound
}
return nil, err
}
return promoCodeEntityToService(m), nil
}
func (r *promoCodeRepository) Update(ctx context.Context, code *service.PromoCode) error {
client := clientFromContext(ctx, r.client)
builder := client.PromoCode.UpdateOneID(code.ID).
SetCode(code.Code).
SetBonusAmount(code.BonusAmount).
SetMaxUses(code.MaxUses).
SetUsedCount(code.UsedCount).
SetStatus(code.Status).
SetNotes(code.Notes)
if code.ExpiresAt != nil {
builder.SetExpiresAt(*code.ExpiresAt)
} else {
builder.ClearExpiresAt()
}
updated, err := builder.Save(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return service.ErrPromoCodeNotFound
}
return err
}
code.UpdatedAt = updated.UpdatedAt
return nil
}
func (r *promoCodeRepository) Delete(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.PromoCode.Delete().Where(promocode.IDEQ(id)).Exec(ctx)
return err
}
func (r *promoCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.PromoCode, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "")
}
func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, search string) ([]service.PromoCode, *pagination.PaginationResult, error) {
q := r.client.PromoCode.Query()
if status != "" {
q = q.Where(promocode.StatusEQ(status))
}
if search != "" {
q = q.Where(promocode.CodeContainsFold(search))
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
codes, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(promocode.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outCodes := promoCodeEntitiesToService(codes)
return outCodes, paginationResultFromTotal(int64(total), params), nil
}
func (r *promoCodeRepository) CreateUsage(ctx context.Context, usage *service.PromoCodeUsage) error {
client := clientFromContext(ctx, r.client)
created, err := client.PromoCodeUsage.Create().
SetPromoCodeID(usage.PromoCodeID).
SetUserID(usage.UserID).
SetBonusAmount(usage.BonusAmount).
SetUsedAt(usage.UsedAt).
Save(ctx)
if err != nil {
return err
}
usage.ID = created.ID
return nil
}
func (r *promoCodeRepository) GetUsageByPromoCodeAndUser(ctx context.Context, promoCodeID, userID int64) (*service.PromoCodeUsage, error) {
m, err := r.client.PromoCodeUsage.Query().
Where(
promocodeusage.PromoCodeIDEQ(promoCodeID),
promocodeusage.UserIDEQ(userID),
).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, nil
}
return nil, err
}
return promoCodeUsageEntityToService(m), nil
}
func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]service.PromoCodeUsage, *pagination.PaginationResult, error) {
q := r.client.PromoCodeUsage.Query().
Where(promocodeusage.PromoCodeIDEQ(promoCodeID))
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
usages, err := q.
WithUser().
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(promocodeusage.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outUsages := promoCodeUsageEntitiesToService(usages)
return outUsages, paginationResultFromTotal(int64(total), params), nil
}
func (r *promoCodeRepository) IncrementUsedCount(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.PromoCode.UpdateOneID(id).
AddUsedCount(1).
Save(ctx)
return err
}
// Entity to Service conversions
func promoCodeEntityToService(m *dbent.PromoCode) *service.PromoCode {
if m == nil {
return nil
}
return &service.PromoCode{
ID: m.ID,
Code: m.Code,
BonusAmount: m.BonusAmount,
MaxUses: m.MaxUses,
UsedCount: m.UsedCount,
Status: m.Status,
ExpiresAt: m.ExpiresAt,
Notes: derefString(m.Notes),
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func promoCodeEntitiesToService(models []*dbent.PromoCode) []service.PromoCode {
out := make([]service.PromoCode, 0, len(models))
for i := range models {
if s := promoCodeEntityToService(models[i]); s != nil {
out = append(out, *s)
}
}
return out
}
func promoCodeUsageEntityToService(m *dbent.PromoCodeUsage) *service.PromoCodeUsage {
if m == nil {
return nil
}
out := &service.PromoCodeUsage{
ID: m.ID,
PromoCodeID: m.PromoCodeID,
UserID: m.UserID,
BonusAmount: m.BonusAmount,
UsedAt: m.UsedAt,
}
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)
}
return out
}
func promoCodeUsageEntitiesToService(models []*dbent.PromoCodeUsage) []service.PromoCodeUsage {
out := make([]service.PromoCodeUsage, 0, len(models))
for i := range models {
if s := promoCodeUsageEntityToService(models[i]); s != nil {
out = append(out, *s)
}
}
return out
}

View File

@@ -0,0 +1,74 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const proxyLatencyKeyPrefix = "proxy:latency:"
func proxyLatencyKey(proxyID int64) string {
return fmt.Sprintf("%s%d", proxyLatencyKeyPrefix, proxyID)
}
type proxyLatencyCache struct {
rdb *redis.Client
}
func NewProxyLatencyCache(rdb *redis.Client) service.ProxyLatencyCache {
return &proxyLatencyCache{rdb: rdb}
}
func (c *proxyLatencyCache) GetProxyLatencies(ctx context.Context, proxyIDs []int64) (map[int64]*service.ProxyLatencyInfo, error) {
results := make(map[int64]*service.ProxyLatencyInfo)
if len(proxyIDs) == 0 {
return results, nil
}
keys := make([]string, 0, len(proxyIDs))
for _, id := range proxyIDs {
keys = append(keys, proxyLatencyKey(id))
}
values, err := c.rdb.MGet(ctx, keys...).Result()
if err != nil {
return results, err
}
for i, raw := range values {
if raw == nil {
continue
}
var payload []byte
switch v := raw.(type) {
case string:
payload = []byte(v)
case []byte:
payload = v
default:
continue
}
var info service.ProxyLatencyInfo
if err := json.Unmarshal(payload, &info); err != nil {
continue
}
results[proxyIDs[i]] = &info
}
return results, nil
}
func (c *proxyLatencyCache) SetProxyLatency(ctx context.Context, proxyID int64, info *service.ProxyLatencyInfo) error {
if info == nil {
return nil
}
payload, err := json.Marshal(info)
if err != nil {
return err
}
return c.rdb.Set(ctx, proxyLatencyKey(proxyID), payload, 0).Err()
}

View File

@@ -0,0 +1,118 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
insecure := false
allowPrivate := false
validateResolvedIP := true
if cfg != nil {
insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
}
if insecure {
log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
}
return &proxyProbeService{
ipInfoURL: defaultIPInfoURL,
insecureSkipVerify: insecure,
allowPrivateHosts: allowPrivate,
validateResolvedIP: validateResolvedIP,
}
}
const (
defaultIPInfoURL = "http://ip-api.com/json/?lang=zh-CN"
defaultProxyProbeTimeout = 30 * time.Second
)
type proxyProbeService struct {
ipInfoURL string
insecureSkipVerify bool
allowPrivateHosts bool
validateResolvedIP bool
}
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: defaultProxyProbeTimeout,
InsecureSkipVerify: s.insecureSkipVerify,
ProxyStrict: true,
ValidateResolvedIP: s.validateResolvedIP,
AllowPrivateHosts: s.allowPrivateHosts,
})
if err != nil {
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
}
startTime := time.Now()
req, err := http.NewRequestWithContext(ctx, "GET", s.ipInfoURL, nil)
if err != nil {
return nil, 0, fmt.Errorf("failed to create request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, 0, fmt.Errorf("proxy connection failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
latencyMs := time.Since(startTime).Milliseconds()
if resp.StatusCode != http.StatusOK {
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
}
var ipInfo struct {
Status string `json:"status"`
Message string `json:"message"`
Query string `json:"query"`
City string `json:"city"`
Region string `json:"region"`
RegionName string `json:"regionName"`
Country string `json:"country"`
CountryCode string `json:"countryCode"`
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
}
if err := json.Unmarshal(body, &ipInfo); err != nil {
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
}
if strings.ToLower(ipInfo.Status) != "success" {
if ipInfo.Message == "" {
ipInfo.Message = "ip-api request failed"
}
return nil, latencyMs, fmt.Errorf("ip-api request failed: %s", ipInfo.Message)
}
region := ipInfo.RegionName
if region == "" {
region = ipInfo.Region
}
return &service.ProxyExitInfo{
IP: ipInfo.Query,
City: ipInfo.City,
Region: region,
Country: ipInfo.Country,
CountryCode: ipInfo.CountryCode,
}, latencyMs, nil
}

View File

@@ -0,0 +1,119 @@
package repository
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ProxyProbeServiceSuite struct {
suite.Suite
ctx context.Context
proxySrv *httptest.Server
prober *proxyProbeService
}
func (s *ProxyProbeServiceSuite) SetupTest() {
s.ctx = context.Background()
s.prober = &proxyProbeService{
ipInfoURL: "http://ip-api.test/json/?lang=zh-CN",
allowPrivateHosts: true,
}
}
func (s *ProxyProbeServiceSuite) TearDownTest() {
if s.proxySrv != nil {
s.proxySrv.Close()
s.proxySrv = nil
}
}
func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) {
s.proxySrv = newLocalTestServer(s.T(), handler)
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidProxyURL() {
_, _, err := s.prober.ProbeProxy(s.ctx, "://bad")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "failed to create proxy client")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_UnsupportedProxyScheme() {
_, _, err := s.prober.ProbeProxy(s.ctx, "ftp://127.0.0.1:1")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "failed to create proxy client")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
seen := make(chan string, 1)
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seen <- r.RequestURI
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`)
}))
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.NoError(s.T(), err, "ProbeProxy")
require.GreaterOrEqual(s.T(), latencyMs, int64(0), "unexpected latency")
require.Equal(s.T(), "1.2.3.4", info.IP)
require.Equal(s.T(), "c", info.City)
require.Equal(s.T(), "r", info.Region)
require.Equal(s.T(), "cc", info.Country)
require.Equal(s.T(), "CC", info.CountryCode)
// Verify proxy received the request
select {
case uri := <-seen:
require.Contains(s.T(), uri, "ip-api.test", "expected request to go through proxy")
default:
require.Fail(s.T(), "expected proxy to receive request")
}
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_NonOKStatus() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "status: 503")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidJSON() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-json")
}))
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "failed to parse response")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidIPInfoURL() {
s.prober.ipInfoURL = "://invalid-url"
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err, "expected error for invalid ipInfoURL")
}
func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
s.proxySrv.Close()
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
require.Error(s.T(), err, "expected error when proxy server is closed")
}
func TestProxyProbeServiceSuite(t *testing.T) {
suite.Run(t, new(ProxyProbeServiceSuite))
}

View File

@@ -0,0 +1,359 @@
package repository
import (
"context"
"database/sql"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
type sqlQuerier interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
type proxyRepository struct {
client *dbent.Client
sql sqlQuerier
}
func NewProxyRepository(client *dbent.Client, sqlDB *sql.DB) service.ProxyRepository {
return newProxyRepositoryWithSQL(client, sqlDB)
}
func newProxyRepositoryWithSQL(client *dbent.Client, sqlq sqlQuerier) *proxyRepository {
return &proxyRepository{client: client, sql: sqlq}
}
func (r *proxyRepository) Create(ctx context.Context, proxyIn *service.Proxy) error {
builder := r.client.Proxy.Create().
SetName(proxyIn.Name).
SetProtocol(proxyIn.Protocol).
SetHost(proxyIn.Host).
SetPort(proxyIn.Port).
SetStatus(proxyIn.Status)
if proxyIn.Username != "" {
builder.SetUsername(proxyIn.Username)
}
if proxyIn.Password != "" {
builder.SetPassword(proxyIn.Password)
}
created, err := builder.Save(ctx)
if err == nil {
applyProxyEntityToService(proxyIn, created)
}
return err
}
func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*service.Proxy, error) {
m, err := r.client.Proxy.Get(ctx, id)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrProxyNotFound
}
return nil, err
}
return proxyEntityToService(m), nil
}
func (r *proxyRepository) Update(ctx context.Context, proxyIn *service.Proxy) error {
builder := r.client.Proxy.UpdateOneID(proxyIn.ID).
SetName(proxyIn.Name).
SetProtocol(proxyIn.Protocol).
SetHost(proxyIn.Host).
SetPort(proxyIn.Port).
SetStatus(proxyIn.Status)
if proxyIn.Username != "" {
builder.SetUsername(proxyIn.Username)
} else {
builder.ClearUsername()
}
if proxyIn.Password != "" {
builder.SetPassword(proxyIn.Password)
} else {
builder.ClearPassword()
}
updated, err := builder.Save(ctx)
if err == nil {
applyProxyEntityToService(proxyIn, updated)
return nil
}
if dbent.IsNotFound(err) {
return service.ErrProxyNotFound
}
return err
}
func (r *proxyRepository) Delete(ctx context.Context, id int64) error {
_, err := r.client.Proxy.Delete().Where(proxy.IDEQ(id)).Exec(ctx)
return err
}
func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Proxy, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query
func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.Proxy, *pagination.PaginationResult, error) {
q := r.client.Proxy.Query()
if protocol != "" {
q = q.Where(proxy.ProtocolEQ(protocol))
}
if status != "" {
q = q.Where(proxy.StatusEQ(status))
}
if search != "" {
q = q.Where(proxy.NameContainsFold(search))
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
proxies, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(proxy.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outProxies := make([]service.Proxy, 0, len(proxies))
for i := range proxies {
outProxies = append(outProxies, *proxyEntityToService(proxies[i]))
}
return outProxies, paginationResultFromTotal(int64(total), params), nil
}
// ListWithFiltersAndAccountCount lists proxies with filters and includes account count per proxy
func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) {
q := r.client.Proxy.Query()
if protocol != "" {
q = q.Where(proxy.ProtocolEQ(protocol))
}
if status != "" {
q = q.Where(proxy.StatusEQ(status))
}
if search != "" {
q = q.Where(proxy.NameContainsFold(search))
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
proxies, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(proxy.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
// Get account counts
counts, err := r.GetAccountCountsForProxies(ctx)
if err != nil {
return nil, nil, err
}
// Build result with account counts
result := make([]service.ProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
proxyOut := proxyEntityToService(proxies[i])
if proxyOut == nil {
continue
}
result = append(result, service.ProxyWithAccountCount{
Proxy: *proxyOut,
AccountCount: counts[proxyOut.ID],
})
}
return result, paginationResultFromTotal(int64(total), params), nil
}
func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) {
proxies, err := r.client.Proxy.Query().
Where(proxy.StatusEQ(service.StatusActive)).
All(ctx)
if err != nil {
return nil, err
}
outProxies := make([]service.Proxy, 0, len(proxies))
for i := range proxies {
outProxies = append(outProxies, *proxyEntityToService(proxies[i]))
}
return outProxies, nil
}
// ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists
func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
q := r.client.Proxy.Query().
Where(proxy.HostEQ(host), proxy.PortEQ(port))
if username == "" {
q = q.Where(proxy.Or(proxy.UsernameIsNil(), proxy.UsernameEQ("")))
} else {
q = q.Where(proxy.UsernameEQ(username))
}
if password == "" {
q = q.Where(proxy.Or(proxy.PasswordIsNil(), proxy.PasswordEQ("")))
} else {
q = q.Where(proxy.PasswordEQ(password))
}
count, err := q.Count(ctx)
return count > 0, err
}
// CountAccountsByProxyID returns the number of accounts using a specific proxy
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
var count int64
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1 AND deleted_at IS NULL", []any{proxyID}, &count); err != nil {
return 0, err
}
return count, nil
}
func (r *proxyRepository) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
rows, err := r.sql.QueryContext(ctx, `
SELECT id, name, platform, type, notes
FROM accounts
WHERE proxy_id = $1 AND deleted_at IS NULL
ORDER BY id DESC
`, proxyID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
out := make([]service.ProxyAccountSummary, 0)
for rows.Next() {
var (
id int64
name string
platform string
accType string
notes sql.NullString
)
if err := rows.Scan(&id, &name, &platform, &accType, &notes); err != nil {
return nil, err
}
var notesPtr *string
if notes.Valid {
notesPtr = &notes.String
}
out = append(out, service.ProxyAccountSummary{
ID: id,
Name: name,
Platform: platform,
Type: accType,
Notes: notesPtr,
})
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (counts map[int64]int64, err error) {
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL AND deleted_at IS NULL GROUP BY proxy_id")
if err != nil {
return nil, err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = closeErr
counts = nil
}
}()
counts = make(map[int64]int64)
for rows.Next() {
var proxyID, count int64
if err = rows.Scan(&proxyID, &count); err != nil {
return nil, err
}
counts[proxyID] = count
}
if err = rows.Err(); err != nil {
return nil, err
}
return counts, nil
}
// ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending
func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) {
proxies, err := r.client.Proxy.Query().
Where(proxy.StatusEQ(service.StatusActive)).
Order(dbent.Desc(proxy.FieldCreatedAt)).
All(ctx)
if err != nil {
return nil, err
}
// Get account counts
counts, err := r.GetAccountCountsForProxies(ctx)
if err != nil {
return nil, err
}
// Build result with account counts
result := make([]service.ProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
proxyOut := proxyEntityToService(proxies[i])
if proxyOut == nil {
continue
}
result = append(result, service.ProxyWithAccountCount{
Proxy: *proxyOut,
AccountCount: counts[proxyOut.ID],
})
}
return result, nil
}
func proxyEntityToService(m *dbent.Proxy) *service.Proxy {
if m == nil {
return nil
}
out := &service.Proxy{
ID: m.ID,
Name: m.Name,
Protocol: m.Protocol,
Host: m.Host,
Port: m.Port,
Status: m.Status,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
if m.Username != nil {
out.Username = *m.Username
}
if m.Password != nil {
out.Password = *m.Password
}
return out
}
func applyProxyEntityToService(dst *service.Proxy, src *dbent.Proxy) {
if dst == nil || src == nil {
return
}
dst.ID = src.ID
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}

View File

@@ -0,0 +1,329 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
)
type ProxyRepoSuite struct {
suite.Suite
ctx context.Context
tx *dbent.Tx
repo *proxyRepository
}
func (s *ProxyRepoSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.tx = tx
s.repo = newProxyRepositoryWithSQL(tx.Client(), tx)
}
func TestProxyRepoSuite(t *testing.T) {
suite.Run(t, new(ProxyRepoSuite))
}
// --- Create / GetByID / Update / Delete ---
func (s *ProxyRepoSuite) TestCreate() {
proxy := &service.Proxy{
Name: "test-create",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Status: service.StatusActive,
}
err := s.repo.Create(s.ctx, proxy)
s.Require().NoError(err, "Create")
s.Require().NotZero(proxy.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, proxy.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("test-create", got.Name)
}
func (s *ProxyRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *ProxyRepoSuite) TestUpdate() {
proxy := &service.Proxy{
Name: "original",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Status: service.StatusActive,
}
s.Require().NoError(s.repo.Create(s.ctx, proxy))
proxy.Name = "updated"
err := s.repo.Update(s.ctx, proxy)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, proxy.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated", got.Name)
}
func (s *ProxyRepoSuite) TestDelete() {
proxy := &service.Proxy{
Name: "to-delete",
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Status: service.StatusActive,
}
s.Require().NoError(s.repo.Create(s.ctx, proxy))
err := s.repo.Delete(s.ctx, proxy.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, proxy.ID)
s.Require().Error(err, "expected error after delete")
}
// --- List / ListWithFilters ---
func (s *ProxyRepoSuite) TestList() {
s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
proxies, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
s.Require().Len(proxies, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "socks5", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "socks5", "", "")
s.Require().NoError(err)
s.Require().Len(proxies, 1)
s.Require().Equal("socks5", proxies[0].Protocol)
}
func (s *ProxyRepoSuite) TestListWithFilters_Status() {
s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusDisabled})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "")
s.Require().NoError(err)
s.Require().Len(proxies, 1)
s.Require().Equal(service.StatusDisabled, proxies[0].Status)
}
func (s *ProxyRepoSuite) TestListWithFilters_Search() {
s.mustCreateProxy(&service.Proxy{Name: "production-proxy", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
s.mustCreateProxy(&service.Proxy{Name: "dev-proxy", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "prod")
s.Require().NoError(err)
s.Require().Len(proxies, 1)
s.Require().Contains(proxies[0].Name, "production")
}
// --- ListActive ---
func (s *ProxyRepoSuite) TestListActive() {
s.mustCreateProxy(&service.Proxy{Name: "active1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
s.mustCreateProxy(&service.Proxy{Name: "inactive1", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusDisabled})
proxies, err := s.repo.ListActive(s.ctx)
s.Require().NoError(err, "ListActive")
s.Require().Len(proxies, 1)
s.Require().Equal("active1", proxies[0].Name)
}
// --- ExistsByHostPortAuth ---
func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
s.mustCreateProxy(&service.Proxy{
Name: "p1",
Protocol: "http",
Host: "1.2.3.4",
Port: 8080,
Username: "user",
Password: "pass",
Status: service.StatusActive,
})
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "user", "pass")
s.Require().NoError(err, "ExistsByHostPortAuth")
s.Require().True(exists)
notExists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "wrong", "creds")
s.Require().NoError(err)
s.Require().False(notExists)
}
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
s.mustCreateProxy(&service.Proxy{
Name: "p-noauth",
Protocol: "http",
Host: "5.6.7.8",
Port: 8081,
Username: "",
Password: "",
Status: service.StatusActive,
})
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "5.6.7.8", 8081, "", "")
s.Require().NoError(err)
s.Require().True(exists)
}
// --- CountAccountsByProxyID ---
func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
proxy := s.mustCreateProxy(&service.Proxy{Name: "p-count", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
s.mustInsertAccount("a1", &proxy.ID)
s.mustInsertAccount("a2", &proxy.ID)
s.mustInsertAccount("a3", nil) // no proxy
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
s.Require().NoError(err, "CountAccountsByProxyID")
s.Require().Equal(int64(2), count)
}
func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
proxy := s.mustCreateProxy(&service.Proxy{Name: "p-zero", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
s.Require().NoError(err)
s.Require().Zero(count)
}
// --- GetAccountCountsForProxies ---
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies() {
p1 := s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
p2 := s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
s.mustInsertAccount("a1", &p1.ID)
s.mustInsertAccount("a2", &p1.ID)
s.mustInsertAccount("a3", &p2.ID)
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
s.Require().NoError(err, "GetAccountCountsForProxies")
s.Require().Equal(int64(2), counts[p1.ID])
s.Require().Equal(int64(1), counts[p2.ID])
}
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies_Empty() {
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
s.Require().NoError(err)
s.Require().Empty(counts)
}
// --- ListActiveWithAccountCount ---
func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
p1 := s.mustCreateProxyWithTimes("p1", service.StatusActive, base.Add(-1*time.Hour))
p2 := s.mustCreateProxyWithTimes("p2", service.StatusActive, base)
s.mustCreateProxyWithTimes("p3-inactive", service.StatusDisabled, base.Add(1*time.Hour))
s.mustInsertAccount("a1", &p1.ID)
s.mustInsertAccount("a2", &p1.ID)
s.mustInsertAccount("a3", &p2.ID)
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
s.Require().NoError(err, "ListActiveWithAccountCount")
s.Require().Len(withCounts, 2, "expected 2 active proxies")
// Sorted by created_at DESC, so p2 first
s.Require().Equal(p2.ID, withCounts[0].ID)
s.Require().Equal(int64(1), withCounts[0].AccountCount)
s.Require().Equal(p1.ID, withCounts[1].ID)
s.Require().Equal(int64(2), withCounts[1].AccountCount)
}
// --- Combined original test ---
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
p1 := s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "1.2.3.4", Port: 8080, Username: "u", Password: "p", Status: service.StatusActive})
p2 := s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "5.6.7.8", Port: 8081, Username: "", Password: "", Status: service.StatusActive})
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "u", "p")
s.Require().NoError(err, "ExistsByHostPortAuth")
s.Require().True(exists, "expected proxy to exist")
s.mustInsertAccount("a1", &p1.ID)
s.mustInsertAccount("a2", &p1.ID)
s.mustInsertAccount("a3", &p2.ID)
count1, err := s.repo.CountAccountsByProxyID(s.ctx, p1.ID)
s.Require().NoError(err, "CountAccountsByProxyID")
s.Require().Equal(int64(2), count1, "expected 2 accounts for p1")
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
s.Require().NoError(err, "GetAccountCountsForProxies")
s.Require().Equal(int64(2), counts[p1.ID])
s.Require().Equal(int64(1), counts[p2.ID])
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
s.Require().NoError(err, "ListActiveWithAccountCount")
s.Require().Len(withCounts, 2, "expected 2 proxies")
for _, pc := range withCounts {
switch pc.ID {
case p1.ID:
s.Require().Equal(int64(2), pc.AccountCount, "p1 count mismatch")
case p2.ID:
s.Require().Equal(int64(1), pc.AccountCount, "p2 count mismatch")
default:
s.Require().Fail("unexpected proxy id", pc.ID)
}
}
}
func (s *ProxyRepoSuite) mustCreateProxy(p *service.Proxy) *service.Proxy {
s.T().Helper()
s.Require().NoError(s.repo.Create(s.ctx, p), "create proxy")
return p
}
func (s *ProxyRepoSuite) mustCreateProxyWithTimes(name, status string, createdAt time.Time) *service.Proxy {
s.T().Helper()
// Use the repository create for standard fields, then update timestamps via raw SQL to keep deterministic ordering.
p := s.mustCreateProxy(&service.Proxy{
Name: name,
Protocol: "http",
Host: "127.0.0.1",
Port: 8080,
Status: status,
})
_, err := s.tx.ExecContext(s.ctx, "UPDATE proxies SET created_at = $1, updated_at = $1 WHERE id = $2", createdAt, p.ID)
s.Require().NoError(err, "update proxy timestamps")
return p
}
func (s *ProxyRepoSuite) mustInsertAccount(name string, proxyID *int64) {
s.T().Helper()
var pid any
if proxyID != nil {
pid = *proxyID
}
_, err := s.tx.ExecContext(
s.ctx,
"INSERT INTO accounts (name, platform, type, proxy_id) VALUES ($1, $2, $3, $4)",
name,
service.PlatformAnthropic,
service.AccountTypeOAuth,
pid,
)
s.Require().NoError(err, "insert account")
}

View File

@@ -0,0 +1,62 @@
package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
redeemRateLimitKeyPrefix = "redeem:ratelimit:"
redeemLockKeyPrefix = "redeem:lock:"
redeemRateLimitDuration = 24 * time.Hour
)
// redeemRateLimitKey generates the Redis key for redeem attempt rate limiting.
func redeemRateLimitKey(userID int64) string {
return fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
}
// redeemLockKey generates the Redis key for redeem code locking.
func redeemLockKey(code string) string {
return redeemLockKeyPrefix + code
}
type redeemCache struct {
rdb *redis.Client
}
func NewRedeemCache(rdb *redis.Client) service.RedeemCache {
return &redeemCache{rdb: rdb}
}
func (c *redeemCache) GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) {
key := redeemRateLimitKey(userID)
count, err := c.rdb.Get(ctx, key).Int()
if err == redis.Nil {
return 0, nil
}
return count, err
}
func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID int64) error {
key := redeemRateLimitKey(userID)
pipe := c.rdb.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, redeemRateLimitDuration)
_, err := pipe.Exec(ctx)
return err
}
func (c *redeemCache) AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error) {
key := redeemLockKey(code)
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
}
func (c *redeemCache) ReleaseRedeemLock(ctx context.Context, code string) error {
key := redeemLockKey(code)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,103 @@
//go:build integration
package repository
import (
"fmt"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type RedeemCacheSuite struct {
IntegrationRedisSuite
cache *redeemCache
}
func (s *RedeemCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewRedeemCache(s.rdb).(*redeemCache)
}
func (s *RedeemCacheSuite) TestGetRedeemAttemptCount_Missing() {
missingUserID := int64(99999)
count, err := s.cache.GetRedeemAttemptCount(s.ctx, missingUserID)
require.NoError(s.T(), err, "expected nil error for missing rate-limit key")
require.Equal(s.T(), 0, count, "expected zero count for missing key")
}
func (s *RedeemCacheSuite) TestIncrementAndGetRedeemAttemptCount() {
userID := int64(1)
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID), "IncrementRedeemAttemptCount")
count, err := s.cache.GetRedeemAttemptCount(s.ctx, userID)
require.NoError(s.T(), err, "GetRedeemAttemptCount")
require.Equal(s.T(), 1, count, "count mismatch")
ttl, err := s.rdb.TTL(s.ctx, key).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, redeemRateLimitDuration)
}
func (s *RedeemCacheSuite) TestMultipleIncrements() {
userID := int64(2)
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID))
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID))
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID))
count, err := s.cache.GetRedeemAttemptCount(s.ctx, userID)
require.NoError(s.T(), err)
require.Equal(s.T(), 3, count, "count after 3 increments")
}
func (s *RedeemCacheSuite) TestAcquireAndReleaseRedeemLock() {
ok, err := s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second)
require.NoError(s.T(), err, "AcquireRedeemLock")
require.True(s.T(), ok)
// Second acquire should fail
ok, err = s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second)
require.NoError(s.T(), err, "AcquireRedeemLock 2")
require.False(s.T(), ok, "expected lock to be held")
// Release
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "CODE"), "ReleaseRedeemLock")
// Now acquire should succeed
ok, err = s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second)
require.NoError(s.T(), err, "AcquireRedeemLock after release")
require.True(s.T(), ok)
}
func (s *RedeemCacheSuite) TestAcquireRedeemLock_TTL() {
lockKey := redeemLockKeyPrefix + "CODE2"
lockTTL := 15 * time.Second
ok, err := s.cache.AcquireRedeemLock(s.ctx, "CODE2", lockTTL)
require.NoError(s.T(), err, "AcquireRedeemLock CODE2")
require.True(s.T(), ok)
ttl, err := s.rdb.TTL(s.ctx, lockKey).Result()
require.NoError(s.T(), err, "TTL lock key")
s.AssertTTLWithin(ttl, 1*time.Second, lockTTL)
}
func (s *RedeemCacheSuite) TestReleaseRedeemLock_Idempotent() {
// Release a lock that doesn't exist should not error
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "NONEXISTENT"))
// Acquire, release, release again
ok, err := s.cache.AcquireRedeemLock(s.ctx, "IDEMPOTENT", 10*time.Second)
require.NoError(s.T(), err)
require.True(s.T(), ok)
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "IDEMPOTENT"))
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "IDEMPOTENT"), "second release should be idempotent")
}
func TestRedeemCacheSuite(t *testing.T) {
suite.Run(t, new(RedeemCacheSuite))
}

View File

@@ -0,0 +1,77 @@
//go:build unit
package repository
import (
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestRedeemRateLimitKey(t *testing.T) {
tests := []struct {
name string
userID int64
expected string
}{
{
name: "normal_user_id",
userID: 123,
expected: "redeem:ratelimit:123",
},
{
name: "zero_user_id",
userID: 0,
expected: "redeem:ratelimit:0",
},
{
name: "negative_user_id",
userID: -1,
expected: "redeem:ratelimit:-1",
},
{
name: "max_int64",
userID: math.MaxInt64,
expected: "redeem:ratelimit:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := redeemRateLimitKey(tc.userID)
require.Equal(t, tc.expected, got)
})
}
}
func TestRedeemLockKey(t *testing.T) {
tests := []struct {
name string
code string
expected string
}{
{
name: "normal_code",
code: "ABC123",
expected: "redeem:lock:ABC123",
},
{
name: "empty_code",
code: "",
expected: "redeem:lock:",
},
{
name: "code_with_special_chars",
code: "CODE-2024:test",
expected: "redeem:lock:CODE-2024:test",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := redeemLockKey(tc.code)
require.Equal(t, tc.expected, got)
})
}
}

View File

@@ -0,0 +1,239 @@
package repository
import (
"context"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type redeemCodeRepository struct {
client *dbent.Client
}
func NewRedeemCodeRepository(client *dbent.Client) service.RedeemCodeRepository {
return &redeemCodeRepository{client: client}
}
func (r *redeemCodeRepository) Create(ctx context.Context, code *service.RedeemCode) error {
created, err := r.client.RedeemCode.Create().
SetCode(code.Code).
SetType(code.Type).
SetValue(code.Value).
SetStatus(code.Status).
SetNotes(code.Notes).
SetValidityDays(code.ValidityDays).
SetNillableUsedBy(code.UsedBy).
SetNillableUsedAt(code.UsedAt).
SetNillableGroupID(code.GroupID).
Save(ctx)
if err == nil {
code.ID = created.ID
code.CreatedAt = created.CreatedAt
}
return err
}
func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []service.RedeemCode) error {
if len(codes) == 0 {
return nil
}
builders := make([]*dbent.RedeemCodeCreate, 0, len(codes))
for i := range codes {
c := &codes[i]
b := r.client.RedeemCode.Create().
SetCode(c.Code).
SetType(c.Type).
SetValue(c.Value).
SetStatus(c.Status).
SetNotes(c.Notes).
SetValidityDays(c.ValidityDays).
SetNillableUsedBy(c.UsedBy).
SetNillableUsedAt(c.UsedAt).
SetNillableGroupID(c.GroupID)
builders = append(builders, b)
}
return r.client.RedeemCode.CreateBulk(builders...).Exec(ctx)
}
func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*service.RedeemCode, error) {
m, err := r.client.RedeemCode.Query().
Where(redeemcode.IDEQ(id)).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrRedeemCodeNotFound
}
return nil, err
}
return redeemCodeEntityToService(m), nil
}
func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) {
m, err := r.client.RedeemCode.Query().
Where(redeemcode.CodeEQ(code)).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrRedeemCodeNotFound
}
return nil, err
}
return redeemCodeEntityToService(m), nil
}
func (r *redeemCodeRepository) Delete(ctx context.Context, id int64) error {
_, err := r.client.RedeemCode.Delete().Where(redeemcode.IDEQ(id)).Exec(ctx)
return err
}
func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
q := r.client.RedeemCode.Query()
if codeType != "" {
q = q.Where(redeemcode.TypeEQ(codeType))
}
if status != "" {
q = q.Where(redeemcode.StatusEQ(status))
}
if search != "" {
q = q.Where(redeemcode.CodeContainsFold(search))
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
codes, err := q.
WithUser().
WithGroup().
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(redeemcode.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outCodes := redeemCodeEntitiesToService(codes)
return outCodes, paginationResultFromTotal(int64(total), params), nil
}
func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemCode) error {
up := r.client.RedeemCode.UpdateOneID(code.ID).
SetCode(code.Code).
SetType(code.Type).
SetValue(code.Value).
SetStatus(code.Status).
SetNotes(code.Notes).
SetValidityDays(code.ValidityDays)
if code.UsedBy != nil {
up.SetUsedBy(*code.UsedBy)
} else {
up.ClearUsedBy()
}
if code.UsedAt != nil {
up.SetUsedAt(*code.UsedAt)
} else {
up.ClearUsedAt()
}
if code.GroupID != nil {
up.SetGroupID(*code.GroupID)
} else {
up.ClearGroupID()
}
updated, err := up.Save(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return service.ErrRedeemCodeNotFound
}
return err
}
code.CreatedAt = updated.CreatedAt
return nil
}
func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
now := time.Now()
client := clientFromContext(ctx, r.client)
affected, err := client.RedeemCode.Update().
Where(redeemcode.IDEQ(id), redeemcode.StatusEQ(service.StatusUnused)).
SetStatus(service.StatusUsed).
SetUsedBy(userID).
SetUsedAt(now).
Save(ctx)
if err != nil {
return err
}
if affected == 0 {
return service.ErrRedeemCodeUsed
}
return nil
}
func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) {
if limit <= 0 {
limit = 10
}
codes, err := r.client.RedeemCode.Query().
Where(redeemcode.UsedByEQ(userID)).
WithGroup().
Order(dbent.Desc(redeemcode.FieldUsedAt)).
Limit(limit).
All(ctx)
if err != nil {
return nil, err
}
return redeemCodeEntitiesToService(codes), nil
}
func redeemCodeEntityToService(m *dbent.RedeemCode) *service.RedeemCode {
if m == nil {
return nil
}
out := &service.RedeemCode{
ID: m.ID,
Code: m.Code,
Type: m.Type,
Value: m.Value,
Status: m.Status,
UsedBy: m.UsedBy,
UsedAt: m.UsedAt,
Notes: derefString(m.Notes),
CreatedAt: m.CreatedAt,
GroupID: m.GroupID,
ValidityDays: m.ValidityDays,
}
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)
}
if m.Edges.Group != nil {
out.Group = groupEntityToService(m.Edges.Group)
}
return out
}
func redeemCodeEntitiesToService(models []*dbent.RedeemCode) []service.RedeemCode {
out := make([]service.RedeemCode, 0, len(models))
for i := range models {
if s := redeemCodeEntityToService(models[i]); s != nil {
out = append(out, *s)
}
}
return out
}

View File

@@ -0,0 +1,390 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
)
type RedeemCodeRepoSuite struct {
suite.Suite
ctx context.Context
client *dbent.Client
repo *redeemCodeRepository
}
func (s *RedeemCodeRepoSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
s.repo = NewRedeemCodeRepository(s.client).(*redeemCodeRepository)
}
func TestRedeemCodeRepoSuite(t *testing.T) {
suite.Run(t, new(RedeemCodeRepoSuite))
}
func (s *RedeemCodeRepoSuite) createUser(email string) *dbent.User {
u, err := s.client.User.Create().
SetEmail(email).
SetPasswordHash("test-password-hash").
Save(s.ctx)
s.Require().NoError(err, "create user")
return u
}
func (s *RedeemCodeRepoSuite) createGroup(name string) *dbent.Group {
g, err := s.client.Group.Create().
SetName(name).
Save(s.ctx)
s.Require().NoError(err, "create group")
return g
}
// --- Create / CreateBatch / GetByID / GetByCode ---
func (s *RedeemCodeRepoSuite) TestCreate() {
code := &service.RedeemCode{
Code: "TEST-CREATE",
Type: service.RedeemTypeBalance,
Value: 100,
Status: service.StatusUnused,
}
err := s.repo.Create(s.ctx, code)
s.Require().NoError(err, "Create")
s.Require().NotZero(code.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, code.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("TEST-CREATE", got.Code)
}
func (s *RedeemCodeRepoSuite) TestCreateBatch() {
codes := []service.RedeemCode{
{Code: "BATCH-1", Type: service.RedeemTypeBalance, Value: 10, Status: service.StatusUnused},
{Code: "BATCH-2", Type: service.RedeemTypeBalance, Value: 20, Status: service.StatusUnused},
}
err := s.repo.CreateBatch(s.ctx, codes)
s.Require().NoError(err, "CreateBatch")
got1, err := s.repo.GetByCode(s.ctx, "BATCH-1")
s.Require().NoError(err)
s.Require().Equal(float64(10), got1.Value)
got2, err := s.repo.GetByCode(s.ctx, "BATCH-2")
s.Require().NoError(err)
s.Require().Equal(float64(20), got2.Value)
}
func (s *RedeemCodeRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
s.Require().ErrorIs(err, service.ErrRedeemCodeNotFound)
}
func (s *RedeemCodeRepoSuite) TestGetByCode() {
_, err := s.client.RedeemCode.Create().
SetCode("GET-BY-CODE").
SetType(service.RedeemTypeBalance).
SetStatus(service.StatusUnused).
SetValue(0).
SetNotes("").
SetValidityDays(30).
Save(s.ctx)
s.Require().NoError(err, "seed redeem code")
got, err := s.repo.GetByCode(s.ctx, "GET-BY-CODE")
s.Require().NoError(err, "GetByCode")
s.Require().Equal("GET-BY-CODE", got.Code)
}
func (s *RedeemCodeRepoSuite) TestGetByCode_NotFound() {
_, err := s.repo.GetByCode(s.ctx, "NON-EXISTENT")
s.Require().Error(err, "expected error for non-existent code")
s.Require().ErrorIs(err, service.ErrRedeemCodeNotFound)
}
// --- Delete ---
func (s *RedeemCodeRepoSuite) TestDelete() {
created, err := s.client.RedeemCode.Create().
SetCode("TO-DELETE").
SetType(service.RedeemTypeBalance).
SetStatus(service.StatusUnused).
SetValue(0).
SetNotes("").
SetValidityDays(30).
Save(s.ctx)
s.Require().NoError(err)
err = s.repo.Delete(s.ctx, created.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, created.ID)
s.Require().Error(err, "expected error after delete")
s.Require().ErrorIs(err, service.ErrRedeemCodeNotFound)
}
// --- List / ListWithFilters ---
func (s *RedeemCodeRepoSuite) TestList() {
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "LIST-1", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "LIST-2", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
codes, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
s.Require().Len(codes, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_Type() {
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "TYPE-BAL", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "TYPE-SUB", Type: service.RedeemTypeSubscription, Value: 0, Status: service.StatusUnused}))
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.RedeemTypeSubscription, "", "")
s.Require().NoError(err)
s.Require().Len(codes, 1)
s.Require().Equal(service.RedeemTypeSubscription, codes[0].Type)
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() {
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "STAT-UNUSED", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "STAT-USED", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUsed}))
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusUsed, "")
s.Require().NoError(err)
s.Require().Len(codes, 1)
s.Require().Equal(service.StatusUsed, codes[0].Status)
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "ALPHA-CODE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "BETA-CODE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alpha")
s.Require().NoError(err)
s.Require().Len(codes, 1)
s.Require().Contains(codes[0].Code, "ALPHA")
}
func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() {
group := s.createGroup(uniqueTestValue(s.T(), "g-preload"))
_, err := s.client.RedeemCode.Create().
SetCode("WITH-GROUP").
SetType(service.RedeemTypeSubscription).
SetStatus(service.StatusUnused).
SetValue(0).
SetNotes("").
SetValidityDays(30).
SetGroupID(group.ID).
Save(s.ctx)
s.Require().NoError(err)
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "")
s.Require().NoError(err)
s.Require().Len(codes, 1)
s.Require().NotNil(codes[0].Group, "expected Group preload")
s.Require().Equal(group.ID, codes[0].Group.ID)
}
// --- Update ---
func (s *RedeemCodeRepoSuite) TestUpdate() {
code := &service.RedeemCode{
Code: "UPDATE-ME",
Type: service.RedeemTypeBalance,
Value: 10,
Status: service.StatusUnused,
}
s.Require().NoError(s.repo.Create(s.ctx, code))
code.Value = 50
err := s.repo.Update(s.ctx, code)
s.Require().NoError(err, "Update")
got, err := s.repo.GetByID(s.ctx, code.ID)
s.Require().NoError(err)
s.Require().Equal(float64(50), got.Value)
}
// --- Use ---
func (s *RedeemCodeRepoSuite) TestUse() {
user := s.createUser(uniqueTestValue(s.T(), "use") + "@example.com")
code := &service.RedeemCode{Code: "USE-ME", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}
s.Require().NoError(s.repo.Create(s.ctx, code))
err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().NoError(err, "Use")
got, err := s.repo.GetByID(s.ctx, code.ID)
s.Require().NoError(err)
s.Require().Equal(service.StatusUsed, got.Status)
s.Require().NotNil(got.UsedBy)
s.Require().Equal(user.ID, *got.UsedBy)
s.Require().NotNil(got.UsedAt)
}
func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
user := s.createUser(uniqueTestValue(s.T(), "idem") + "@example.com")
code := &service.RedeemCode{Code: "IDEM-CODE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}
s.Require().NoError(s.repo.Create(s.ctx, code))
err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().NoError(err, "Use first time")
// Second use should fail
err = s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().Error(err, "Use expected error on second call")
s.Require().ErrorIs(err, service.ErrRedeemCodeUsed)
}
func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
user := s.createUser(uniqueTestValue(s.T(), "already") + "@example.com")
code := &service.RedeemCode{Code: "ALREADY-USED", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUsed}
s.Require().NoError(s.repo.Create(s.ctx, code))
err := s.repo.Use(s.ctx, code.ID, user.ID)
s.Require().Error(err, "expected error for already used code")
s.Require().ErrorIs(err, service.ErrRedeemCodeUsed)
}
// --- ListByUser ---
func (s *RedeemCodeRepoSuite) TestListByUser() {
user := s.createUser(uniqueTestValue(s.T(), "listby") + "@example.com")
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
usedAt1 := base
_, err := s.client.RedeemCode.Create().
SetCode("USER-1").
SetType(service.RedeemTypeBalance).
SetStatus(service.StatusUsed).
SetValue(0).
SetNotes("").
SetValidityDays(30).
SetUsedBy(user.ID).
SetUsedAt(usedAt1).
Save(s.ctx)
s.Require().NoError(err)
usedAt2 := base.Add(1 * time.Hour)
_, err = s.client.RedeemCode.Create().
SetCode("USER-2").
SetType(service.RedeemTypeBalance).
SetStatus(service.StatusUsed).
SetValue(0).
SetNotes("").
SetValidityDays(30).
SetUsedBy(user.ID).
SetUsedAt(usedAt2).
Save(s.ctx)
s.Require().NoError(err)
codes, err := s.repo.ListByUser(s.ctx, user.ID, 10)
s.Require().NoError(err, "ListByUser")
s.Require().Len(codes, 2)
// Ordered by used_at DESC, so USER-2 first
s.Require().Equal("USER-2", codes[0].Code)
s.Require().Equal("USER-1", codes[1].Code)
}
func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() {
user := s.createUser(uniqueTestValue(s.T(), "grp") + "@example.com")
group := s.createGroup(uniqueTestValue(s.T(), "g-listby"))
_, err := s.client.RedeemCode.Create().
SetCode("WITH-GRP").
SetType(service.RedeemTypeSubscription).
SetStatus(service.StatusUsed).
SetValue(0).
SetNotes("").
SetValidityDays(30).
SetUsedBy(user.ID).
SetUsedAt(time.Now()).
SetGroupID(group.ID).
Save(s.ctx)
s.Require().NoError(err)
codes, err := s.repo.ListByUser(s.ctx, user.ID, 10)
s.Require().NoError(err)
s.Require().Len(codes, 1)
s.Require().NotNil(codes[0].Group)
s.Require().Equal(group.ID, codes[0].Group.ID)
}
func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() {
user := s.createUser(uniqueTestValue(s.T(), "deflimit") + "@example.com")
_, err := s.client.RedeemCode.Create().
SetCode("DEF-LIM").
SetType(service.RedeemTypeBalance).
SetStatus(service.StatusUsed).
SetValue(0).
SetNotes("").
SetValidityDays(30).
SetUsedBy(user.ID).
SetUsedAt(time.Now()).
Save(s.ctx)
s.Require().NoError(err)
// limit <= 0 should default to 10
codes, err := s.repo.ListByUser(s.ctx, user.ID, 0)
s.Require().NoError(err)
s.Require().Len(codes, 1)
}
// --- Combined original test ---
func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser() {
user := s.createUser(uniqueTestValue(s.T(), "rc") + "@example.com")
group := s.createGroup(uniqueTestValue(s.T(), "g-rc"))
groupID := group.ID
codes := []service.RedeemCode{
{Code: "CODEA", Type: service.RedeemTypeBalance, Value: 1, Status: service.StatusUnused, Notes: ""},
{Code: "CODEB", Type: service.RedeemTypeSubscription, Value: 0, Status: service.StatusUnused, Notes: "", GroupID: &groupID, ValidityDays: 7},
}
s.Require().NoError(s.repo.CreateBatch(s.ctx, codes), "CreateBatch")
list, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.RedeemTypeSubscription, service.StatusUnused, "code")
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total)
s.Require().Len(list, 1)
s.Require().NotNil(list[0].Group, "expected Group preload")
s.Require().Equal(group.ID, list[0].Group.ID)
codeB, err := s.repo.GetByCode(s.ctx, "CODEB")
s.Require().NoError(err, "GetByCode")
s.Require().NoError(s.repo.Use(s.ctx, codeB.ID, user.ID), "Use")
err = s.repo.Use(s.ctx, codeB.ID, user.ID)
s.Require().Error(err, "Use expected error on second call")
s.Require().ErrorIs(err, service.ErrRedeemCodeUsed)
codeA, err := s.repo.GetByCode(s.ctx, "CODEA")
s.Require().NoError(err, "GetByCode")
// Use fixed time instead of time.Sleep for deterministic ordering.
_, err = s.client.RedeemCode.UpdateOneID(codeB.ID).
SetUsedAt(time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)).
Save(s.ctx)
s.Require().NoError(err)
s.Require().NoError(s.repo.Use(s.ctx, codeA.ID, user.ID), "Use codeA")
_, err = s.client.RedeemCode.UpdateOneID(codeA.ID).
SetUsedAt(time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC)).
Save(s.ctx)
s.Require().NoError(err)
used, err := s.repo.ListByUser(s.ctx, user.ID, 10)
s.Require().NoError(err, "ListByUser")
s.Require().Len(used, 2, "expected 2 used codes")
s.Require().Equal("CODEA", used[0].Code, "expected newest used code first")
}

View File

@@ -0,0 +1,39 @@
package repository
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/redis/go-redis/v9"
)
// InitRedis 初始化 Redis 客户端
//
// 性能优化说明:
// 原实现使用 go-redis 默认配置,未设置连接池和超时参数:
// 1. 默认连接池大小可能不足以支撑高并发
// 2. 无超时控制可能导致慢操作阻塞
//
// 新实现支持可配置的连接池和超时参数:
// 1. PoolSize: 控制最大并发连接数(默认 128
// 2. MinIdleConns: 保持最小空闲连接,减少冷启动延迟(默认 10
// 3. DialTimeout/ReadTimeout/WriteTimeout: 精确控制各阶段超时
func InitRedis(cfg *config.Config) *redis.Client {
return redis.NewClient(buildRedisOptions(cfg))
}
// buildRedisOptions 构建 Redis 连接选项
// 从配置文件读取连接池和超时参数,支持生产环境调优
func buildRedisOptions(cfg *config.Config) *redis.Options {
return &redis.Options{
Addr: cfg.Redis.Address(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
DialTimeout: time.Duration(cfg.Redis.DialTimeoutSeconds) * time.Second, // 建连超时
ReadTimeout: time.Duration(cfg.Redis.ReadTimeoutSeconds) * time.Second, // 读取超时
WriteTimeout: time.Duration(cfg.Redis.WriteTimeoutSeconds) * time.Second, // 写入超时
PoolSize: cfg.Redis.PoolSize, // 连接池大小
MinIdleConns: cfg.Redis.MinIdleConns, // 最小空闲连接
}
}

View File

@@ -0,0 +1,35 @@
package repository
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestBuildRedisOptions(t *testing.T) {
cfg := &config.Config{
Redis: config.RedisConfig{
Host: "localhost",
Port: 6379,
Password: "secret",
DB: 2,
DialTimeoutSeconds: 5,
ReadTimeoutSeconds: 3,
WriteTimeoutSeconds: 4,
PoolSize: 100,
MinIdleConns: 10,
},
}
opts := buildRedisOptions(cfg)
require.Equal(t, "localhost:6379", opts.Addr)
require.Equal(t, "secret", opts.Password)
require.Equal(t, 2, opts.DB)
require.Equal(t, 5*time.Second, opts.DialTimeout)
require.Equal(t, 3*time.Second, opts.ReadTimeout)
require.Equal(t, 4*time.Second, opts.WriteTimeout)
require.Equal(t, 100, opts.PoolSize)
require.Equal(t, 10, opts.MinIdleConns)
}

View File

@@ -0,0 +1,64 @@
package repository
import (
"fmt"
"strings"
"sync"
"time"
"github.com/imroc/req/v3"
)
// reqClientOptions 定义 req 客户端的构建参数
type reqClientOptions struct {
ProxyURL string // 代理 URL支持 http/https/socks5
Timeout time.Duration // 请求超时时间
Impersonate bool // 是否模拟 Chrome 浏览器指纹
}
// sharedReqClients 存储按配置参数缓存的 req 客户端实例
//
// 性能优化说明:
// 原实现在每次 OAuth 刷新时都创建新的 req.Client
// 1. claude_oauth_service.go: 每次刷新创建新客户端
// 2. openai_oauth_service.go: 每次刷新创建新客户端
// 3. gemini_oauth_client.go: 每次刷新创建新客户端
//
// 新实现使用 sync.Map 缓存客户端:
// 1. 相同配置(代理+超时+模拟设置)复用同一客户端
// 2. 复用底层连接池,减少 TLS 握手开销
// 3. LoadOrStore 保证并发安全,避免重复创建
var sharedReqClients sync.Map
// getSharedReqClient 获取共享的 req 客户端实例
// 性能优化:相同配置复用同一客户端,避免重复创建
func getSharedReqClient(opts reqClientOptions) *req.Client {
key := buildReqClientKey(opts)
if cached, ok := sharedReqClients.Load(key); ok {
if c, ok := cached.(*req.Client); ok {
return c
}
}
client := req.C().SetTimeout(opts.Timeout)
if opts.Impersonate {
client = client.ImpersonateChrome()
}
if strings.TrimSpace(opts.ProxyURL) != "" {
client.SetProxyURL(strings.TrimSpace(opts.ProxyURL))
}
actual, _ := sharedReqClients.LoadOrStore(key, client)
if c, ok := actual.(*req.Client); ok {
return c
}
return client
}
func buildReqClientKey(opts reqClientOptions) string {
return fmt.Sprintf("%s|%s|%t",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.Impersonate,
)
}

View File

@@ -0,0 +1,276 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
schedulerBucketSetKey = "sched:buckets"
schedulerOutboxWatermarkKey = "sched:outbox:watermark"
schedulerAccountPrefix = "sched:acc:"
schedulerActivePrefix = "sched:active:"
schedulerReadyPrefix = "sched:ready:"
schedulerVersionPrefix = "sched:ver:"
schedulerSnapshotPrefix = "sched:"
schedulerLockPrefix = "sched:lock:"
)
type schedulerCache struct {
rdb *redis.Client
}
func NewSchedulerCache(rdb *redis.Client) service.SchedulerCache {
return &schedulerCache{rdb: rdb}
}
func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
readyKey := schedulerBucketKey(schedulerReadyPrefix, bucket)
readyVal, err := c.rdb.Get(ctx, readyKey).Result()
if err == redis.Nil {
return nil, false, nil
}
if err != nil {
return nil, false, err
}
if readyVal != "1" {
return nil, false, nil
}
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
activeVal, err := c.rdb.Get(ctx, activeKey).Result()
if err == redis.Nil {
return nil, false, nil
}
if err != nil {
return nil, false, err
}
snapshotKey := schedulerSnapshotKey(bucket, activeVal)
ids, err := c.rdb.ZRange(ctx, snapshotKey, 0, -1).Result()
if err != nil {
return nil, false, err
}
if len(ids) == 0 {
return []*service.Account{}, true, nil
}
keys := make([]string, 0, len(ids))
for _, id := range ids {
keys = append(keys, schedulerAccountKey(id))
}
values, err := c.rdb.MGet(ctx, keys...).Result()
if err != nil {
return nil, false, err
}
accounts := make([]*service.Account, 0, len(values))
for _, val := range values {
if val == nil {
return nil, false, nil
}
account, err := decodeCachedAccount(val)
if err != nil {
return nil, false, err
}
accounts = append(accounts, account)
}
return accounts, true, nil
}
func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
oldActive, _ := c.rdb.Get(ctx, activeKey).Result()
versionKey := schedulerBucketKey(schedulerVersionPrefix, bucket)
version, err := c.rdb.Incr(ctx, versionKey).Result()
if err != nil {
return err
}
versionStr := strconv.FormatInt(version, 10)
snapshotKey := schedulerSnapshotKey(bucket, versionStr)
pipe := c.rdb.Pipeline()
for _, account := range accounts {
payload, err := json.Marshal(account)
if err != nil {
return err
}
pipe.Set(ctx, schedulerAccountKey(strconv.FormatInt(account.ID, 10)), payload, 0)
}
if len(accounts) > 0 {
// 使用序号作为 score保持数据库返回的排序语义。
members := make([]redis.Z, 0, len(accounts))
for idx, account := range accounts {
members = append(members, redis.Z{
Score: float64(idx),
Member: strconv.FormatInt(account.ID, 10),
})
}
pipe.ZAdd(ctx, snapshotKey, members...)
} else {
pipe.Del(ctx, snapshotKey)
}
pipe.Set(ctx, activeKey, versionStr, 0)
pipe.Set(ctx, schedulerBucketKey(schedulerReadyPrefix, bucket), "1", 0)
pipe.SAdd(ctx, schedulerBucketSetKey, bucket.String())
if _, err := pipe.Exec(ctx); err != nil {
return err
}
if oldActive != "" && oldActive != versionStr {
_ = c.rdb.Del(ctx, schedulerSnapshotKey(bucket, oldActive)).Err()
}
return nil
}
func (c *schedulerCache) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
key := schedulerAccountKey(strconv.FormatInt(accountID, 10))
val, err := c.rdb.Get(ctx, key).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, err
}
return decodeCachedAccount(val)
}
func (c *schedulerCache) SetAccount(ctx context.Context, account *service.Account) error {
if account == nil || account.ID <= 0 {
return nil
}
payload, err := json.Marshal(account)
if err != nil {
return err
}
key := schedulerAccountKey(strconv.FormatInt(account.ID, 10))
return c.rdb.Set(ctx, key, payload, 0).Err()
}
func (c *schedulerCache) DeleteAccount(ctx context.Context, accountID int64) error {
if accountID <= 0 {
return nil
}
key := schedulerAccountKey(strconv.FormatInt(accountID, 10))
return c.rdb.Del(ctx, key).Err()
}
func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
if len(updates) == 0 {
return nil
}
keys := make([]string, 0, len(updates))
ids := make([]int64, 0, len(updates))
for id := range updates {
keys = append(keys, schedulerAccountKey(strconv.FormatInt(id, 10)))
ids = append(ids, id)
}
values, err := c.rdb.MGet(ctx, keys...).Result()
if err != nil {
return err
}
pipe := c.rdb.Pipeline()
for i, val := range values {
if val == nil {
continue
}
account, err := decodeCachedAccount(val)
if err != nil {
return err
}
account.LastUsedAt = ptrTime(updates[ids[i]])
updated, err := json.Marshal(account)
if err != nil {
return err
}
pipe.Set(ctx, keys[i], updated, 0)
}
_, err = pipe.Exec(ctx)
return err
}
func (c *schedulerCache) TryLockBucket(ctx context.Context, bucket service.SchedulerBucket, ttl time.Duration) (bool, error) {
key := schedulerBucketKey(schedulerLockPrefix, bucket)
return c.rdb.SetNX(ctx, key, time.Now().UnixNano(), ttl).Result()
}
func (c *schedulerCache) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
raw, err := c.rdb.SMembers(ctx, schedulerBucketSetKey).Result()
if err != nil {
return nil, err
}
out := make([]service.SchedulerBucket, 0, len(raw))
for _, entry := range raw {
bucket, ok := service.ParseSchedulerBucket(entry)
if !ok {
continue
}
out = append(out, bucket)
}
return out, nil
}
func (c *schedulerCache) GetOutboxWatermark(ctx context.Context) (int64, error) {
val, err := c.rdb.Get(ctx, schedulerOutboxWatermarkKey).Result()
if err == redis.Nil {
return 0, nil
}
if err != nil {
return 0, err
}
id, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return 0, err
}
return id, nil
}
func (c *schedulerCache) SetOutboxWatermark(ctx context.Context, id int64) error {
return c.rdb.Set(ctx, schedulerOutboxWatermarkKey, strconv.FormatInt(id, 10), 0).Err()
}
func schedulerBucketKey(prefix string, bucket service.SchedulerBucket) string {
return fmt.Sprintf("%s%d:%s:%s", prefix, bucket.GroupID, bucket.Platform, bucket.Mode)
}
func schedulerSnapshotKey(bucket service.SchedulerBucket, version string) string {
return fmt.Sprintf("%s%d:%s:%s:v%s", schedulerSnapshotPrefix, bucket.GroupID, bucket.Platform, bucket.Mode, version)
}
func schedulerAccountKey(id string) string {
return schedulerAccountPrefix + id
}
func ptrTime(t time.Time) *time.Time {
return &t
}
func decodeCachedAccount(val any) (*service.Account, error) {
var payload []byte
switch raw := val.(type) {
case string:
payload = []byte(raw)
case []byte:
payload = raw
default:
return nil, fmt.Errorf("unexpected account cache type: %T", val)
}
var account service.Account
if err := json.Unmarshal(payload, &account); err != nil {
return nil, err
}
return &account, nil
}

View File

@@ -0,0 +1,96 @@
package repository
import (
"context"
"database/sql"
"encoding/json"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type schedulerOutboxRepository struct {
db *sql.DB
}
func NewSchedulerOutboxRepository(db *sql.DB) service.SchedulerOutboxRepository {
return &schedulerOutboxRepository{db: db}
}
func (r *schedulerOutboxRepository) ListAfter(ctx context.Context, afterID int64, limit int) ([]service.SchedulerOutboxEvent, error) {
if limit <= 0 {
limit = 100
}
rows, err := r.db.QueryContext(ctx, `
SELECT id, event_type, account_id, group_id, payload, created_at
FROM scheduler_outbox
WHERE id > $1
ORDER BY id ASC
LIMIT $2
`, afterID, limit)
if err != nil {
return nil, err
}
defer func() {
_ = rows.Close()
}()
events := make([]service.SchedulerOutboxEvent, 0, limit)
for rows.Next() {
var (
payloadRaw []byte
accountID sql.NullInt64
groupID sql.NullInt64
event service.SchedulerOutboxEvent
)
if err := rows.Scan(&event.ID, &event.EventType, &accountID, &groupID, &payloadRaw, &event.CreatedAt); err != nil {
return nil, err
}
if accountID.Valid {
v := accountID.Int64
event.AccountID = &v
}
if groupID.Valid {
v := groupID.Int64
event.GroupID = &v
}
if len(payloadRaw) > 0 {
var payload map[string]any
if err := json.Unmarshal(payloadRaw, &payload); err != nil {
return nil, err
}
event.Payload = payload
}
events = append(events, event)
}
if err := rows.Err(); err != nil {
return nil, err
}
return events, nil
}
func (r *schedulerOutboxRepository) MaxID(ctx context.Context) (int64, error) {
var maxID int64
if err := r.db.QueryRowContext(ctx, "SELECT COALESCE(MAX(id), 0) FROM scheduler_outbox").Scan(&maxID); err != nil {
return 0, err
}
return maxID, nil
}
func enqueueSchedulerOutbox(ctx context.Context, exec sqlExecutor, eventType string, accountID *int64, groupID *int64, payload any) error {
if exec == nil {
return nil
}
var payloadArg any
if payload != nil {
encoded, err := json.Marshal(payload)
if err != nil {
return err
}
payloadArg = encoded
}
_, err := exec.ExecContext(ctx, `
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
VALUES ($1, $2, $3, $4)
`, eventType, accountID, groupID, payloadArg)
return err
}

View File

@@ -0,0 +1,68 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
ctx := context.Background()
rdb := testRedis(t)
client := testEntClient(t)
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox")
accountRepo := newAccountRepositoryWithSQL(client, integrationDB)
outboxRepo := NewSchedulerOutboxRepository(integrationDB)
cache := NewSchedulerCache(rdb)
cfg := &config.Config{
RunMode: config.RunModeStandard,
Gateway: config.GatewayConfig{
Scheduling: config.GatewaySchedulingConfig{
OutboxPollIntervalSeconds: 1,
FullRebuildIntervalSeconds: 0,
DbFallbackEnabled: true,
},
},
}
account := &service.Account{
Name: "outbox-replay-" + time.Now().Format("150405.000000"),
Platform: service.PlatformOpenAI,
Type: service.AccountTypeAPIKey,
Status: service.StatusActive,
Schedulable: true,
Concurrency: 3,
Priority: 1,
Credentials: map[string]any{},
Extra: map[string]any{},
}
require.NoError(t, accountRepo.Create(ctx, account))
require.NoError(t, cache.SetAccount(ctx, account))
svc := service.NewSchedulerSnapshotService(cache, outboxRepo, accountRepo, nil, cfg)
svc.Start()
t.Cleanup(svc.Stop)
require.NoError(t, accountRepo.UpdateLastUsed(ctx, account.ID))
updated, err := accountRepo.GetByID(ctx, account.ID)
require.NoError(t, err)
require.NotNil(t, updated.LastUsedAt)
expectedUnix := updated.LastUsedAt.Unix()
require.Eventually(t, func() bool {
cached, err := cache.GetAccount(ctx, account.ID)
if err != nil || cached == nil || cached.LastUsedAt == nil {
return false
}
return cached.LastUsedAt.Unix() == expectedUnix
}, 5*time.Second, 100*time.Millisecond)
}

View File

@@ -0,0 +1,105 @@
package repository
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/setting"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type settingRepository struct {
client *ent.Client
}
func NewSettingRepository(client *ent.Client) service.SettingRepository {
return &settingRepository{client: client}
}
func (r *settingRepository) Get(ctx context.Context, key string) (*service.Setting, error) {
m, err := r.client.Setting.Query().Where(setting.KeyEQ(key)).Only(ctx)
if err != nil {
if ent.IsNotFound(err) {
return nil, service.ErrSettingNotFound
}
return nil, err
}
return &service.Setting{
ID: m.ID,
Key: m.Key,
Value: m.Value,
UpdatedAt: m.UpdatedAt,
}, nil
}
func (r *settingRepository) GetValue(ctx context.Context, key string) (string, error) {
setting, err := r.Get(ctx, key)
if err != nil {
return "", err
}
return setting.Value, nil
}
func (r *settingRepository) Set(ctx context.Context, key, value string) error {
now := time.Now()
return r.client.Setting.
Create().
SetKey(key).
SetValue(value).
SetUpdatedAt(now).
OnConflictColumns(setting.FieldKey).
UpdateNewValues().
Exec(ctx)
}
func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
if len(keys) == 0 {
return map[string]string{}, nil
}
settings, err := r.client.Setting.Query().Where(setting.KeyIn(keys...)).All(ctx)
if err != nil {
return nil, err
}
result := make(map[string]string)
for _, s := range settings {
result[s.Key] = s.Value
}
return result, nil
}
func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string]string) error {
if len(settings) == 0 {
return nil
}
now := time.Now()
builders := make([]*ent.SettingCreate, 0, len(settings))
for key, value := range settings {
builders = append(builders, r.client.Setting.Create().SetKey(key).SetValue(value).SetUpdatedAt(now))
}
return r.client.Setting.
CreateBulk(builders...).
OnConflictColumns(setting.FieldKey).
UpdateNewValues().
Exec(ctx)
}
func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, error) {
settings, err := r.client.Setting.Query().All(ctx)
if err != nil {
return nil, err
}
result := make(map[string]string)
for _, s := range settings {
result[s.Key] = s.Value
}
return result, nil
}
func (r *settingRepository) Delete(ctx context.Context, key string) error {
_, err := r.client.Setting.Delete().Where(setting.KeyEQ(key)).Exec(ctx)
return err
}

View File

@@ -0,0 +1,163 @@
//go:build integration
package repository
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
)
type SettingRepoSuite struct {
suite.Suite
ctx context.Context
repo *settingRepository
}
func (s *SettingRepoSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.repo = NewSettingRepository(tx.Client()).(*settingRepository)
}
func TestSettingRepoSuite(t *testing.T) {
suite.Run(t, new(SettingRepoSuite))
}
func (s *SettingRepoSuite) TestSetAndGetValue() {
s.Require().NoError(s.repo.Set(s.ctx, "k1", "v1"), "Set")
got, err := s.repo.GetValue(s.ctx, "k1")
s.Require().NoError(err, "GetValue")
s.Require().Equal("v1", got, "GetValue mismatch")
}
func (s *SettingRepoSuite) TestSet_Upsert() {
s.Require().NoError(s.repo.Set(s.ctx, "k1", "v1"), "Set")
s.Require().NoError(s.repo.Set(s.ctx, "k1", "v2"), "Set upsert")
got, err := s.repo.GetValue(s.ctx, "k1")
s.Require().NoError(err, "GetValue after upsert")
s.Require().Equal("v2", got, "upsert mismatch")
}
func (s *SettingRepoSuite) TestGetValue_Missing() {
_, err := s.repo.GetValue(s.ctx, "nonexistent")
s.Require().Error(err, "expected error for missing key")
s.Require().ErrorIs(err, service.ErrSettingNotFound)
}
func (s *SettingRepoSuite) TestSetMultiple_AndGetMultiple() {
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"k2": "v2", "k3": "v3"}), "SetMultiple")
m, err := s.repo.GetMultiple(s.ctx, []string{"k2", "k3"})
s.Require().NoError(err, "GetMultiple")
s.Require().Equal("v2", m["k2"])
s.Require().Equal("v3", m["k3"])
}
func (s *SettingRepoSuite) TestGetMultiple_EmptyKeys() {
m, err := s.repo.GetMultiple(s.ctx, []string{})
s.Require().NoError(err, "GetMultiple with empty keys")
s.Require().Empty(m, "expected empty map")
}
func (s *SettingRepoSuite) TestGetMultiple_Subset() {
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"a": "1", "b": "2", "c": "3"}))
m, err := s.repo.GetMultiple(s.ctx, []string{"a", "c", "nonexistent"})
s.Require().NoError(err, "GetMultiple subset")
s.Require().Equal("1", m["a"])
s.Require().Equal("3", m["c"])
_, exists := m["nonexistent"]
s.Require().False(exists, "nonexistent key should not be in map")
}
func (s *SettingRepoSuite) TestGetAll() {
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"x": "1", "y": "2"}))
all, err := s.repo.GetAll(s.ctx)
s.Require().NoError(err, "GetAll")
s.Require().GreaterOrEqual(len(all), 2, "expected at least 2 settings")
s.Require().Equal("1", all["x"])
s.Require().Equal("2", all["y"])
}
func (s *SettingRepoSuite) TestDelete() {
s.Require().NoError(s.repo.Set(s.ctx, "todelete", "val"))
s.Require().NoError(s.repo.Delete(s.ctx, "todelete"), "Delete")
_, err := s.repo.GetValue(s.ctx, "todelete")
s.Require().Error(err, "expected missing key error after Delete")
s.Require().ErrorIs(err, service.ErrSettingNotFound)
}
func (s *SettingRepoSuite) TestDelete_Idempotent() {
// Delete a key that doesn't exist should not error
s.Require().NoError(s.repo.Delete(s.ctx, "nonexistent_delete"), "Delete nonexistent should be idempotent")
}
func (s *SettingRepoSuite) TestSetMultiple_Upsert() {
s.Require().NoError(s.repo.Set(s.ctx, "upsert_key", "old_value"))
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"upsert_key": "new_value", "new_key": "new_val"}))
got, err := s.repo.GetValue(s.ctx, "upsert_key")
s.Require().NoError(err)
s.Require().Equal("new_value", got, "SetMultiple should upsert existing key")
got2, err := s.repo.GetValue(s.ctx, "new_key")
s.Require().NoError(err)
s.Require().Equal("new_val", got2)
}
// TestSet_EmptyValue 测试保存空字符串值
// 这是一个回归测试确保可选设置如站点Logo、API端点地址等可以保存为空字符串
func (s *SettingRepoSuite) TestSet_EmptyValue() {
// 测试 Set 方法保存空值
s.Require().NoError(s.repo.Set(s.ctx, "empty_key", ""), "Set with empty value should succeed")
got, err := s.repo.GetValue(s.ctx, "empty_key")
s.Require().NoError(err, "GetValue for empty value")
s.Require().Equal("", got, "empty value should be preserved")
}
// TestSetMultiple_WithEmptyValues 测试批量保存包含空字符串的设置
// 模拟用户保存站点设置时部分字段为空的场景
func (s *SettingRepoSuite) TestSetMultiple_WithEmptyValues() {
// 模拟保存站点设置,部分字段有值,部分字段为空
settings := map[string]string{
"site_name": "AICodex2API",
"site_subtitle": "Subscription to API",
"site_logo": "", // 用户未上传Logo
"api_base_url": "", // 用户未设置API地址
"contact_info": "", // 用户未设置联系方式
"doc_url": "", // 用户未设置文档链接
}
s.Require().NoError(s.repo.SetMultiple(s.ctx, settings), "SetMultiple with empty values should succeed")
// 验证所有值都正确保存
result, err := s.repo.GetMultiple(s.ctx, []string{"site_name", "site_subtitle", "site_logo", "api_base_url", "contact_info", "doc_url"})
s.Require().NoError(err, "GetMultiple after SetMultiple with empty values")
s.Require().Equal("AICodex2API", result["site_name"])
s.Require().Equal("Subscription to API", result["site_subtitle"])
s.Require().Equal("", result["site_logo"], "empty site_logo should be preserved")
s.Require().Equal("", result["api_base_url"], "empty api_base_url should be preserved")
s.Require().Equal("", result["contact_info"], "empty contact_info should be preserved")
s.Require().Equal("", result["doc_url"], "empty doc_url should be preserved")
}
// TestSetMultiple_UpdateToEmpty 测试将已有值更新为空字符串
// 确保用户可以清空之前设置的值
func (s *SettingRepoSuite) TestSetMultiple_UpdateToEmpty() {
// 先设置非空值
s.Require().NoError(s.repo.Set(s.ctx, "clearable_key", "initial_value"))
got, err := s.repo.GetValue(s.ctx, "clearable_key")
s.Require().NoError(err)
s.Require().Equal("initial_value", got)
// 更新为空值
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"clearable_key": ""}), "Update to empty should succeed")
got, err = s.repo.GetValue(s.ctx, "clearable_key")
s.Require().NoError(err)
s.Require().Equal("", got, "value should be updated to empty string")
}

View File

@@ -0,0 +1,216 @@
//go:build integration
package repository
import (
"context"
"fmt"
"strings"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func uniqueSoftDeleteValue(t *testing.T, prefix string) string {
t.Helper()
safeName := strings.NewReplacer("/", "_", " ", "_").Replace(t.Name())
return fmt.Sprintf("%s-%s", prefix, safeName)
}
func createEntUser(t *testing.T, ctx context.Context, client *dbent.Client, email string) *dbent.User {
t.Helper()
u, err := client.User.Create().
SetEmail(email).
SetPasswordHash("test-password-hash").
Save(ctx)
require.NoError(t, err, "create ent user")
return u
}
func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
ctx := context.Background()
// 使用全局 ent client确保软删除验证在实际持久化数据上进行。
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com")
repo := NewAPIKeyRepository(client)
key := &service.APIKey{
UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete"),
Name: "soft-delete",
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, key), "create api key")
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
_, err := repo.GetByID(ctx, key.ID)
require.ErrorIs(t, err, service.ErrAPIKeyNotFound, "deleted rows should be hidden by default")
_, err = client.APIKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx)
require.Error(t, err, "default ent query should not see soft-deleted rows")
require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter")
got, err := client.APIKey.Query().
Where(apikey.IDEQ(key.ID)).
Only(mixins.SkipSoftDelete(ctx))
require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows")
require.NotNil(t, got.DeletedAt, "deleted_at should be set after soft delete")
}
func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
ctx := context.Background()
// 使用全局 ent client避免事务回滚影响幂等性验证。
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com")
repo := NewAPIKeyRepository(client)
key := &service.APIKey{
UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete2"),
Name: "soft-delete2",
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, key), "create api key")
require.NoError(t, repo.Delete(ctx, key.ID), "first delete")
require.NoError(t, repo.Delete(ctx, key.ID), "second delete should be idempotent")
}
func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
ctx := context.Background()
// 使用全局 ent client确保 SkipSoftDelete 的硬删除语义可验证。
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com")
repo := NewAPIKeyRepository(client)
key := &service.APIKey{
UserID: u.ID,
Key: uniqueSoftDeleteValue(t, "sk-soft-delete3"),
Name: "soft-delete3",
Status: service.StatusActive,
}
require.NoError(t, repo.Create(ctx, key), "create api key")
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
// Hard delete using SkipSoftDelete so the hook doesn't convert it to update-deleted_at.
_, err := client.APIKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx))
require.NoError(t, err, "hard delete")
_, err = client.APIKey.Query().
Where(apikey.IDEQ(key.ID)).
Only(mixins.SkipSoftDelete(ctx))
require.True(t, dbent.IsNotFound(err), "expected row to be hard deleted")
}
// --- UserSubscription 软删除测试 ---
func createEntGroup(t *testing.T, ctx context.Context, client *dbent.Client, name string) *dbent.Group {
t.Helper()
g, err := client.Group.Create().
SetName(name).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err, "create ent group")
return g
}
func TestEntSoftDelete_UserSubscription_DefaultFilterAndSkip(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-user")+"@example.com")
g := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group"))
repo := NewUserSubscriptionRepository(client)
sub := &service.UserSubscription{
UserID: u.ID,
GroupID: g.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
require.NoError(t, repo.Create(ctx, sub), "create user subscription")
require.NoError(t, repo.Delete(ctx, sub.ID), "soft delete user subscription")
_, err := repo.GetByID(ctx, sub.ID)
require.Error(t, err, "deleted rows should be hidden by default")
_, err = client.UserSubscription.Query().Where(usersubscription.IDEQ(sub.ID)).Only(ctx)
require.Error(t, err, "default ent query should not see soft-deleted rows")
require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter")
got, err := client.UserSubscription.Query().
Where(usersubscription.IDEQ(sub.ID)).
Only(mixins.SkipSoftDelete(ctx))
require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows")
require.NotNil(t, got.DeletedAt, "deleted_at should be set after soft delete")
}
func TestEntSoftDelete_UserSubscription_DeleteIdempotent(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-user2")+"@example.com")
g := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group2"))
repo := NewUserSubscriptionRepository(client)
sub := &service.UserSubscription{
UserID: u.ID,
GroupID: g.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
require.NoError(t, repo.Create(ctx, sub), "create user subscription")
require.NoError(t, repo.Delete(ctx, sub.ID), "first delete")
require.NoError(t, repo.Delete(ctx, sub.ID), "second delete should be idempotent")
}
func TestEntSoftDelete_UserSubscription_ListExcludesDeleted(t *testing.T) {
ctx := context.Background()
client := testEntClient(t)
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-user3")+"@example.com")
g1 := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group3a"))
g2 := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group3b"))
repo := NewUserSubscriptionRepository(client)
sub1 := &service.UserSubscription{
UserID: u.ID,
GroupID: g1.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
require.NoError(t, repo.Create(ctx, sub1), "create subscription 1")
sub2 := &service.UserSubscription{
UserID: u.ID,
GroupID: g2.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
require.NoError(t, repo.Create(ctx, sub2), "create subscription 2")
// 软删除 sub1
require.NoError(t, repo.Delete(ctx, sub1.ID), "soft delete subscription 1")
// ListByUserID 应只返回未删除的订阅
subs, err := repo.ListByUserID(ctx, u.ID)
require.NoError(t, err, "ListByUserID")
require.Len(t, subs, 1, "should only return non-deleted subscriptions")
require.Equal(t, sub2.ID, subs[0].ID, "expected sub2 to be returned")
}

View File

@@ -0,0 +1,42 @@
package repository
import (
"context"
"database/sql"
"errors"
)
type sqlQueryer interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
// scanSingleRow 执行查询并扫描第一行到 dest。
// 若无结果,可通过 errors.Is(err, sql.ErrNoRows) 判断。
// 如果 Close 失败,会与原始错误合并返回。
// 设计目的:仅依赖 QueryContext避免 QueryRowContext 对 *sql.Tx 的强绑定,
// 让 ent.Tx 也能作为 sqlExecutor/Queryer 使用。
func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any, dest ...any) (err error) {
rows, err := q.QueryContext(ctx, query, args...)
if err != nil {
return err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
err = errors.Join(err, closeErr)
}
}()
if !rows.Next() {
if err = rows.Err(); err != nil {
return err
}
return sql.ErrNoRows
}
if err = rows.Scan(dest...); err != nil {
return err
}
if err = rows.Err(); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,91 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const tempUnschedPrefix = "temp_unsched:account:"
var tempUnschedSetScript = redis.NewScript(`
local key = KEYS[1]
local new_until = tonumber(ARGV[1])
local new_value = ARGV[2]
local new_ttl = tonumber(ARGV[3])
local existing = redis.call('GET', key)
if existing then
local ok, existing_data = pcall(cjson.decode, existing)
if ok and existing_data and existing_data.until_unix then
local existing_until = tonumber(existing_data.until_unix)
if existing_until and new_until <= existing_until then
return 0
end
end
end
redis.call('SET', key, new_value, 'EX', new_ttl)
return 1
`)
type tempUnschedCache struct {
rdb *redis.Client
}
func NewTempUnschedCache(rdb *redis.Client) service.TempUnschedCache {
return &tempUnschedCache{rdb: rdb}
}
// SetTempUnsched 设置临时不可调度状态(只延长不缩短)
func (c *tempUnschedCache) SetTempUnsched(ctx context.Context, accountID int64, state *service.TempUnschedState) error {
key := fmt.Sprintf("%s%d", tempUnschedPrefix, accountID)
stateJSON, err := json.Marshal(state)
if err != nil {
return fmt.Errorf("marshal state: %w", err)
}
ttl := time.Until(time.Unix(state.UntilUnix, 0))
if ttl <= 0 {
return nil // 已过期,不设置
}
ttlSeconds := int(ttl.Seconds())
if ttlSeconds < 1 {
ttlSeconds = 1
}
_, err = tempUnschedSetScript.Run(ctx, c.rdb, []string{key}, state.UntilUnix, string(stateJSON), ttlSeconds).Result()
return err
}
// GetTempUnsched 获取临时不可调度状态
func (c *tempUnschedCache) GetTempUnsched(ctx context.Context, accountID int64) (*service.TempUnschedState, error) {
key := fmt.Sprintf("%s%d", tempUnschedPrefix, accountID)
val, err := c.rdb.Get(ctx, key).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, err
}
var state service.TempUnschedState
if err := json.Unmarshal([]byte(val), &state); err != nil {
return nil, fmt.Errorf("unmarshal state: %w", err)
}
return &state, nil
}
// DeleteTempUnsched 删除临时不可调度状态
func (c *tempUnschedCache) DeleteTempUnsched(ctx context.Context, accountID int64) error {
key := fmt.Sprintf("%s%d", tempUnschedPrefix, accountID)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,80 @@
package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const timeoutCounterPrefix = "timeout_count:account:"
// timeoutCounterIncrScript 使用 Lua 脚本原子性地增加计数并返回当前值
// 如果 key 不存在,则创建并设置过期时间
var timeoutCounterIncrScript = redis.NewScript(`
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local count = redis.call('INCR', key)
if count == 1 then
redis.call('EXPIRE', key, ttl)
end
return count
`)
type timeoutCounterCache struct {
rdb *redis.Client
}
// NewTimeoutCounterCache 创建超时计数器缓存实例
func NewTimeoutCounterCache(rdb *redis.Client) service.TimeoutCounterCache {
return &timeoutCounterCache{rdb: rdb}
}
// IncrementTimeoutCount 增加账户的超时计数,返回当前计数值
// windowMinutes 是计数窗口时间(分钟),超过此时间计数器会自动重置
func (c *timeoutCounterCache) IncrementTimeoutCount(ctx context.Context, accountID int64, windowMinutes int) (int64, error) {
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
ttlSeconds := windowMinutes * 60
if ttlSeconds < 60 {
ttlSeconds = 60 // 最小1分钟
}
result, err := timeoutCounterIncrScript.Run(ctx, c.rdb, []string{key}, ttlSeconds).Int64()
if err != nil {
return 0, fmt.Errorf("increment timeout count: %w", err)
}
return result, nil
}
// GetTimeoutCount 获取账户当前的超时计数
func (c *timeoutCounterCache) GetTimeoutCount(ctx context.Context, accountID int64) (int64, error) {
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
val, err := c.rdb.Get(ctx, key).Int64()
if err == redis.Nil {
return 0, nil
}
if err != nil {
return 0, fmt.Errorf("get timeout count: %w", err)
}
return val, nil
}
// ResetTimeoutCount 重置账户的超时计数
func (c *timeoutCounterCache) ResetTimeoutCount(ctx context.Context, accountID int64) error {
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
return c.rdb.Del(ctx, key).Err()
}
// GetTimeoutCountTTL 获取计数器剩余过期时间
func (c *timeoutCounterCache) GetTimeoutCountTTL(ctx context.Context, accountID int64) (time.Duration, error) {
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
return c.rdb.TTL(ctx, key).Result()
}

View File

@@ -0,0 +1,63 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/service"
)
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
type turnstileVerifier struct {
httpClient *http.Client
verifyURL string
}
func NewTurnstileVerifier() service.TurnstileVerifier {
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 10 * time.Second,
ValidateResolvedIP: true,
})
if err != nil {
sharedClient = &http.Client{Timeout: 10 * time.Second}
}
return &turnstileVerifier{
httpClient: sharedClient,
verifyURL: turnstileVerifyURL,
}
}
func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*service.TurnstileVerifyResponse, error) {
formData := url.Values{}
formData.Set("secret", secretKey)
formData.Set("response", token)
if remoteIP != "" {
formData.Set("remoteip", remoteIP)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, v.verifyURL, strings.NewReader(formData.Encode()))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := v.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer func() { _ = resp.Body.Close() }()
var result service.TurnstileVerifyResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("decode response: %w", err)
}
return &result, nil
}

View File

@@ -0,0 +1,141 @@
package repository
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/url"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type TurnstileServiceSuite struct {
suite.Suite
ctx context.Context
verifier *turnstileVerifier
received chan url.Values
}
func (s *TurnstileServiceSuite) SetupTest() {
s.ctx = context.Background()
s.received = make(chan url.Values, 1)
verifier, ok := NewTurnstileVerifier().(*turnstileVerifier)
require.True(s.T(), ok, "type assertion failed")
s.verifier = verifier
}
func (s *TurnstileServiceSuite) setupTransport(handler http.HandlerFunc) {
s.verifier.verifyURL = "http://in-process/turnstile"
s.verifier.httpClient = &http.Client{
Transport: newInProcessTransport(handler, nil),
}
}
func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Capture form data in main goroutine context later
body, _ := io.ReadAll(r.Body)
values, _ := url.ParseQuery(string(body))
s.received <- values
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
}))
resp, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.NoError(s.T(), err, "VerifyToken")
require.NotNil(s.T(), resp)
require.True(s.T(), resp.Success, "expected success response")
// Assert form fields in main goroutine
select {
case values := <-s.received:
require.Equal(s.T(), "sk", values.Get("secret"))
require.Equal(s.T(), "token", values.Get("response"))
require.Equal(s.T(), "1.1.1.1", values.Get("remoteip"))
default:
require.Fail(s.T(), "expected server to receive request")
}
}
func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() {
var contentType string
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
contentType = r.Header.Get("Content-Type")
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
}))
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.NoError(s.T(), err)
require.True(s.T(), strings.HasPrefix(contentType, "application/x-www-form-urlencoded"), "unexpected content-type: %s", contentType)
}
func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() {
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
values, _ := url.ParseQuery(string(body))
s.received <- values
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
}))
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "")
require.NoError(s.T(), err)
select {
case values := <-s.received:
require.Equal(s.T(), "", values.Get("remoteip"), "remoteip should be empty or not sent")
default:
require.Fail(s.T(), "expected server to receive request")
}
}
func (s *TurnstileServiceSuite) TestVerifyToken_RequestError() {
s.verifier.verifyURL = "http://in-process/turnstile"
s.verifier.httpClient = &http.Client{
Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
return nil, errors.New("dial failed")
}),
}
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.Error(s.T(), err, "expected error when server is closed")
}
func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() {
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, "not-valid-json")
}))
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.Error(s.T(), err, "expected error for invalid JSON response")
}
func (s *TurnstileServiceSuite) TestVerifyToken_SuccessFalse() {
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{
Success: false,
ErrorCodes: []string{"invalid-input-response"},
})
}))
resp, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
require.NoError(s.T(), err, "VerifyToken should not error on success=false")
require.NotNil(s.T(), resp)
require.False(s.T(), resp.Success)
require.Contains(s.T(), resp.ErrorCodes, "invalid-input-response")
}
func TestTurnstileServiceSuite(t *testing.T) {
suite.Run(t, new(TurnstileServiceSuite))
}

View File

@@ -0,0 +1,27 @@
package repository
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const updateCacheKey = "update:latest"
type updateCache struct {
rdb *redis.Client
}
func NewUpdateCache(rdb *redis.Client) service.UpdateCache {
return &updateCache{rdb: rdb}
}
func (c *updateCache) GetUpdateInfo(ctx context.Context) (string, error) {
return c.rdb.Get(ctx, updateCacheKey).Result()
}
func (c *updateCache) SetUpdateInfo(ctx context.Context, data string, ttl time.Duration) error {
return c.rdb.Set(ctx, updateCacheKey, data, ttl).Err()
}

View File

@@ -0,0 +1,73 @@
//go:build integration
package repository
import (
"errors"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type UpdateCacheSuite struct {
IntegrationRedisSuite
cache *updateCache
}
func (s *UpdateCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewUpdateCache(s.rdb).(*updateCache)
}
func (s *UpdateCacheSuite) TestGetUpdateInfo_Missing() {
_, err := s.cache.GetUpdateInfo(s.ctx)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing update info")
}
func (s *UpdateCacheSuite) TestSetAndGetUpdateInfo() {
updateTTL := 5 * time.Minute
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.2.3", updateTTL), "SetUpdateInfo")
info, err := s.cache.GetUpdateInfo(s.ctx)
require.NoError(s.T(), err, "GetUpdateInfo")
require.Equal(s.T(), "v1.2.3", info, "update info mismatch")
}
func (s *UpdateCacheSuite) TestSetUpdateInfo_TTL() {
updateTTL := 5 * time.Minute
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.2.3", updateTTL))
ttl, err := s.rdb.TTL(s.ctx, updateCacheKey).Result()
require.NoError(s.T(), err, "TTL updateCacheKey")
s.AssertTTLWithin(ttl, 1*time.Second, updateTTL)
}
func (s *UpdateCacheSuite) TestSetUpdateInfo_Overwrite() {
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.0.0", 5*time.Minute))
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v2.0.0", 5*time.Minute))
info, err := s.cache.GetUpdateInfo(s.ctx)
require.NoError(s.T(), err)
require.Equal(s.T(), "v2.0.0", info, "expected overwritten value")
}
func (s *UpdateCacheSuite) TestSetUpdateInfo_ZeroTTL() {
// TTL=0 means persist forever (no expiry) in Redis SET command
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v0.0.0", 0))
info, err := s.cache.GetUpdateInfo(s.ctx)
require.NoError(s.T(), err)
require.Equal(s.T(), "v0.0.0", info)
ttl, err := s.rdb.TTL(s.ctx, updateCacheKey).Result()
require.NoError(s.T(), err)
// TTL=-1 means no expiry, TTL=-2 means key doesn't exist
require.Equal(s.T(), time.Duration(-1), ttl, "expected TTL=-1 for key with no expiry")
}
func TestUpdateCacheSuite(t *testing.T) {
suite.Run(t, new(UpdateCacheSuite))
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,385 @@
package repository
import (
"context"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// UserAttributeDefinitionRepository implementation
type userAttributeDefinitionRepository struct {
client *dbent.Client
}
// NewUserAttributeDefinitionRepository creates a new repository instance
func NewUserAttributeDefinitionRepository(client *dbent.Client) service.UserAttributeDefinitionRepository {
return &userAttributeDefinitionRepository{client: client}
}
func (r *userAttributeDefinitionRepository) Create(ctx context.Context, def *service.UserAttributeDefinition) error {
client := clientFromContext(ctx, r.client)
created, err := client.UserAttributeDefinition.Create().
SetKey(def.Key).
SetName(def.Name).
SetDescription(def.Description).
SetType(string(def.Type)).
SetOptions(toEntOptions(def.Options)).
SetRequired(def.Required).
SetValidation(toEntValidation(def.Validation)).
SetPlaceholder(def.Placeholder).
SetEnabled(def.Enabled).
Save(ctx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrAttributeKeyExists)
}
def.ID = created.ID
def.DisplayOrder = created.DisplayOrder
def.CreatedAt = created.CreatedAt
def.UpdatedAt = created.UpdatedAt
return nil
}
func (r *userAttributeDefinitionRepository) GetByID(ctx context.Context, id int64) (*service.UserAttributeDefinition, error) {
client := clientFromContext(ctx, r.client)
e, err := client.UserAttributeDefinition.Query().
Where(userattributedefinition.IDEQ(id)).
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
}
return defEntityToService(e), nil
}
func (r *userAttributeDefinitionRepository) GetByKey(ctx context.Context, key string) (*service.UserAttributeDefinition, error) {
client := clientFromContext(ctx, r.client)
e, err := client.UserAttributeDefinition.Query().
Where(userattributedefinition.KeyEQ(key)).
Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
}
return defEntityToService(e), nil
}
func (r *userAttributeDefinitionRepository) Update(ctx context.Context, def *service.UserAttributeDefinition) error {
client := clientFromContext(ctx, r.client)
updated, err := client.UserAttributeDefinition.UpdateOneID(def.ID).
SetName(def.Name).
SetDescription(def.Description).
SetType(string(def.Type)).
SetOptions(toEntOptions(def.Options)).
SetRequired(def.Required).
SetValidation(toEntValidation(def.Validation)).
SetPlaceholder(def.Placeholder).
SetDisplayOrder(def.DisplayOrder).
SetEnabled(def.Enabled).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, service.ErrAttributeKeyExists)
}
def.UpdatedAt = updated.UpdatedAt
return nil
}
func (r *userAttributeDefinitionRepository) Delete(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.UserAttributeDefinition.Delete().
Where(userattributedefinition.IDEQ(id)).
Exec(ctx)
return translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
}
func (r *userAttributeDefinitionRepository) List(ctx context.Context, enabledOnly bool) ([]service.UserAttributeDefinition, error) {
client := clientFromContext(ctx, r.client)
q := client.UserAttributeDefinition.Query()
if enabledOnly {
q = q.Where(userattributedefinition.EnabledEQ(true))
}
entities, err := q.Order(dbent.Asc(userattributedefinition.FieldDisplayOrder)).All(ctx)
if err != nil {
return nil, err
}
result := make([]service.UserAttributeDefinition, 0, len(entities))
for _, e := range entities {
result = append(result, *defEntityToService(e))
}
return result, nil
}
func (r *userAttributeDefinitionRepository) UpdateDisplayOrders(ctx context.Context, orders map[int64]int) error {
tx, err := r.client.Tx(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
for id, order := range orders {
if _, err := tx.UserAttributeDefinition.UpdateOneID(id).
SetDisplayOrder(order).
Save(ctx); err != nil {
return translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
}
}
return tx.Commit()
}
func (r *userAttributeDefinitionRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
client := clientFromContext(ctx, r.client)
return client.UserAttributeDefinition.Query().
Where(userattributedefinition.KeyEQ(key)).
Exist(ctx)
}
// UserAttributeValueRepository implementation
type userAttributeValueRepository struct {
client *dbent.Client
}
// NewUserAttributeValueRepository creates a new repository instance
func NewUserAttributeValueRepository(client *dbent.Client) service.UserAttributeValueRepository {
return &userAttributeValueRepository{client: client}
}
func (r *userAttributeValueRepository) GetByUserID(ctx context.Context, userID int64) ([]service.UserAttributeValue, error) {
client := clientFromContext(ctx, r.client)
entities, err := client.UserAttributeValue.Query().
Where(userattributevalue.UserIDEQ(userID)).
All(ctx)
if err != nil {
return nil, err
}
result := make([]service.UserAttributeValue, 0, len(entities))
for _, e := range entities {
result = append(result, service.UserAttributeValue{
ID: e.ID,
UserID: e.UserID,
AttributeID: e.AttributeID,
Value: e.Value,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
})
}
return result, nil
}
func (r *userAttributeValueRepository) GetByUserIDs(ctx context.Context, userIDs []int64) ([]service.UserAttributeValue, error) {
if len(userIDs) == 0 {
return []service.UserAttributeValue{}, nil
}
client := clientFromContext(ctx, r.client)
entities, err := client.UserAttributeValue.Query().
Where(userattributevalue.UserIDIn(userIDs...)).
All(ctx)
if err != nil {
return nil, err
}
result := make([]service.UserAttributeValue, 0, len(entities))
for _, e := range entities {
result = append(result, service.UserAttributeValue{
ID: e.ID,
UserID: e.UserID,
AttributeID: e.AttributeID,
Value: e.Value,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
})
}
return result, nil
}
func (r *userAttributeValueRepository) UpsertBatch(ctx context.Context, userID int64, inputs []service.UpdateUserAttributeInput) error {
if len(inputs) == 0 {
return nil
}
tx, err := r.client.Tx(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
for _, input := range inputs {
// Use upsert (ON CONFLICT DO UPDATE)
err := tx.UserAttributeValue.Create().
SetUserID(userID).
SetAttributeID(input.AttributeID).
SetValue(input.Value).
OnConflictColumns(userattributevalue.FieldUserID, userattributevalue.FieldAttributeID).
UpdateValue().
UpdateUpdatedAt().
Exec(ctx)
if err != nil {
return err
}
}
return tx.Commit()
}
func (r *userAttributeValueRepository) DeleteByAttributeID(ctx context.Context, attributeID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.UserAttributeValue.Delete().
Where(userattributevalue.AttributeIDEQ(attributeID)).
Exec(ctx)
return err
}
func (r *userAttributeValueRepository) DeleteByUserID(ctx context.Context, userID int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.UserAttributeValue.Delete().
Where(userattributevalue.UserIDEQ(userID)).
Exec(ctx)
return err
}
// Helper functions for entity to service conversion
func defEntityToService(e *dbent.UserAttributeDefinition) *service.UserAttributeDefinition {
if e == nil {
return nil
}
return &service.UserAttributeDefinition{
ID: e.ID,
Key: e.Key,
Name: e.Name,
Description: e.Description,
Type: service.UserAttributeType(e.Type),
Options: toServiceOptions(e.Options),
Required: e.Required,
Validation: toServiceValidation(e.Validation),
Placeholder: e.Placeholder,
DisplayOrder: e.DisplayOrder,
Enabled: e.Enabled,
CreatedAt: e.CreatedAt,
UpdatedAt: e.UpdatedAt,
}
}
// Type conversion helpers (map types <-> service types)
func toEntOptions(opts []service.UserAttributeOption) []map[string]any {
if opts == nil {
return []map[string]any{}
}
result := make([]map[string]any, len(opts))
for i, o := range opts {
result[i] = map[string]any{"value": o.Value, "label": o.Label}
}
return result
}
func toServiceOptions(opts []map[string]any) []service.UserAttributeOption {
if opts == nil {
return []service.UserAttributeOption{}
}
result := make([]service.UserAttributeOption, len(opts))
for i, o := range opts {
result[i] = service.UserAttributeOption{
Value: getString(o, "value"),
Label: getString(o, "label"),
}
}
return result
}
func toEntValidation(v service.UserAttributeValidation) map[string]any {
result := map[string]any{}
if v.MinLength != nil {
result["min_length"] = *v.MinLength
}
if v.MaxLength != nil {
result["max_length"] = *v.MaxLength
}
if v.Min != nil {
result["min"] = *v.Min
}
if v.Max != nil {
result["max"] = *v.Max
}
if v.Pattern != nil {
result["pattern"] = *v.Pattern
}
if v.Message != nil {
result["message"] = *v.Message
}
return result
}
func toServiceValidation(v map[string]any) service.UserAttributeValidation {
result := service.UserAttributeValidation{}
if val := getInt(v, "min_length"); val != nil {
result.MinLength = val
}
if val := getInt(v, "max_length"); val != nil {
result.MaxLength = val
}
if val := getInt(v, "min"); val != nil {
result.Min = val
}
if val := getInt(v, "max"); val != nil {
result.Max = val
}
if val := getStringPtr(v, "pattern"); val != nil {
result.Pattern = val
}
if val := getStringPtr(v, "message"); val != nil {
result.Message = val
}
return result
}
// Helper functions for type conversion
func getString(m map[string]any, key string) string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
func getStringPtr(m map[string]any, key string) *string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return &s
}
}
return nil
}
func getInt(m map[string]any, key string) *int {
if v, ok := m[key]; ok {
switch n := v.(type) {
case int:
return &n
case int64:
i := int(n)
return &i
case float64:
i := int(n)
return &i
}
}
return nil
}

View File

@@ -0,0 +1,468 @@
package repository
import (
"context"
"database/sql"
"errors"
"fmt"
"sort"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type userRepository struct {
client *dbent.Client
sql sqlExecutor
}
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
return newUserRepositoryWithSQL(client, sqlDB)
}
func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository {
return &userRepository{client: client, sql: sqlq}
}
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
if userIn == nil {
return nil
}
// 统一使用 ent 的事务:保证用户与允许分组的更新原子化,
// 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return err
}
var txClient *dbent.Client
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
} else {
// 已处于外部事务中ErrTxStarted复用当前 client 并由调用方负责提交/回滚。
txClient = r.client
}
created, err := txClient.User.Create().
SetEmail(userIn.Email).
SetUsername(userIn.Username).
SetNotes(userIn.Notes).
SetPasswordHash(userIn.PasswordHash).
SetRole(userIn.Role).
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
Save(ctx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists)
}
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil {
return err
}
if tx != nil {
if err := tx.Commit(); err != nil {
return err
}
}
applyUserEntityToService(userIn, created)
return nil
}
func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, error) {
m, err := r.client.User.Query().Where(dbuser.IDEQ(id)).Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{id})
if err != nil {
return nil, err
}
if v, ok := groups[id]; ok {
out.AllowedGroups = v
}
return out, nil
}
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
m, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
if err != nil {
return nil, err
}
if v, ok := groups[m.ID]; ok {
out.AllowedGroups = v
}
return out, nil
}
func (r *userRepository) Update(ctx context.Context, userIn *service.User) error {
if userIn == nil {
return nil
}
// 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return err
}
var txClient *dbent.Client
if err == nil {
defer func() { _ = tx.Rollback() }()
txClient = tx.Client()
} else {
// 已处于外部事务中ErrTxStarted复用当前 client 并由调用方负责提交/回滚。
txClient = r.client
}
updated, err := txClient.User.UpdateOneID(userIn.ID).
SetEmail(userIn.Email).
SetUsername(userIn.Username).
SetNotes(userIn.Notes).
SetPasswordHash(userIn.PasswordHash).
SetRole(userIn.Role).
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
}
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
return err
}
if tx != nil {
if err := tx.Commit(); err != nil {
return err
}
}
userIn.UpdatedAt = updated.UpdatedAt
return nil
}
func (r *userRepository) Delete(ctx context.Context, id int64) error {
affected, err := r.client.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if affected == 0 {
return service.ErrUserNotFound
}
return nil
}
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, service.UserListFilters{})
}
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
q := r.client.User.Query()
if filters.Status != "" {
q = q.Where(dbuser.StatusEQ(filters.Status))
}
if filters.Role != "" {
q = q.Where(dbuser.RoleEQ(filters.Role))
}
if filters.Search != "" {
q = q.Where(
dbuser.Or(
dbuser.EmailContainsFold(filters.Search),
dbuser.UsernameContainsFold(filters.Search),
),
)
}
// If attribute filters are specified, we need to filter by user IDs first
var allowedUserIDs []int64
if len(filters.Attributes) > 0 {
var attrErr error
allowedUserIDs, attrErr = r.filterUsersByAttributes(ctx, filters.Attributes)
if attrErr != nil {
return nil, nil, attrErr
}
if len(allowedUserIDs) == 0 {
// No users match the attribute filters
return []service.User{}, paginationResultFromTotal(0, params), nil
}
q = q.Where(dbuser.IDIn(allowedUserIDs...))
}
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, nil, err
}
users, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(dbuser.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outUsers := make([]service.User, 0, len(users))
if len(users) == 0 {
return outUsers, paginationResultFromTotal(int64(total), params), nil
}
userIDs := make([]int64, 0, len(users))
userMap := make(map[int64]*service.User, len(users))
for i := range users {
userIDs = append(userIDs, users[i].ID)
u := userEntityToService(users[i])
outUsers = append(outUsers, *u)
userMap[u.ID] = &outUsers[len(outUsers)-1]
}
// Batch load active subscriptions with groups to avoid N+1.
subs, err := r.client.UserSubscription.Query().
Where(
usersubscription.UserIDIn(userIDs...),
usersubscription.StatusEQ(service.SubscriptionStatusActive),
).
WithGroup().
All(ctx)
if err != nil {
return nil, nil, err
}
for i := range subs {
if u, ok := userMap[subs[i].UserID]; ok {
u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i]))
}
}
allowedGroupsByUser, err := r.loadAllowedGroups(ctx, userIDs)
if err != nil {
return nil, nil, err
}
for id, u := range userMap {
if groups, ok := allowedGroupsByUser[id]; ok {
u.AllowedGroups = groups
}
}
return outUsers, paginationResultFromTotal(int64(total), params), nil
}
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
if len(attrs) == 0 {
return nil, nil
}
if r.sql == nil {
return nil, fmt.Errorf("sql executor is not configured")
}
clauses := make([]string, 0, len(attrs))
args := make([]any, 0, len(attrs)*2+1)
argIndex := 1
for attrID, value := range attrs {
clauses = append(clauses, fmt.Sprintf("(attribute_id = $%d AND value ILIKE $%d)", argIndex, argIndex+1))
args = append(args, attrID, "%"+value+"%")
argIndex += 2
}
query := fmt.Sprintf(
`SELECT user_id
FROM user_attribute_values
WHERE %s
GROUP BY user_id
HAVING COUNT(DISTINCT attribute_id) = $%d`,
strings.Join(clauses, " OR "),
argIndex,
)
args = append(args, len(attrs))
rows, err := r.sql.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
result := make([]int64, 0)
for rows.Next() {
var userID int64
if scanErr := rows.Scan(&userID); scanErr != nil {
return nil, scanErr
}
result = append(result, userID)
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
client := clientFromContext(ctx, r.client)
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if n == 0 {
return service.ErrUserNotFound
}
return nil
}
// DeductBalance 扣除用户余额
// 透支策略:允许余额变为负数,确保当前请求能够完成
// 中间件会阻止余额 <= 0 的用户发起后续请求
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
client := clientFromContext(ctx, r.client)
n, err := client.User.Update().
Where(dbuser.IDEQ(id)).
AddBalance(-amount).
Save(ctx)
if err != nil {
return err
}
if n == 0 {
return service.ErrUserNotFound
}
return nil
}
func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
client := clientFromContext(ctx, r.client)
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddConcurrency(amount).Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil)
}
if n == 0 {
return service.ErrUserNotFound
}
return nil
}
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
}
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
// 仅操作 user_allowed_groups 联接表legacy users.allowed_groups 列已弃用。
affected, err := r.client.UserAllowedGroup.Delete().
Where(userallowedgroup.GroupIDEQ(groupID)).
Exec(ctx)
if err != nil {
return 0, err
}
return int64(affected), nil
}
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) {
m, err := r.client.User.Query().
Where(
dbuser.RoleEQ(service.RoleAdmin),
dbuser.StatusEQ(service.StatusActive),
).
Order(dbent.Asc(dbuser.FieldID)).
First(ctx)
if err != nil {
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
}
out := userEntityToService(m)
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
if err != nil {
return nil, err
}
if v, ok := groups[m.ID]; ok {
out.AllowedGroups = v
}
return out, nil
}
func (r *userRepository) loadAllowedGroups(ctx context.Context, userIDs []int64) (map[int64][]int64, error) {
out := make(map[int64][]int64, len(userIDs))
if len(userIDs) == 0 {
return out, nil
}
rows, err := r.client.UserAllowedGroup.Query().
Where(userallowedgroup.UserIDIn(userIDs...)).
All(ctx)
if err != nil {
return nil, err
}
for i := range rows {
out[rows[i].UserID] = append(out[rows[i].UserID], rows[i].GroupID)
}
for userID := range out {
sort.Slice(out[userID], func(i, j int) bool { return out[userID][i] < out[userID][j] })
}
return out, nil
}
// syncUserAllowedGroupsWithClient 在 ent client/事务内同步用户允许分组:
// 仅操作 user_allowed_groups 联接表legacy users.allowed_groups 列已弃用。
func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, client *dbent.Client, userID int64, groupIDs []int64) error {
if client == nil {
return nil
}
// Keep join table as the source of truth for reads.
if _, err := client.UserAllowedGroup.Delete().Where(userallowedgroup.UserIDEQ(userID)).Exec(ctx); err != nil {
return err
}
unique := make(map[int64]struct{}, len(groupIDs))
for _, id := range groupIDs {
if id <= 0 {
continue
}
unique[id] = struct{}{}
}
if len(unique) > 0 {
creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique))
for groupID := range unique {
creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID))
}
if err := client.UserAllowedGroup.
CreateBulk(creates...).
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
DoNothing().
Exec(ctx); err != nil {
return err
}
}
return nil
}
func applyUserEntityToService(dst *service.User, src *dbent.User) {
if dst == nil || src == nil {
return
}
dst.ID = src.ID
dst.CreatedAt = src.CreatedAt
dst.UpdatedAt = src.UpdatedAt
}

View File

@@ -0,0 +1,537 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
)
type UserRepoSuite struct {
suite.Suite
ctx context.Context
client *dbent.Client
repo *userRepository
}
func (s *UserRepoSuite) SetupTest() {
s.ctx = context.Background()
s.client = testEntClient(s.T())
s.repo = newUserRepositoryWithSQL(s.client, integrationDB)
// 清理测试数据,确保每个测试从干净状态开始
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_subscriptions")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_allowed_groups")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM users")
}
func TestUserRepoSuite(t *testing.T) {
suite.Run(t, new(UserRepoSuite))
}
func (s *UserRepoSuite) mustCreateUser(u *service.User) *service.User {
s.T().Helper()
if u.Email == "" {
u.Email = "user-" + time.Now().Format(time.RFC3339Nano) + "@example.com"
}
if u.PasswordHash == "" {
u.PasswordHash = "test-password-hash"
}
if u.Role == "" {
u.Role = service.RoleUser
}
if u.Status == "" {
u.Status = service.StatusActive
}
if u.Concurrency == 0 {
u.Concurrency = 5
}
s.Require().NoError(s.repo.Create(s.ctx, u), "create user")
return u
}
func (s *UserRepoSuite) mustCreateGroup(name string) *service.Group {
s.T().Helper()
g, err := s.client.Group.Create().
SetName(name).
SetStatus(service.StatusActive).
Save(s.ctx)
s.Require().NoError(err, "create group")
return groupEntityToService(g)
}
func (s *UserRepoSuite) mustCreateSubscription(userID, groupID int64, mutate func(*dbent.UserSubscriptionCreate)) *dbent.UserSubscription {
s.T().Helper()
now := time.Now()
create := s.client.UserSubscription.Create().
SetUserID(userID).
SetGroupID(groupID).
SetStartsAt(now.Add(-1 * time.Hour)).
SetExpiresAt(now.Add(24 * time.Hour)).
SetStatus(service.SubscriptionStatusActive).
SetAssignedAt(now).
SetNotes("")
if mutate != nil {
mutate(create)
}
sub, err := create.Save(s.ctx)
s.Require().NoError(err, "create subscription")
return sub
}
// --- Create / GetByID / GetByEmail / Update / Delete ---
func (s *UserRepoSuite) TestCreate() {
user := s.mustCreateUser(&service.User{
Email: "create@test.com",
Username: "testuser",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
})
s.Require().NotZero(user.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("create@test.com", got.Email)
}
func (s *UserRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *UserRepoSuite) TestGetByEmail() {
user := s.mustCreateUser(&service.User{Email: "byemail@test.com"})
got, err := s.repo.GetByEmail(s.ctx, user.Email)
s.Require().NoError(err, "GetByEmail")
s.Require().Equal(user.ID, got.ID)
}
func (s *UserRepoSuite) TestGetByEmail_NotFound() {
_, err := s.repo.GetByEmail(s.ctx, "nonexistent@test.com")
s.Require().Error(err, "expected error for non-existent email")
}
func (s *UserRepoSuite) TestUpdate() {
user := s.mustCreateUser(&service.User{Email: "update@test.com", Username: "original"})
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
got.Username = "updated"
s.Require().NoError(s.repo.Update(s.ctx, got), "Update")
updated, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated", updated.Username)
}
func (s *UserRepoSuite) TestDelete() {
user := s.mustCreateUser(&service.User{Email: "delete@test.com"})
err := s.repo.Delete(s.ctx, user.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, user.ID)
s.Require().Error(err, "expected error after delete")
}
// --- List / ListWithFilters ---
func (s *UserRepoSuite) TestList() {
s.mustCreateUser(&service.User{Email: "list1@test.com"})
s.mustCreateUser(&service.User{Email: "list2@test.com"})
users, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
s.Require().Len(users, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *UserRepoSuite) TestListWithFilters_Status() {
s.mustCreateUser(&service.User{Email: "active@test.com", Status: service.StatusActive})
s.mustCreateUser(&service.User{Email: "disabled@test.com", Status: service.StatusDisabled})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Status: service.StatusActive})
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal(service.StatusActive, users[0].Status)
}
func (s *UserRepoSuite) TestListWithFilters_Role() {
s.mustCreateUser(&service.User{Email: "user@test.com", Role: service.RoleUser})
s.mustCreateUser(&service.User{Email: "admin@test.com", Role: service.RoleAdmin})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Role: service.RoleAdmin})
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal(service.RoleAdmin, users[0].Role)
}
func (s *UserRepoSuite) TestListWithFilters_Search() {
s.mustCreateUser(&service.User{Email: "alice@test.com", Username: "Alice"})
s.mustCreateUser(&service.User{Email: "bob@test.com", Username: "Bob"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "alice"})
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Contains(users[0].Email, "alice")
}
func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
s.mustCreateUser(&service.User{Email: "u1@test.com", Username: "JohnDoe"})
s.mustCreateUser(&service.User{Email: "u2@test.com", Username: "JaneSmith"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "john"})
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal("JohnDoe", users[0].Username)
}
func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
user := s.mustCreateUser(&service.User{Email: "sub@test.com", Status: service.StatusActive})
groupActive := s.mustCreateGroup("g-sub-active")
groupExpired := s.mustCreateGroup("g-sub-expired")
_ = s.mustCreateSubscription(user.ID, groupActive.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetStatus(service.SubscriptionStatusActive)
c.SetExpiresAt(time.Now().Add(1 * time.Hour))
})
_ = s.mustCreateSubscription(user.ID, groupExpired.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetStatus(service.SubscriptionStatusExpired)
c.SetExpiresAt(time.Now().Add(-1 * time.Hour))
})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "sub@"})
s.Require().NoError(err, "ListWithFilters")
s.Require().Len(users, 1, "expected 1 user")
s.Require().Len(users[0].Subscriptions, 1, "expected 1 active subscription")
s.Require().NotNil(users[0].Subscriptions[0].Group, "expected subscription group preload")
s.Require().Equal(groupActive.ID, users[0].Subscriptions[0].Group.ID, "group ID mismatch")
}
func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
s.mustCreateUser(&service.User{
Email: "a@example.com",
Username: "Alice",
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
})
target := s.mustCreateUser(&service.User{
Email: "b@example.com",
Username: "Bob",
Role: service.RoleAdmin,
Status: service.StatusActive,
Balance: 1,
})
s.mustCreateUser(&service.User{
Email: "c@example.com",
Role: service.RoleAdmin,
Status: service.StatusDisabled,
})
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Status: service.StatusActive, Role: service.RoleAdmin, Search: "b@"})
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch")
s.Require().Equal(target.ID, users[0].ID, "ListWithFilters result mismatch")
}
// --- Balance operations ---
func (s *UserRepoSuite) TestUpdateBalance() {
user := s.mustCreateUser(&service.User{Email: "bal@test.com", Balance: 10})
err := s.repo.UpdateBalance(s.ctx, user.ID, 2.5)
s.Require().NoError(err, "UpdateBalance")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().InDelta(12.5, got.Balance, 1e-6)
}
func (s *UserRepoSuite) TestUpdateBalance_Negative() {
user := s.mustCreateUser(&service.User{Email: "balneg@test.com", Balance: 10})
err := s.repo.UpdateBalance(s.ctx, user.ID, -3)
s.Require().NoError(err, "UpdateBalance with negative")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().InDelta(7.0, got.Balance, 1e-6)
}
func (s *UserRepoSuite) TestDeductBalance() {
user := s.mustCreateUser(&service.User{Email: "deduct@test.com", Balance: 10})
err := s.repo.DeductBalance(s.ctx, user.ID, 5)
s.Require().NoError(err, "DeductBalance")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().InDelta(5.0, got.Balance, 1e-6)
}
func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
user := s.mustCreateUser(&service.User{Email: "insuf@test.com", Balance: 5})
// 透支策略:允许扣除超过余额的金额
err := s.repo.DeductBalance(s.ctx, user.ID, 999)
s.Require().NoError(err, "DeductBalance should allow overdraft")
// 验证余额变为负数
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().InDelta(-994.0, got.Balance, 1e-6, "Balance should be negative after overdraft")
}
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
user := s.mustCreateUser(&service.User{Email: "exact@test.com", Balance: 10})
err := s.repo.DeductBalance(s.ctx, user.ID, 10)
s.Require().NoError(err, "DeductBalance exact amount")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().InDelta(0.0, got.Balance, 1e-6)
}
func (s *UserRepoSuite) TestDeductBalance_AllowsOverdraft() {
user := s.mustCreateUser(&service.User{Email: "overdraft@test.com", Balance: 5.0})
// 扣除超过余额的金额 - 应该成功
err := s.repo.DeductBalance(s.ctx, user.ID, 10.0)
s.Require().NoError(err, "DeductBalance should allow overdraft")
// 验证余额为负
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().InDelta(-5.0, got.Balance, 1e-6, "Balance should be -5.0 after overdraft")
}
// --- Concurrency ---
func (s *UserRepoSuite) TestUpdateConcurrency() {
user := s.mustCreateUser(&service.User{Email: "conc@test.com", Concurrency: 5})
err := s.repo.UpdateConcurrency(s.ctx, user.ID, 3)
s.Require().NoError(err, "UpdateConcurrency")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Equal(8, got.Concurrency)
}
func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
user := s.mustCreateUser(&service.User{Email: "concneg@test.com", Concurrency: 5})
err := s.repo.UpdateConcurrency(s.ctx, user.ID, -2)
s.Require().NoError(err, "UpdateConcurrency negative")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Equal(3, got.Concurrency)
}
// --- ExistsByEmail ---
func (s *UserRepoSuite) TestExistsByEmail() {
s.mustCreateUser(&service.User{Email: "exists@test.com"})
exists, err := s.repo.ExistsByEmail(s.ctx, "exists@test.com")
s.Require().NoError(err, "ExistsByEmail")
s.Require().True(exists)
notExists, err := s.repo.ExistsByEmail(s.ctx, "notexists@test.com")
s.Require().NoError(err)
s.Require().False(notExists)
}
// --- RemoveGroupFromAllowedGroups ---
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() {
target := s.mustCreateGroup("target-42")
other := s.mustCreateGroup("other-7")
userA := s.mustCreateUser(&service.User{
Email: "a1@example.com",
AllowedGroups: []int64{target.ID, other.ID},
})
s.mustCreateUser(&service.User{
Email: "a2@example.com",
AllowedGroups: []int64{other.ID},
})
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, target.ID)
s.Require().NoError(err, "RemoveGroupFromAllowedGroups")
s.Require().Equal(int64(1), affected, "expected 1 affected row")
got, err := s.repo.GetByID(s.ctx, userA.ID)
s.Require().NoError(err, "GetByID")
s.Require().NotContains(got.AllowedGroups, target.ID)
s.Require().Contains(got.AllowedGroups, other.ID)
}
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
groupA := s.mustCreateGroup("nomatch-a")
groupB := s.mustCreateGroup("nomatch-b")
s.mustCreateUser(&service.User{
Email: "nomatch@test.com",
AllowedGroups: []int64{groupA.ID, groupB.ID},
})
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, 999999)
s.Require().NoError(err)
s.Require().Zero(affected, "expected no affected rows")
}
// --- GetFirstAdmin ---
func (s *UserRepoSuite) TestGetFirstAdmin() {
admin1 := s.mustCreateUser(&service.User{
Email: "admin1@example.com",
Role: service.RoleAdmin,
Status: service.StatusActive,
})
s.mustCreateUser(&service.User{
Email: "admin2@example.com",
Role: service.RoleAdmin,
Status: service.StatusActive,
})
got, err := s.repo.GetFirstAdmin(s.ctx)
s.Require().NoError(err, "GetFirstAdmin")
s.Require().Equal(admin1.ID, got.ID, "GetFirstAdmin mismatch")
}
func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
s.mustCreateUser(&service.User{
Email: "user@example.com",
Role: service.RoleUser,
Status: service.StatusActive,
})
_, err := s.repo.GetFirstAdmin(s.ctx)
s.Require().Error(err, "expected error when no admin exists")
}
func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
s.mustCreateUser(&service.User{
Email: "disabled@example.com",
Role: service.RoleAdmin,
Status: service.StatusDisabled,
})
activeAdmin := s.mustCreateUser(&service.User{
Email: "active@example.com",
Role: service.RoleAdmin,
Status: service.StatusActive,
})
got, err := s.repo.GetFirstAdmin(s.ctx)
s.Require().NoError(err, "GetFirstAdmin")
s.Require().Equal(activeAdmin.ID, got.ID, "should return only active admin")
}
// --- Combined ---
func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
user1 := s.mustCreateUser(&service.User{
Email: "a@example.com",
Username: "Alice",
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
})
user2 := s.mustCreateUser(&service.User{
Email: "b@example.com",
Username: "Bob",
Role: service.RoleAdmin,
Status: service.StatusActive,
Balance: 1,
})
s.mustCreateUser(&service.User{
Email: "c@example.com",
Role: service.RoleAdmin,
Status: service.StatusDisabled,
})
got, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal(user1.Email, got.Email, "GetByID email mismatch")
gotByEmail, err := s.repo.GetByEmail(s.ctx, user2.Email)
s.Require().NoError(err, "GetByEmail")
s.Require().Equal(user2.ID, gotByEmail.ID, "GetByEmail ID mismatch")
got.Username = "Alice2"
s.Require().NoError(s.repo.Update(s.ctx, got), "Update")
got2, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("Alice2", got2.Username, "Update did not persist")
s.Require().NoError(s.repo.UpdateBalance(s.ctx, user1.ID, 2.5), "UpdateBalance")
got3, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after UpdateBalance")
s.Require().InDelta(12.5, got3.Balance, 1e-6)
s.Require().NoError(s.repo.DeductBalance(s.ctx, user1.ID, 5), "DeductBalance")
got4, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after DeductBalance")
s.Require().InDelta(7.5, got4.Balance, 1e-6)
// 透支策略:允许扣除超过余额的金额
err = s.repo.DeductBalance(s.ctx, user1.ID, 999)
s.Require().NoError(err, "DeductBalance should allow overdraft")
gotOverdraft, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after overdraft")
s.Require().Less(gotOverdraft.Balance, 0.0, "Balance should be negative after overdraft")
s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency")
got5, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after UpdateConcurrency")
s.Require().Equal(user1.Concurrency+3, got5.Concurrency)
params := pagination.PaginationParams{Page: 1, PageSize: 10}
users, page, err := s.repo.ListWithFilters(s.ctx, params, service.UserListFilters{Status: service.StatusActive, Role: service.RoleAdmin, Search: "b@"})
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch")
s.Require().Equal(user2.ID, users[0].ID, "ListWithFilters result mismatch")
}
// --- UpdateBalance/UpdateConcurrency 影响行数校验测试 ---
func (s *UserRepoSuite) TestUpdateBalance_NotFound() {
err := s.repo.UpdateBalance(s.ctx, 999999, 10.0)
s.Require().Error(err, "expected error for non-existent user")
s.Require().ErrorIs(err, service.ErrUserNotFound)
}
func (s *UserRepoSuite) TestUpdateConcurrency_NotFound() {
err := s.repo.UpdateConcurrency(s.ctx, 999999, 5)
s.Require().Error(err, "expected error for non-existent user")
s.Require().ErrorIs(err, service.ErrUserNotFound)
}
func (s *UserRepoSuite) TestDeductBalance_NotFound() {
err := s.repo.DeductBalance(s.ctx, 999999, 5)
s.Require().Error(err, "expected error for non-existent user")
// DeductBalance 在用户不存在时返回 ErrUserNotFound
s.Require().ErrorIs(err, service.ErrUserNotFound)
}

Some files were not shown because too many files have changed in this diff Show More