first commit
This commit is contained in:
1425
backend/internal/repository/account_repo.go
Normal file
1425
backend/internal/repository/account_repo.go
Normal file
File diff suppressed because it is too large
Load Diff
587
backend/internal/repository/account_repo_integration_test.go
Normal file
587
backend/internal/repository/account_repo_integration_test.go
Normal file
@@ -0,0 +1,587 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/accountgroup"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type AccountRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
repo *accountRepository
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.repo = newAccountRepositoryWithSQL(s.client, tx)
|
||||
}
|
||||
|
||||
func TestAccountRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(AccountRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / GetByID / Update / Delete ---
|
||||
|
||||
func (s *AccountRepoSuite) TestCreate() {
|
||||
account := &service.Account{
|
||||
Name: "test-create",
|
||||
Platform: service.PlatformAnthropic,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{},
|
||||
Extra: map[string]any{},
|
||||
Concurrency: 3,
|
||||
Priority: 50,
|
||||
Schedulable: true,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, account)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(account.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("test-create", got.Name)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdate() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "original"})
|
||||
|
||||
account.Name = "updated"
|
||||
err := s.repo.Update(s.ctx, account)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("updated", got.Name)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
|
||||
|
||||
err := s.repo.Delete(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete_WithGroupBindings() {
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-del"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-del"})
|
||||
mustBindAccountToGroup(s.T(), s.client, account.ID, group.ID, 1)
|
||||
|
||||
err := s.repo.Delete(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "Delete should cascade remove bindings")
|
||||
|
||||
count, err := s.client.AccountGroup.Query().Where(accountgroup.AccountIDEQ(account.ID)).Count(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(count, "expected bindings to be removed")
|
||||
}
|
||||
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *AccountRepoSuite) TestList() {
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc1"})
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc2"})
|
||||
|
||||
accounts, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(accounts, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(client *dbent.Client)
|
||||
platform string
|
||||
accType string
|
||||
status string
|
||||
search string
|
||||
wantCount int
|
||||
validate func(accounts []service.Account)
|
||||
}{
|
||||
{
|
||||
name: "filter_by_platform",
|
||||
setup: func(client *dbent.Client) {
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "a1", Platform: service.PlatformAnthropic})
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "a2", Platform: service.PlatformOpenAI})
|
||||
},
|
||||
platform: service.PlatformOpenAI,
|
||||
wantCount: 1,
|
||||
validate: func(accounts []service.Account) {
|
||||
s.Require().Equal(service.PlatformOpenAI, accounts[0].Platform)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter_by_type",
|
||||
setup: func(client *dbent.Client) {
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "t1", Type: service.AccountTypeOAuth})
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeAPIKey})
|
||||
},
|
||||
accType: service.AccountTypeAPIKey,
|
||||
wantCount: 1,
|
||||
validate: func(accounts []service.Account) {
|
||||
s.Require().Equal(service.AccountTypeAPIKey, accounts[0].Type)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter_by_status",
|
||||
setup: func(client *dbent.Client) {
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "s1", Status: service.StatusActive})
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "s2", Status: service.StatusDisabled})
|
||||
},
|
||||
status: service.StatusDisabled,
|
||||
wantCount: 1,
|
||||
validate: func(accounts []service.Account) {
|
||||
s.Require().Equal(service.StatusDisabled, accounts[0].Status)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter_by_search",
|
||||
setup: func(client *dbent.Client) {
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "alpha-account"})
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "beta-account"})
|
||||
},
|
||||
search: "alpha",
|
||||
wantCount: 1,
|
||||
validate: func(accounts []service.Account) {
|
||||
s.Require().Contains(accounts[0].Name, "alpha")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
// 每个 case 重新获取隔离资源
|
||||
tx := testEntTx(s.T())
|
||||
client := tx.Client()
|
||||
repo := newAccountRepositoryWithSQL(client, tx)
|
||||
ctx := context.Background()
|
||||
|
||||
tt.setup(client)
|
||||
|
||||
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, tt.wantCount)
|
||||
if tt.validate != nil {
|
||||
tt.validate(accounts)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- ListByGroup / ListActive / ListByPlatform ---
|
||||
|
||||
func (s *AccountRepoSuite) TestListByGroup() {
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-list"})
|
||||
acc1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a1", Status: service.StatusActive})
|
||||
acc2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a2", Status: service.StatusActive})
|
||||
mustBindAccountToGroup(s.T(), s.client, acc1.ID, group.ID, 2)
|
||||
mustBindAccountToGroup(s.T(), s.client, acc2.ID, group.ID, 1)
|
||||
|
||||
accounts, err := s.repo.ListByGroup(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ListByGroup")
|
||||
s.Require().Len(accounts, 2)
|
||||
// Should be ordered by priority
|
||||
s.Require().Equal(acc2.ID, accounts[0].ID, "expected acc2 first (priority=1)")
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListActive() {
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "active1", Status: service.StatusActive})
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "inactive1", Status: service.StatusDisabled})
|
||||
|
||||
accounts, err := s.repo.ListActive(s.ctx)
|
||||
s.Require().NoError(err, "ListActive")
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().Equal("active1", accounts[0].Name)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListByPlatform() {
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "p1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "p2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
|
||||
|
||||
accounts, err := s.repo.ListByPlatform(s.ctx, service.PlatformAnthropic)
|
||||
s.Require().NoError(err, "ListByPlatform")
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().Equal(service.PlatformAnthropic, accounts[0].Platform)
|
||||
}
|
||||
|
||||
// --- Preload and VirtualFields ---
|
||||
|
||||
func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
|
||||
proxy := mustCreateProxy(s.T(), s.client, &service.Proxy{Name: "p1"})
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g1"})
|
||||
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc1",
|
||||
ProxyID: &proxy.ID,
|
||||
})
|
||||
mustBindAccountToGroup(s.T(), s.client, account.ID, group.ID, 1)
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().NotNil(got.Proxy, "expected Proxy preload")
|
||||
s.Require().Equal(proxy.ID, got.Proxy.ID)
|
||||
s.Require().Len(got.GroupIDs, 1, "expected GroupIDs to be populated")
|
||||
s.Require().Equal(group.ID, got.GroupIDs[0])
|
||||
s.Require().Len(got.Groups, 1, "expected Groups to be populated")
|
||||
s.Require().Equal(group.ID, got.Groups[0].ID)
|
||||
|
||||
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc")
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().NotNil(accounts[0].Proxy, "expected Proxy preload in list")
|
||||
s.Require().Equal(proxy.ID, accounts[0].Proxy.ID)
|
||||
s.Require().Len(accounts[0].GroupIDs, 1, "expected GroupIDs in list")
|
||||
s.Require().Equal(group.ID, accounts[0].GroupIDs[0])
|
||||
}
|
||||
|
||||
// --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups ---
|
||||
|
||||
func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
|
||||
g1 := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g1"})
|
||||
g2 := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g2"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc"})
|
||||
|
||||
s.Require().NoError(s.repo.AddToGroup(s.ctx, account.ID, g1.ID, 10), "AddToGroup")
|
||||
groups, err := s.repo.GetGroups(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetGroups")
|
||||
s.Require().Len(groups, 1, "expected 1 group")
|
||||
s.Require().Equal(g1.ID, groups[0].ID)
|
||||
|
||||
s.Require().NoError(s.repo.RemoveFromGroup(s.ctx, account.ID, g1.ID), "RemoveFromGroup")
|
||||
groups, err = s.repo.GetGroups(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetGroups after remove")
|
||||
s.Require().Empty(groups, "expected 0 groups after remove")
|
||||
|
||||
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{g1.ID, g2.ID}), "BindGroups")
|
||||
groups, err = s.repo.GetGroups(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetGroups after bind")
|
||||
s.Require().Len(groups, 2, "expected 2 groups after bind")
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-empty"})
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-empty"})
|
||||
mustBindAccountToGroup(s.T(), s.client, account.ID, group.ID, 1)
|
||||
|
||||
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{}), "BindGroups empty")
|
||||
|
||||
groups, err := s.repo.GetGroups(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(groups, "expected 0 groups after binding empty list")
|
||||
}
|
||||
|
||||
// --- Schedulable ---
|
||||
|
||||
func (s *AccountRepoSuite) TestListSchedulable() {
|
||||
now := time.Now()
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-sched"})
|
||||
|
||||
okAcc := mustCreateAccount(s.T(), s.client, &service.Account{Name: "ok", Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.client, okAcc.ID, group.ID, 1)
|
||||
|
||||
future := now.Add(10 * time.Minute)
|
||||
overloaded := mustCreateAccount(s.T(), s.client, &service.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
|
||||
mustBindAccountToGroup(s.T(), s.client, overloaded.ID, group.ID, 1)
|
||||
|
||||
sched, err := s.repo.ListSchedulable(s.ctx)
|
||||
s.Require().NoError(err, "ListSchedulable")
|
||||
ids := idsOfAccounts(sched)
|
||||
s.Require().Contains(ids, okAcc.ID)
|
||||
s.Require().NotContains(ids, overloaded.ID)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates() {
|
||||
now := time.Now()
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-sched"})
|
||||
|
||||
okAcc := mustCreateAccount(s.T(), s.client, &service.Account{Name: "ok", Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.client, okAcc.ID, group.ID, 1)
|
||||
|
||||
future := now.Add(10 * time.Minute)
|
||||
overloaded := mustCreateAccount(s.T(), s.client, &service.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
|
||||
mustBindAccountToGroup(s.T(), s.client, overloaded.ID, group.ID, 1)
|
||||
|
||||
rateLimited := mustCreateAccount(s.T(), s.client, &service.Account{Name: "rl", Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.client, rateLimited.ID, group.ID, 1)
|
||||
s.Require().NoError(s.repo.SetRateLimited(s.ctx, rateLimited.ID, now.Add(10*time.Minute)), "SetRateLimited")
|
||||
|
||||
s.Require().NoError(s.repo.SetError(s.ctx, overloaded.ID, "boom"), "SetError")
|
||||
|
||||
sched, err := s.repo.ListSchedulableByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ListSchedulableByGroupID")
|
||||
s.Require().Len(sched, 1, "expected only ok account schedulable")
|
||||
s.Require().Equal(okAcc.ID, sched[0].ID)
|
||||
|
||||
s.Require().NoError(s.repo.ClearRateLimit(s.ctx, rateLimited.ID), "ClearRateLimit")
|
||||
sched2, err := s.repo.ListSchedulableByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ListSchedulableByGroupID after ClearRateLimit")
|
||||
s.Require().Len(sched2, 2, "expected 2 schedulable accounts after ClearRateLimit")
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListSchedulableByPlatform() {
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
|
||||
|
||||
accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, service.PlatformAnthropic)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().Equal(service.PlatformAnthropic, accounts[0].Platform)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-sp"})
|
||||
a1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
|
||||
a2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.client, a1.ID, group.ID, 1)
|
||||
mustBindAccountToGroup(s.T(), s.client, a2.ID, group.ID, 2)
|
||||
|
||||
accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, service.PlatformAnthropic)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().Equal(a1.ID, accounts[0].ID)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestSetSchedulable() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-sched", Schedulable: true})
|
||||
|
||||
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(got.Schedulable)
|
||||
}
|
||||
|
||||
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
|
||||
|
||||
func (s *AccountRepoSuite) TestSetOverloaded() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-over"})
|
||||
until := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(got.OverloadUntil)
|
||||
s.Require().WithinDuration(until, *got.OverloadUntil, time.Second)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestSetRateLimited() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-rl"})
|
||||
resetAt := time.Date(2025, 6, 15, 14, 0, 0, 0, time.UTC)
|
||||
|
||||
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, resetAt))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(got.RateLimitedAt)
|
||||
s.Require().NotNil(got.RateLimitResetAt)
|
||||
s.Require().WithinDuration(resetAt, *got.RateLimitResetAt, time.Second)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestClearRateLimit() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-clear"})
|
||||
until := time.Now().Add(1 * time.Hour)
|
||||
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
|
||||
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, until))
|
||||
|
||||
s.Require().NoError(s.repo.ClearRateLimit(s.ctx, account.ID))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Nil(got.RateLimitedAt)
|
||||
s.Require().Nil(got.RateLimitResetAt)
|
||||
s.Require().Nil(got.OverloadUntil)
|
||||
}
|
||||
|
||||
// --- UpdateLastUsed ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateLastUsed() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-used"})
|
||||
s.Require().Nil(account.LastUsedAt)
|
||||
|
||||
s.Require().NoError(s.repo.UpdateLastUsed(s.ctx, account.ID))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(got.LastUsedAt)
|
||||
}
|
||||
|
||||
// --- SetError ---
|
||||
|
||||
func (s *AccountRepoSuite) TestSetError() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-err", Status: service.StatusActive})
|
||||
|
||||
s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong"))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(service.StatusError, got.Status)
|
||||
s.Require().Equal("something went wrong", got.ErrorMessage)
|
||||
}
|
||||
|
||||
// --- UpdateSessionWindow ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateSessionWindow() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-win"})
|
||||
start := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC)
|
||||
end := time.Date(2025, 6, 15, 15, 0, 0, 0, time.UTC)
|
||||
|
||||
s.Require().NoError(s.repo.UpdateSessionWindow(s.ctx, account.ID, &start, &end, "active"))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(got.SessionWindowStart)
|
||||
s.Require().NotNil(got.SessionWindowEnd)
|
||||
s.Require().Equal("active", got.SessionWindowStatus)
|
||||
}
|
||||
|
||||
// --- UpdateExtra ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-extra",
|
||||
Extra: map[string]any{"a": "1"},
|
||||
})
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"b": "2"}), "UpdateExtra")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("1", got.Extra["a"])
|
||||
s.Require().Equal("2", got.Extra["b"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_EmptyUpdates() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-extra-empty"})
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{}))
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-nil-extra", Extra: nil})
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"key": "val"}))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("val", got.Extra["key"])
|
||||
}
|
||||
|
||||
// --- GetByCRSAccountID ---
|
||||
|
||||
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
|
||||
crsID := "crs-12345"
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-crs",
|
||||
Extra: map[string]any{"crs_account_id": crsID},
|
||||
})
|
||||
|
||||
got, err := s.repo.GetByCRSAccountID(s.ctx, crsID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(got)
|
||||
s.Require().Equal("acc-crs", got.Name)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestGetByCRSAccountID_NotFound() {
|
||||
got, err := s.repo.GetByCRSAccountID(s.ctx, "non-existent")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Nil(got)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() {
|
||||
got, err := s.repo.GetByCRSAccountID(s.ctx, "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Nil(got)
|
||||
}
|
||||
|
||||
// --- BulkUpdate ---
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate() {
|
||||
a1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk1", Priority: 1})
|
||||
a2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk2", Priority: 1})
|
||||
|
||||
newPriority := 99
|
||||
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, service.AccountBulkUpdate{
|
||||
Priority: &newPriority,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().GreaterOrEqual(affected, int64(1), "expected at least one affected row")
|
||||
|
||||
got1, _ := s.repo.GetByID(s.ctx, a1.ID)
|
||||
got2, _ := s.repo.GetByID(s.ctx, a2.ID)
|
||||
s.Require().Equal(99, got1.Priority)
|
||||
s.Require().Equal(99, got2.Priority)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
|
||||
a1 := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "bulk-cred",
|
||||
Credentials: map[string]any{"existing": "value"},
|
||||
})
|
||||
|
||||
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
|
||||
Credentials: map[string]any{"new_key": "new_value"},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
got, _ := s.repo.GetByID(s.ctx, a1.ID)
|
||||
s.Require().Equal("value", got.Credentials["existing"])
|
||||
s.Require().Equal("new_value", got.Credentials["new_key"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() {
|
||||
a1 := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "bulk-extra",
|
||||
Extra: map[string]any{"existing": "val"},
|
||||
})
|
||||
|
||||
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
|
||||
Extra: map[string]any{"new_key": "new_val"},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
got, _ := s.repo.GetByID(s.ctx, a1.ID)
|
||||
s.Require().Equal("val", got.Extra["existing"])
|
||||
s.Require().Equal("new_val", got.Extra["new_key"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
|
||||
affected, err := s.repo.BulkUpdate(s.ctx, []int64{}, service.AccountBulkUpdate{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(affected)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() {
|
||||
a1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-empty"})
|
||||
|
||||
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(affected)
|
||||
}
|
||||
|
||||
func idsOfAccounts(accounts []service.Account) []int64 {
|
||||
out := make([]int64, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
out = append(out, accounts[i].ID)
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func uniqueTestValue(t *testing.T, prefix string) string {
|
||||
t.Helper()
|
||||
safeName := strings.NewReplacer("/", "_", " ", "_").Replace(t.Name())
|
||||
return fmt.Sprintf("%s-%s", prefix, safeName)
|
||||
}
|
||||
|
||||
func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
entClient := tx.Client()
|
||||
|
||||
targetGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "target-group")).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
otherGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "other-group")).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
repo := newUserRepositoryWithSQL(entClient, tx)
|
||||
|
||||
u1 := &service.User{
|
||||
Email: uniqueTestValue(t, "u1") + "@example.com",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
AllowedGroups: []int64{targetGroup.ID, otherGroup.ID},
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, u1))
|
||||
|
||||
u2 := &service.User{
|
||||
Email: uniqueTestValue(t, "u2") + "@example.com",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
AllowedGroups: []int64{targetGroup.ID},
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, u2))
|
||||
|
||||
u3 := &service.User{
|
||||
Email: uniqueTestValue(t, "u3") + "@example.com",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
AllowedGroups: []int64{otherGroup.ID},
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, u3))
|
||||
|
||||
affected, err := repo.RemoveGroupFromAllowedGroups(ctx, targetGroup.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), affected)
|
||||
|
||||
u1After, err := repo.GetByID(ctx, u1.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, u1After.AllowedGroups, targetGroup.ID)
|
||||
require.Contains(t, u1After.AllowedGroups, otherGroup.ID)
|
||||
|
||||
u2After, err := repo.GetByID(ctx, u2.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, u2After.AllowedGroups, targetGroup.ID)
|
||||
}
|
||||
|
||||
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
entClient := tx.Client()
|
||||
|
||||
targetGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "delete-cascade-target")).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
otherGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "delete-cascade-other")).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
userRepo := newUserRepositoryWithSQL(entClient, tx)
|
||||
groupRepo := newGroupRepositoryWithSQL(entClient, tx)
|
||||
apiKeyRepo := NewAPIKeyRepository(entClient)
|
||||
|
||||
u := &service.User{
|
||||
Email: uniqueTestValue(t, "cascade-user") + "@example.com",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
AllowedGroups: []int64{targetGroup.ID, otherGroup.ID},
|
||||
}
|
||||
require.NoError(t, userRepo.Create(ctx, u))
|
||||
|
||||
key := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueTestValue(t, "sk-test-delete-cascade"),
|
||||
Name: "test key",
|
||||
GroupID: &targetGroup.ID,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, apiKeyRepo.Create(ctx, key))
|
||||
|
||||
_, err = groupRepo.DeleteCascade(ctx, targetGroup.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Deleted group should be hidden by default queries (soft-delete semantics).
|
||||
_, err = groupRepo.GetByID(ctx, targetGroup.ID)
|
||||
require.ErrorIs(t, err, service.ErrGroupNotFound)
|
||||
|
||||
activeGroups, err := groupRepo.ListActive(ctx)
|
||||
require.NoError(t, err)
|
||||
for _, g := range activeGroups {
|
||||
require.NotEqual(t, targetGroup.ID, g.ID)
|
||||
}
|
||||
|
||||
// User.allowed_groups should no longer include the deleted group.
|
||||
uAfter, err := userRepo.GetByID(ctx, u.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, uAfter.AllowedGroups, targetGroup.ID)
|
||||
require.Contains(t, uAfter.AllowedGroups, otherGroup.ID)
|
||||
|
||||
// API keys bound to the deleted group should have group_id cleared.
|
||||
keyAfter, err := apiKeyRepo.GetByID(ctx, key.ID)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, keyAfter.GroupID)
|
||||
}
|
||||
93
backend/internal/repository/api_key_cache.go
Normal file
93
backend/internal/repository/api_key_cache.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
||||
apiKeyRateLimitDuration = 24 * time.Hour
|
||||
apiKeyAuthCachePrefix = "apikey:auth:"
|
||||
)
|
||||
|
||||
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
|
||||
func apiKeyRateLimitKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
}
|
||||
|
||||
func apiKeyAuthCacheKey(key string) string {
|
||||
return fmt.Sprintf("%s%s", apiKeyAuthCachePrefix, key)
|
||||
}
|
||||
|
||||
type apiKeyCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewAPIKeyCache(rdb *redis.Client) service.APIKeyCache {
|
||||
return &apiKeyCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
key := apiKeyRateLimitKey(userID)
|
||||
count, err := c.rdb.Get(ctx, key).Int()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return 0, nil
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
key := apiKeyRateLimitKey(userID)
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
key := apiKeyRateLimitKey(userID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
|
||||
return c.rdb.Incr(ctx, apiKey).Err()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||
return c.rdb.Expire(ctx, apiKey, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) {
|
||||
val, err := c.rdb.Get(ctx, apiKeyAuthCacheKey(key)).Bytes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var entry service.APIKeyAuthCacheEntry
|
||||
if err := json.Unmarshal(val, &entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &entry, nil
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error {
|
||||
if entry == nil {
|
||||
return nil
|
||||
}
|
||||
payload, err := json.Marshal(entry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.rdb.Set(ctx, apiKeyAuthCacheKey(key), payload, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
|
||||
return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err()
|
||||
}
|
||||
127
backend/internal/repository/api_key_cache_integration_test.go
Normal file
127
backend/internal/repository/api_key_cache_integration_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ApiKeyCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
}
|
||||
|
||||
func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
|
||||
}{
|
||||
{
|
||||
name: "missing_key_returns_zero_nil",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
userID := int64(1)
|
||||
|
||||
count, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||
|
||||
require.NoError(s.T(), err, "expected nil error for missing key")
|
||||
require.Equal(s.T(), 0, count, "expected zero count for missing key")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "increment_increases_count_and_sets_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
userID := int64(1)
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount")
|
||||
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount 2")
|
||||
|
||||
count, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetCreateAttemptCount")
|
||||
require.Equal(s.T(), 2, count, "count mismatch")
|
||||
|
||||
ttl, err := rdb.TTL(ctx, key).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, apiKeyRateLimitDuration)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "delete_removes_key",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
userID := int64(1)
|
||||
|
||||
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID))
|
||||
require.NoError(s.T(), cache.DeleteCreateAttemptCount(ctx, userID), "DeleteCreateAttemptCount")
|
||||
|
||||
count, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||
require.NoError(s.T(), err, "expected nil error after delete")
|
||||
require.Equal(s.T(), 0, count, "expected zero count after delete")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
// 每个 case 重新获取隔离资源
|
||||
rdb := testRedis(s.T())
|
||||
cache := &apiKeyCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ApiKeyCacheSuite) TestDailyUsage() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
|
||||
}{
|
||||
{
|
||||
name: "increment_increases_count",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
dailyKey := "daily:sk-test"
|
||||
|
||||
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage")
|
||||
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage 2")
|
||||
|
||||
n, err := rdb.Get(ctx, dailyKey).Int()
|
||||
require.NoError(s.T(), err, "Get dailyKey")
|
||||
require.Equal(s.T(), 2, n, "expected daily usage=2")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set_expiry_sets_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
dailyKey := "daily:sk-test-expiry"
|
||||
|
||||
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey))
|
||||
require.NoError(s.T(), cache.SetDailyUsageExpiry(ctx, dailyKey, 1*time.Hour), "SetDailyUsageExpiry")
|
||||
|
||||
ttl, err := rdb.TTL(ctx, dailyKey).Result()
|
||||
require.NoError(s.T(), err, "TTL dailyKey")
|
||||
require.Greater(s.T(), ttl, time.Duration(0), "expected ttl > 0")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := &apiKeyCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApiKeyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ApiKeyCacheSuite))
|
||||
}
|
||||
46
backend/internal/repository/api_key_cache_test.go
Normal file
46
backend/internal/repository/api_key_cache_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApiKeyRateLimitKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_user_id",
|
||||
userID: 123,
|
||||
expected: "apikey:ratelimit:123",
|
||||
},
|
||||
{
|
||||
name: "zero_user_id",
|
||||
userID: 0,
|
||||
expected: "apikey:ratelimit:0",
|
||||
},
|
||||
{
|
||||
name: "negative_user_id",
|
||||
userID: -1,
|
||||
expected: "apikey:ratelimit:-1",
|
||||
},
|
||||
{
|
||||
name: "max_int64",
|
||||
userID: math.MaxInt64,
|
||||
expected: "apikey:ratelimit:9223372036854775807",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := apiKeyRateLimitKey(tc.userID)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
435
backend/internal/repository/api_key_repo.go
Normal file
435
backend/internal/repository/api_key_repo.go
Normal file
@@ -0,0 +1,435 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type apiKeyRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewAPIKeyRepository(client *dbent.Client) service.APIKeyRepository {
|
||||
return &apiKeyRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
|
||||
// 默认过滤已软删除记录,避免删除后仍被查询到。
|
||||
return r.client.APIKey.Query().Where(apikey.DeletedAtIsNil())
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error {
|
||||
builder := r.client.APIKey.Create().
|
||||
SetUserID(key.UserID).
|
||||
SetKey(key.Key).
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status).
|
||||
SetNillableGroupID(key.GroupID)
|
||||
|
||||
if len(key.IPWhitelist) > 0 {
|
||||
builder.SetIPWhitelist(key.IPWhitelist)
|
||||
}
|
||||
if len(key.IPBlacklist) > 0 {
|
||||
builder.SetIPBlacklist(key.IPBlacklist)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
key.ID = created.ID
|
||||
key.CreatedAt = created.CreatedAt
|
||||
key.UpdatedAt = created.UpdatedAt
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrAPIKeyExists)
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.IDEQ(id)).
|
||||
WithUser().
|
||||
WithGroup().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return apiKeyEntityToService(m), nil
|
||||
}
|
||||
|
||||
// GetKeyAndOwnerID 根据 API Key ID 获取其 key 与所有者(用户)ID。
|
||||
// 相比 GetByID,此方法性能更优,因为:
|
||||
// - 使用 Select() 只查询必要字段,减少数据传输量
|
||||
// - 不加载完整的 API Key 实体及其关联数据(User、Group 等)
|
||||
// - 适用于删除等只需 key 与用户 ID 的场景
|
||||
func (r *apiKeyRepository) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.IDEQ(id)).
|
||||
Select(apikey.FieldKey, apikey.FieldUserID).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return "", 0, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return "", 0, err
|
||||
}
|
||||
return m.Key, m.UserID, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.KeyEQ(key)).
|
||||
WithUser().
|
||||
WithGroup().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return apiKeyEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.KeyEQ(key)).
|
||||
Select(
|
||||
apikey.FieldID,
|
||||
apikey.FieldUserID,
|
||||
apikey.FieldGroupID,
|
||||
apikey.FieldStatus,
|
||||
apikey.FieldIPWhitelist,
|
||||
apikey.FieldIPBlacklist,
|
||||
).
|
||||
WithUser(func(q *dbent.UserQuery) {
|
||||
q.Select(
|
||||
user.FieldID,
|
||||
user.FieldStatus,
|
||||
user.FieldRole,
|
||||
user.FieldBalance,
|
||||
user.FieldConcurrency,
|
||||
)
|
||||
}).
|
||||
WithGroup(func(q *dbent.GroupQuery) {
|
||||
q.Select(
|
||||
group.FieldID,
|
||||
group.FieldName,
|
||||
group.FieldPlatform,
|
||||
group.FieldStatus,
|
||||
group.FieldSubscriptionType,
|
||||
group.FieldRateMultiplier,
|
||||
group.FieldDailyLimitUsd,
|
||||
group.FieldWeeklyLimitUsd,
|
||||
group.FieldMonthlyLimitUsd,
|
||||
group.FieldImagePrice1k,
|
||||
group.FieldImagePrice2k,
|
||||
group.FieldImagePrice4k,
|
||||
group.FieldClaudeCodeOnly,
|
||||
group.FieldFallbackGroupID,
|
||||
)
|
||||
}).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return apiKeyEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
|
||||
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
|
||||
// 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除,
|
||||
// 则会更新已删除的记录。
|
||||
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。
|
||||
// 同时显式设置 updated_at,避免二次查询带来的并发可见性问题。
|
||||
now := time.Now()
|
||||
builder := r.client.APIKey.Update().
|
||||
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status).
|
||||
SetUpdatedAt(now)
|
||||
if key.GroupID != nil {
|
||||
builder.SetGroupID(*key.GroupID)
|
||||
} else {
|
||||
builder.ClearGroupID()
|
||||
}
|
||||
|
||||
// IP 限制字段
|
||||
if len(key.IPWhitelist) > 0 {
|
||||
builder.SetIPWhitelist(key.IPWhitelist)
|
||||
} else {
|
||||
builder.ClearIPWhitelist()
|
||||
}
|
||||
if len(key.IPBlacklist) > 0 {
|
||||
builder.SetIPBlacklist(key.IPBlacklist)
|
||||
} else {
|
||||
builder.ClearIPBlacklist()
|
||||
}
|
||||
|
||||
affected, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
// 更新影响行数为 0,说明记录不存在或已被软删除。
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
|
||||
// 使用同一时间戳回填,避免并发删除导致二次查询失败。
|
||||
key.UpdatedAt = now
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
// 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
|
||||
affected, err := r.client.APIKey.Update().
|
||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||
SetDeletedAt(time.Now()).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
exists, err := r.client.APIKey.Query().
|
||||
Where(apikey.IDEQ(id)).
|
||||
Exist(mixins.SkipSoftDelete(ctx))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
keys, err := q.
|
||||
WithGroup().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(apikey.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.APIKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
|
||||
return outKeys, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return []int64{}, nil
|
||||
}
|
||||
|
||||
ids, err := r.client.APIKey.Query().
|
||||
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...), apikey.DeletedAtIsNil()).
|
||||
IDs(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
count, err := r.activeQuery().Where(apikey.UserIDEQ(userID)).Count(ctx)
|
||||
return int64(count), err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
count, err := r.activeQuery().Where(apikey.KeyEQ(key)).Count(ctx)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
q := r.activeQuery().Where(apikey.GroupIDEQ(groupID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
keys, err := q.
|
||||
WithUser().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(apikey.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.APIKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
|
||||
return outKeys, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
// SearchAPIKeys searches API keys by user ID and/or keyword (name)
|
||||
func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
|
||||
q := r.activeQuery()
|
||||
if userID > 0 {
|
||||
q = q.Where(apikey.UserIDEQ(userID))
|
||||
}
|
||||
|
||||
if keyword != "" {
|
||||
q = q.Where(apikey.NameContainsFold(keyword))
|
||||
}
|
||||
|
||||
keys, err := q.Limit(limit).Order(dbent.Desc(apikey.FieldID)).All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.APIKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
return outKeys, nil
|
||||
}
|
||||
|
||||
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
|
||||
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
n, err := r.client.APIKey.Update().
|
||||
Where(apikey.GroupIDEQ(groupID), apikey.DeletedAtIsNil()).
|
||||
ClearGroupID().
|
||||
Save(ctx)
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
// CountByGroupID 获取分组的 API Key 数量
|
||||
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
count, err := r.activeQuery().Where(apikey.GroupIDEQ(groupID)).Count(ctx)
|
||||
return int64(count), err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||
keys, err := r.activeQuery().
|
||||
Where(apikey.UserIDEQ(userID)).
|
||||
Select(apikey.FieldKey).
|
||||
Strings(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||
keys, err := r.activeQuery().
|
||||
Where(apikey.GroupIDEQ(groupID)).
|
||||
Select(apikey.FieldKey).
|
||||
Strings(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
out := &service.APIKey{
|
||||
ID: m.ID,
|
||||
UserID: m.UserID,
|
||||
Key: m.Key,
|
||||
Name: m.Name,
|
||||
Status: m.Status,
|
||||
IPWhitelist: m.IPWhitelist,
|
||||
IPBlacklist: m.IPBlacklist,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
GroupID: m.GroupID,
|
||||
}
|
||||
if m.Edges.User != nil {
|
||||
out.User = userEntityToService(m.Edges.User)
|
||||
}
|
||||
if m.Edges.Group != nil {
|
||||
out.Group = groupEntityToService(m.Edges.Group)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func userEntityToService(u *dbent.User) *service.User {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.User{
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Notes: u.Notes,
|
||||
PasswordHash: u.PasswordHash,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: derefString(g.Description),
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
Hydrated: true,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUsd,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUsd,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUsd,
|
||||
ImagePrice1K: g.ImagePrice1k,
|
||||
ImagePrice2K: g.ImagePrice2k,
|
||||
ImagePrice4K: g.ImagePrice4k,
|
||||
DefaultValidityDays: g.DefaultValidityDays,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func derefString(s *string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return *s
|
||||
}
|
||||
385
backend/internal/repository/api_key_repo_integration_test.go
Normal file
385
backend/internal/repository/api_key_repo_integration_test.go
Normal file
@@ -0,0 +1,385 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type APIKeyRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
repo *apiKeyRepository
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.repo = NewAPIKeyRepository(s.client).(*apiKeyRepository)
|
||||
}
|
||||
|
||||
func TestAPIKeyRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(APIKeyRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / GetByID / GetByKey ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestCreate() {
|
||||
user := s.mustCreateUser("create@test.com")
|
||||
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-create-test",
|
||||
Name: "Test Key",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, key)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(key.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("sk-create-test", got.Key)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestGetByKey() {
|
||||
user := s.mustCreateUser("getbykey@test.com")
|
||||
group := s.mustCreateGroup("g-key")
|
||||
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-getbykey",
|
||||
Name: "My Key",
|
||||
GroupID: &group.ID,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, key))
|
||||
|
||||
got, err := s.repo.GetByKey(s.ctx, key.Key)
|
||||
s.Require().NoError(err, "GetByKey")
|
||||
s.Require().Equal(key.ID, got.ID)
|
||||
s.Require().NotNil(got.User, "expected User preload")
|
||||
s.Require().Equal(user.ID, got.User.ID)
|
||||
s.Require().NotNil(got.Group, "expected Group preload")
|
||||
s.Require().Equal(group.ID, got.Group.ID)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestGetByKey_NotFound() {
|
||||
_, err := s.repo.GetByKey(s.ctx, "non-existent-key")
|
||||
s.Require().Error(err, "expected error for non-existent key")
|
||||
}
|
||||
|
||||
// --- Update ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestUpdate() {
|
||||
user := s.mustCreateUser("update@test.com")
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-update",
|
||||
Name: "Original",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, key))
|
||||
|
||||
key.Name = "Renamed"
|
||||
key.Status = service.StatusDisabled
|
||||
err := s.repo.Update(s.ctx, key)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("sk-update", got.Key, "Update should not change key")
|
||||
s.Require().Equal(user.ID, got.UserID, "Update should not change user_id")
|
||||
s.Require().Equal("Renamed", got.Name)
|
||||
s.Require().Equal(service.StatusDisabled, got.Status)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestUpdate_ClearGroupID() {
|
||||
user := s.mustCreateUser("cleargroup@test.com")
|
||||
group := s.mustCreateGroup("g-clear")
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-clear-group",
|
||||
Name: "Group Key",
|
||||
GroupID: &group.ID,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, key))
|
||||
|
||||
key.GroupID = nil
|
||||
err := s.repo.Update(s.ctx, key)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Nil(got.GroupID, "expected GroupID to be cleared")
|
||||
}
|
||||
|
||||
// --- Delete ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestDelete() {
|
||||
user := s.mustCreateUser("delete@test.com")
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-delete",
|
||||
Name: "Delete Me",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, key))
|
||||
|
||||
err := s.repo.Delete(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
// --- ListByUserID / CountByUserID ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestListByUserID() {
|
||||
user := s.mustCreateUser("listbyuser@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByUserID")
|
||||
s.Require().Len(keys, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestListByUserID_Pagination() {
|
||||
user := s.mustCreateUser("paging@test.com")
|
||||
for i := 0; i < 5; i++ {
|
||||
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
|
||||
}
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(keys, 2)
|
||||
s.Require().Equal(int64(5), page.Total)
|
||||
s.Require().Equal(3, page.Pages)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestCountByUserID() {
|
||||
user := s.mustCreateUser("count@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-count-1", "K1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-count-2", "K2", nil)
|
||||
|
||||
count, err := s.repo.CountByUserID(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "CountByUserID")
|
||||
s.Require().Equal(int64(2), count)
|
||||
}
|
||||
|
||||
// --- ListByGroupID / CountByGroupID ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestListByGroupID() {
|
||||
user := s.mustCreateUser("listbygroup@test.com")
|
||||
group := s.mustCreateGroup("g-list")
|
||||
|
||||
s.mustCreateApiKey(user.ID, "sk-grp-1", "K1", &group.ID)
|
||||
s.mustCreateApiKey(user.ID, "sk-grp-2", "K2", &group.ID)
|
||||
s.mustCreateApiKey(user.ID, "sk-grp-3", "K3", nil) // no group
|
||||
|
||||
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByGroupID")
|
||||
s.Require().Len(keys, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
// User preloaded
|
||||
s.Require().NotNil(keys[0].User)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestCountByGroupID() {
|
||||
user := s.mustCreateUser("countgroup@test.com")
|
||||
group := s.mustCreateGroup("g-count")
|
||||
s.mustCreateApiKey(user.ID, "sk-gc-1", "K1", &group.ID)
|
||||
|
||||
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID")
|
||||
s.Require().Equal(int64(1), count)
|
||||
}
|
||||
|
||||
// --- ExistsByKey ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestExistsByKey() {
|
||||
user := s.mustCreateUser("exists@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-exists", "K", nil)
|
||||
|
||||
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
|
||||
s.Require().NoError(err, "ExistsByKey")
|
||||
s.Require().True(exists)
|
||||
|
||||
notExists, err := s.repo.ExistsByKey(s.ctx, "sk-not-exists")
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(notExists)
|
||||
}
|
||||
|
||||
// --- SearchAPIKeys ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestSearchAPIKeys() {
|
||||
user := s.mustCreateUser("search@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-search-1", "Production Key", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-search-2", "Development Key", nil)
|
||||
|
||||
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "prod", 10)
|
||||
s.Require().NoError(err, "SearchAPIKeys")
|
||||
s.Require().Len(found, 1)
|
||||
s.Require().Contains(found[0].Name, "Production")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoKeyword() {
|
||||
user := s.mustCreateUser("searchnokw@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-nk-1", "K1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-nk-2", "K2", nil)
|
||||
|
||||
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "", 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(found, 2)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoUserID() {
|
||||
user := s.mustCreateUser("searchnouid@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-nu-1", "TestKey", nil)
|
||||
|
||||
found, err := s.repo.SearchAPIKeys(s.ctx, 0, "testkey", 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(found, 1)
|
||||
}
|
||||
|
||||
// --- ClearGroupIDByGroupID ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestClearGroupIDByGroupID() {
|
||||
user := s.mustCreateUser("cleargrp@test.com")
|
||||
group := s.mustCreateGroup("g-clear-bulk")
|
||||
|
||||
k1 := s.mustCreateApiKey(user.ID, "sk-clr-1", "K1", &group.ID)
|
||||
k2 := s.mustCreateApiKey(user.ID, "sk-clr-2", "K2", &group.ID)
|
||||
s.mustCreateApiKey(user.ID, "sk-clr-3", "K3", nil) // no group
|
||||
|
||||
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ClearGroupIDByGroupID")
|
||||
s.Require().Equal(int64(2), affected)
|
||||
|
||||
got1, _ := s.repo.GetByID(s.ctx, k1.ID)
|
||||
got2, _ := s.repo.GetByID(s.ctx, k2.ID)
|
||||
s.Require().Nil(got1.GroupID)
|
||||
s.Require().Nil(got2.GroupID)
|
||||
|
||||
count, _ := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
|
||||
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
user := s.mustCreateUser("k@example.com")
|
||||
group := s.mustCreateGroup("g-k")
|
||||
key := s.mustCreateApiKey(user.ID, "sk-test-1", "My Key", &group.ID)
|
||||
key.GroupID = &group.ID
|
||||
|
||||
got, err := s.repo.GetByKey(s.ctx, key.Key)
|
||||
s.Require().NoError(err, "GetByKey")
|
||||
s.Require().Equal(key.ID, got.ID)
|
||||
s.Require().NotNil(got.User)
|
||||
s.Require().Equal(user.ID, got.User.ID)
|
||||
s.Require().NotNil(got.Group)
|
||||
s.Require().Equal(group.ID, got.Group.ID)
|
||||
|
||||
key.Name = "Renamed"
|
||||
key.Status = service.StatusDisabled
|
||||
key.GroupID = nil
|
||||
s.Require().NoError(s.repo.Update(s.ctx, key), "Update")
|
||||
|
||||
got2, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("sk-test-1", got2.Key, "Update should not change key")
|
||||
s.Require().Equal(user.ID, got2.UserID, "Update should not change user_id")
|
||||
s.Require().Equal("Renamed", got2.Name)
|
||||
s.Require().Equal(service.StatusDisabled, got2.Status)
|
||||
s.Require().Nil(got2.GroupID)
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByUserID")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(keys, 1)
|
||||
|
||||
exists, err := s.repo.ExistsByKey(s.ctx, "sk-test-1")
|
||||
s.Require().NoError(err, "ExistsByKey")
|
||||
s.Require().True(exists, "expected key to exist")
|
||||
|
||||
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "renam", 10)
|
||||
s.Require().NoError(err, "SearchAPIKeys")
|
||||
s.Require().Len(found, 1)
|
||||
s.Require().Equal(key.ID, found[0].ID)
|
||||
|
||||
// ClearGroupIDByGroupID
|
||||
k2 := s.mustCreateApiKey(user.ID, "sk-test-2", "Group Key", &group.ID)
|
||||
k2.GroupID = &group.ID
|
||||
|
||||
countBefore, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID")
|
||||
s.Require().Equal(int64(1), countBefore, "expected 1 key in group before clear")
|
||||
|
||||
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ClearGroupIDByGroupID")
|
||||
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||
|
||||
got3, err := s.repo.GetByID(s.ctx, k2.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Nil(got3.GroupID, "expected GroupID cleared")
|
||||
|
||||
countAfter, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID after clear")
|
||||
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) mustCreateUser(email string) *service.User {
|
||||
s.T().Helper()
|
||||
|
||||
u, err := s.client.User.Create().
|
||||
SetEmail(email).
|
||||
SetPasswordHash("test-password-hash").
|
||||
SetStatus(service.StatusActive).
|
||||
SetRole(service.RoleUser).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "create user")
|
||||
return userEntityToService(u)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) mustCreateGroup(name string) *service.Group {
|
||||
s.T().Helper()
|
||||
|
||||
g, err := s.client.Group.Create().
|
||||
SetName(name).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "create group")
|
||||
return groupEntityToService(g)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, groupID *int64) *service.APIKey {
|
||||
s.T().Helper()
|
||||
|
||||
k := &service.APIKey{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: name,
|
||||
GroupID: groupID,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, k), "create api key")
|
||||
return k
|
||||
}
|
||||
183
backend/internal/repository/billing_cache.go
Normal file
183
backend/internal/repository/billing_cache.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
billingBalanceKeyPrefix = "billing:balance:"
|
||||
billingSubKeyPrefix = "billing:sub:"
|
||||
billingCacheTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
// billingBalanceKey generates the Redis key for user balance cache.
|
||||
func billingBalanceKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
}
|
||||
|
||||
// billingSubKey generates the Redis key for subscription cache.
|
||||
func billingSubKey(userID, groupID int64) string {
|
||||
return fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
}
|
||||
|
||||
const (
|
||||
subFieldStatus = "status"
|
||||
subFieldExpiresAt = "expires_at"
|
||||
subFieldDailyUsage = "daily_usage"
|
||||
subFieldWeeklyUsage = "weekly_usage"
|
||||
subFieldMonthlyUsage = "monthly_usage"
|
||||
subFieldVersion = "version"
|
||||
)
|
||||
|
||||
var (
|
||||
deductBalanceScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
return 0
|
||||
end
|
||||
local newVal = tonumber(current) - tonumber(ARGV[1])
|
||||
redis.call('SET', KEYS[1], newVal)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
|
||||
updateSubUsageScript = redis.NewScript(`
|
||||
local exists = redis.call('EXISTS', KEYS[1])
|
||||
if exists == 0 then
|
||||
return 0
|
||||
end
|
||||
local cost = tonumber(ARGV[1])
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
)
|
||||
|
||||
type billingCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewBillingCache(rdb *redis.Client) service.BillingCache {
|
||||
return &billingCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
key := billingBalanceKey(userID)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return strconv.ParseFloat(val, 64)
|
||||
}
|
||||
|
||||
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||
key := billingBalanceKey(userID)
|
||||
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||
key := billingBalanceKey(userID)
|
||||
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
key := billingBalanceKey(userID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*service.SubscriptionCacheData, error) {
|
||||
key := billingSubKey(userID, groupID)
|
||||
result, err := c.rdb.HGetAll(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
return c.parseSubscriptionCache(result)
|
||||
}
|
||||
|
||||
func (c *billingCache) parseSubscriptionCache(data map[string]string) (*service.SubscriptionCacheData, error) {
|
||||
result := &service.SubscriptionCacheData{}
|
||||
|
||||
result.Status = data[subFieldStatus]
|
||||
if result.Status == "" {
|
||||
return nil, errors.New("invalid cache: missing status")
|
||||
}
|
||||
|
||||
if expiresStr, ok := data[subFieldExpiresAt]; ok {
|
||||
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
|
||||
if err == nil {
|
||||
result.ExpiresAt = time.Unix(expiresAt, 0)
|
||||
}
|
||||
}
|
||||
|
||||
if dailyStr, ok := data[subFieldDailyUsage]; ok {
|
||||
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
|
||||
}
|
||||
|
||||
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
|
||||
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
|
||||
}
|
||||
|
||||
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
|
||||
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
|
||||
}
|
||||
|
||||
if versionStr, ok := data[subFieldVersion]; ok {
|
||||
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *service.SubscriptionCacheData) error {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := billingSubKey(userID, groupID)
|
||||
|
||||
fields := map[string]any{
|
||||
subFieldStatus: data.Status,
|
||||
subFieldExpiresAt: data.ExpiresAt.Unix(),
|
||||
subFieldDailyUsage: data.DailyUsage,
|
||||
subFieldWeeklyUsage: data.WeeklyUsage,
|
||||
subFieldMonthlyUsage: data.MonthlyUsage,
|
||||
subFieldVersion: data.Version,
|
||||
}
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.HSet(ctx, key, fields)
|
||||
pipe.Expire(ctx, key, billingCacheTTL)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||
key := billingSubKey(userID, groupID)
|
||||
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
|
||||
key := billingSubKey(userID, groupID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
283
backend/internal/repository/billing_cache_integration_test.go
Normal file
283
backend/internal/repository/billing_cache_integration_test.go
Normal file
@@ -0,0 +1,283 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type BillingCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
}
|
||||
|
||||
func (s *BillingCacheSuite) TestUserBalance() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
|
||||
}{
|
||||
{
|
||||
name: "missing_key_returns_redis_nil",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
_, err := cache.GetUserBalance(ctx, 1)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing balance key")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deduct_on_nonexistent_is_noop",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(1)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 1), "DeductUserBalance should not error")
|
||||
|
||||
_, err := rdb.Get(ctx, balanceKey).Result()
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected missing key after deduct on non-existent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set_and_get_with_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(2)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
|
||||
|
||||
got, err := cache.GetUserBalance(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserBalance")
|
||||
require.Equal(s.T(), 10.5, got, "balance mismatch")
|
||||
|
||||
ttl, err := rdb.TTL(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deduct_reduces_balance",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(3)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
|
||||
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 2.25), "DeductUserBalance")
|
||||
|
||||
got, err := cache.GetUserBalance(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserBalance after deduct")
|
||||
require.Equal(s.T(), 8.25, got, "deduct mismatch")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalidate_removes_key",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(100)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 50.0), "SetUserBalance")
|
||||
|
||||
exists, err := rdb.Exists(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "Exists")
|
||||
require.Equal(s.T(), int64(1), exists, "expected balance key to exist")
|
||||
|
||||
require.NoError(s.T(), cache.InvalidateUserBalance(ctx, userID), "InvalidateUserBalance")
|
||||
|
||||
exists, err = rdb.Exists(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "Exists after invalidate")
|
||||
require.Equal(s.T(), int64(0), exists, "expected balance key to be removed after invalidate")
|
||||
|
||||
_, err = cache.GetUserBalance(ctx, userID)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deduct_refreshes_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(103)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 100.0), "SetUserBalance")
|
||||
|
||||
ttl1, err := rdb.TTL(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "TTL before deduct")
|
||||
s.AssertTTLWithin(ttl1, 1*time.Second, billingCacheTTL)
|
||||
|
||||
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 25.0), "DeductUserBalance")
|
||||
|
||||
balance, err := cache.GetUserBalance(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserBalance")
|
||||
require.Equal(s.T(), 75.0, balance, "expected balance 75.0")
|
||||
|
||||
ttl2, err := rdb.TTL(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "TTL after deduct")
|
||||
s.AssertTTLWithin(ttl2, 1*time.Second, billingCacheTTL)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BillingCacheSuite) TestSubscriptionCache() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
|
||||
}{
|
||||
{
|
||||
name: "missing_key_returns_redis_nil",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(10)
|
||||
groupID := int64(20)
|
||||
|
||||
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing subscription key")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update_usage_on_nonexistent_is_noop",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(11)
|
||||
groupID := int64(21)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 1.0), "UpdateSubscriptionUsage should not error")
|
||||
|
||||
exists, err := rdb.Exists(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "Exists")
|
||||
require.Equal(s.T(), int64(0), exists, "expected missing subscription key after UpdateSubscriptionUsage on non-existent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set_and_get_with_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(12)
|
||||
groupID := int64(22)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
DailyUsage: 1.0,
|
||||
WeeklyUsage: 2.0,
|
||||
MonthlyUsage: 3.0,
|
||||
Version: 7,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||||
|
||||
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.NoError(s.T(), err, "GetSubscriptionCache")
|
||||
require.Equal(s.T(), "active", gotSub.Status)
|
||||
require.Equal(s.T(), int64(7), gotSub.Version)
|
||||
require.Equal(s.T(), 1.0, gotSub.DailyUsage)
|
||||
|
||||
ttl, err := rdb.TTL(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "TTL subKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update_usage_increments_all_fields",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(13)
|
||||
groupID := int64(23)
|
||||
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
DailyUsage: 1.0,
|
||||
WeeklyUsage: 2.0,
|
||||
MonthlyUsage: 3.0,
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||||
|
||||
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 0.5), "UpdateSubscriptionUsage")
|
||||
|
||||
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.NoError(s.T(), err, "GetSubscriptionCache after update")
|
||||
require.Equal(s.T(), 1.5, gotSub.DailyUsage)
|
||||
require.Equal(s.T(), 2.5, gotSub.WeeklyUsage)
|
||||
require.Equal(s.T(), 3.5, gotSub.MonthlyUsage)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalidate_removes_key",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(101)
|
||||
groupID := int64(10)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
DailyUsage: 1.0,
|
||||
WeeklyUsage: 2.0,
|
||||
MonthlyUsage: 3.0,
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||||
|
||||
exists, err := rdb.Exists(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "Exists")
|
||||
require.Equal(s.T(), int64(1), exists, "expected subscription key to exist")
|
||||
|
||||
require.NoError(s.T(), cache.InvalidateSubscriptionCache(ctx, userID, groupID), "InvalidateSubscriptionCache")
|
||||
|
||||
exists, err = rdb.Exists(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "Exists after invalidate")
|
||||
require.Equal(s.T(), int64(0), exists, "expected subscription key to be removed after invalidate")
|
||||
|
||||
_, err = cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing_status_returns_parsing_error",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(102)
|
||||
groupID := int64(11)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
fields := map[string]any{
|
||||
"expires_at": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"daily_usage": 1.0,
|
||||
"weekly_usage": 2.0,
|
||||
"monthly_usage": 3.0,
|
||||
"version": 1,
|
||||
}
|
||||
require.NoError(s.T(), rdb.HSet(ctx, subKey, fields).Err(), "HSet")
|
||||
|
||||
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.Error(s.T(), err, "expected error for missing status field")
|
||||
require.NotErrorIs(s.T(), err, redis.Nil, "expected parsing error, not redis.Nil")
|
||||
require.Equal(s.T(), "invalid cache: missing status", err.Error())
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(BillingCacheSuite))
|
||||
}
|
||||
87
backend/internal/repository/billing_cache_test.go
Normal file
87
backend/internal/repository/billing_cache_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBillingBalanceKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_user_id",
|
||||
userID: 123,
|
||||
expected: "billing:balance:123",
|
||||
},
|
||||
{
|
||||
name: "zero_user_id",
|
||||
userID: 0,
|
||||
expected: "billing:balance:0",
|
||||
},
|
||||
{
|
||||
name: "negative_user_id",
|
||||
userID: -1,
|
||||
expected: "billing:balance:-1",
|
||||
},
|
||||
{
|
||||
name: "max_int64",
|
||||
userID: math.MaxInt64,
|
||||
expected: "billing:balance:9223372036854775807",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := billingBalanceKey(tc.userID)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingSubKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
groupID int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_ids",
|
||||
userID: 123,
|
||||
groupID: 456,
|
||||
expected: "billing:sub:123:456",
|
||||
},
|
||||
{
|
||||
name: "zero_ids",
|
||||
userID: 0,
|
||||
groupID: 0,
|
||||
expected: "billing:sub:0:0",
|
||||
},
|
||||
{
|
||||
name: "negative_ids",
|
||||
userID: -1,
|
||||
groupID: -2,
|
||||
expected: "billing:sub:-1:-2",
|
||||
},
|
||||
{
|
||||
name: "max_int64_ids",
|
||||
userID: math.MaxInt64,
|
||||
groupID: math.MaxInt64,
|
||||
expected: "billing:sub:9223372036854775807:9223372036854775807",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := billingSubKey(tc.userID, tc.groupID)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
248
backend/internal/repository/claude_oauth_service.go
Normal file
248
backend/internal/repository/claude_oauth_service.go
Normal file
@@ -0,0 +1,248 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
func NewClaudeOAuthClient() service.ClaudeOAuthClient {
|
||||
return &claudeOAuthService{
|
||||
baseURL: "https://claude.ai",
|
||||
tokenURL: oauth.TokenURL,
|
||||
clientFactory: createReqClient,
|
||||
}
|
||||
}
|
||||
|
||||
type claudeOAuthService struct {
|
||||
baseURL string
|
||||
tokenURL string
|
||||
clientFactory func(proxyURL string) *req.Client
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
|
||||
client := s.clientFactory(proxyURL)
|
||||
|
||||
var orgs []struct {
|
||||
UUID string `json:"uuid"`
|
||||
}
|
||||
|
||||
targetURL := s.baseURL + "/api/organizations"
|
||||
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetCookies(&http.Cookie{
|
||||
Name: "sessionKey",
|
||||
Value: sessionKey,
|
||||
}).
|
||||
SetSuccessResult(&orgs).
|
||||
Get(targetURL)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err)
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 1 Response - Status: %d", resp.StatusCode)
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
if len(orgs) == 0 {
|
||||
return "", fmt.Errorf("no organizations found")
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 1 SUCCESS - Got org UUID: %s", orgs[0].UUID)
|
||||
return orgs[0].UUID, nil
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
|
||||
client := s.clientFactory(proxyURL)
|
||||
|
||||
authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"response_type": "code",
|
||||
"client_id": oauth.ClientID,
|
||||
"organization_uuid": orgUUID,
|
||||
"redirect_uri": oauth.RedirectURI,
|
||||
"scope": scope,
|
||||
"state": state,
|
||||
"code_challenge": codeChallenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
|
||||
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
|
||||
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
|
||||
|
||||
var result struct {
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
}
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetCookies(&http.Cookie{
|
||||
Name: "sessionKey",
|
||||
Value: sessionKey,
|
||||
}).
|
||||
SetHeader("Accept", "application/json").
|
||||
SetHeader("Accept-Language", "en-US,en;q=0.9").
|
||||
SetHeader("Cache-Control", "no-cache").
|
||||
SetHeader("Origin", "https://claude.ai").
|
||||
SetHeader("Referer", "https://claude.ai/new").
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&result).
|
||||
Post(authURL)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err)
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
if result.RedirectURI == "" {
|
||||
return "", fmt.Errorf("no redirect_uri in response")
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(result.RedirectURI)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse redirect_uri: %w", err)
|
||||
}
|
||||
|
||||
queryParams := parsedURL.Query()
|
||||
authCode := queryParams.Get("code")
|
||||
responseState := queryParams.Get("state")
|
||||
|
||||
if authCode == "" {
|
||||
return "", fmt.Errorf("no authorization code in redirect_uri")
|
||||
}
|
||||
|
||||
fullCode := authCode
|
||||
if responseState != "" {
|
||||
fullCode = authCode + "#" + responseState
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code")
|
||||
return fullCode, nil
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
|
||||
client := s.clientFactory(proxyURL)
|
||||
|
||||
// Parse code which may contain state in format "authCode#state"
|
||||
authCode := code
|
||||
codeState := ""
|
||||
if idx := strings.Index(code, "#"); idx != -1 {
|
||||
authCode = code[:idx]
|
||||
codeState = code[idx+1:]
|
||||
}
|
||||
|
||||
reqBody := map[string]any{
|
||||
"code": authCode,
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": oauth.ClientID,
|
||||
"redirect_uri": oauth.RedirectURI,
|
||||
"code_verifier": codeVerifier,
|
||||
}
|
||||
|
||||
if codeState != "" {
|
||||
reqBody["state"] = codeState
|
||||
}
|
||||
|
||||
// Setup token requires longer expiration (1 year)
|
||||
if isSetupToken {
|
||||
reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
|
||||
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
|
||||
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
|
||||
|
||||
var tokenResp oauth.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 3 SUCCESS - Got access token")
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
|
||||
client := s.clientFactory(proxyURL)
|
||||
|
||||
// 使用 JSON 格式(与 ExchangeCodeForToken 保持一致)
|
||||
// Anthropic OAuth API 期望 JSON 格式的请求体
|
||||
reqBody := map[string]any{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refreshToken,
|
||||
"client_id": oauth.ClientID,
|
||||
}
|
||||
|
||||
var tokenResp oauth.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func createReqClient(proxyURL string) *req.Client {
|
||||
// 禁用 CookieJar,确保每次授权都是干净的会话
|
||||
client := req.C().
|
||||
SetTimeout(60 * time.Second).
|
||||
ImpersonateChrome().
|
||||
SetCookieJar(nil) // 禁用 CookieJar
|
||||
|
||||
if strings.TrimSpace(proxyURL) != "" {
|
||||
client.SetProxyURL(strings.TrimSpace(proxyURL))
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
396
backend/internal/repository/claude_oauth_service_test.go
Normal file
396
backend/internal/repository/claude_oauth_service_test.go
Normal file
@@ -0,0 +1,396 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ClaudeOAuthServiceSuite struct {
|
||||
suite.Suite
|
||||
client *claudeOAuthService
|
||||
}
|
||||
|
||||
// requestCapture holds captured request data for assertions in the main goroutine.
|
||||
type requestCapture struct {
|
||||
path string
|
||||
method string
|
||||
cookies []*http.Cookie
|
||||
body []byte
|
||||
bodyJSON map[string]any
|
||||
contentType string
|
||||
}
|
||||
|
||||
func newTestReqClient(rt http.RoundTripper) *req.Client {
|
||||
c := req.C()
|
||||
c.GetClient().Transport = rt
|
||||
return c
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
wantErr bool
|
||||
errContain string
|
||||
wantUUID string
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`[{"uuid":"org-1"}]`))
|
||||
},
|
||||
wantUUID: "org-1",
|
||||
validate: func(captured requestCapture) {
|
||||
require.Equal(s.T(), "/api/organizations", captured.path, "unexpected path")
|
||||
require.Len(s.T(), captured.cookies, 1, "expected 1 cookie")
|
||||
require.Equal(s.T(), "sessionKey", captured.cookies[0].Name)
|
||||
require.Equal(s.T(), "sess", captured.cookies[0].Value)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non_200_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte("unauthorized"))
|
||||
},
|
||||
wantErr: true,
|
||||
errContain: "401",
|
||||
},
|
||||
{
|
||||
name: "invalid_json_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte("not-json"))
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.path = r.URL.Path
|
||||
captured.cookies = r.Cookies()
|
||||
tt.handler(w, r)
|
||||
}), nil)
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.baseURL = "http://in-process"
|
||||
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
|
||||
|
||||
got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "")
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
if tt.errContain != "" {
|
||||
require.ErrorContains(s.T(), err, tt.errContain)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantUUID, got)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
wantErr bool
|
||||
wantCode string
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "parses_redirect_uri",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"redirect_uri": oauth.RedirectURI + "?code=AUTH&state=STATE",
|
||||
})
|
||||
},
|
||||
wantCode: "AUTH#STATE",
|
||||
validate: func(captured requestCapture) {
|
||||
require.True(s.T(), strings.HasPrefix(captured.path, "/v1/oauth/") && strings.HasSuffix(captured.path, "/authorize"), "unexpected path: %s", captured.path)
|
||||
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
||||
require.Len(s.T(), captured.cookies, 1, "expected 1 cookie")
|
||||
require.Equal(s.T(), "sess", captured.cookies[0].Value)
|
||||
require.Equal(s.T(), "org-1", captured.bodyJSON["organization_uuid"])
|
||||
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
|
||||
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
|
||||
require.Equal(s.T(), "st", captured.bodyJSON["state"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing_code_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"redirect_uri": oauth.RedirectURI + "?state=STATE", // no code
|
||||
})
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.path = r.URL.Path
|
||||
captured.method = r.Method
|
||||
captured.cookies = r.Cookies()
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||
tt.handler(w, r)
|
||||
}), nil)
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.baseURL = "http://in-process"
|
||||
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
|
||||
|
||||
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "")
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantCode, code)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
code string
|
||||
isSetupToken bool
|
||||
wantErr bool
|
||||
wantResp *oauth.TokenResponse
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "sends_state_when_embedded",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
TokenType: "bearer",
|
||||
ExpiresIn: 3600,
|
||||
RefreshToken: "rt",
|
||||
Scope: "s",
|
||||
})
|
||||
},
|
||||
code: "AUTH#STATE2",
|
||||
isSetupToken: false,
|
||||
wantResp: &oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
RefreshToken: "rt",
|
||||
},
|
||||
validate: func(captured requestCapture) {
|
||||
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
||||
require.True(s.T(), strings.HasPrefix(captured.contentType, "application/json"), "unexpected content-type")
|
||||
require.Equal(s.T(), "AUTH", captured.bodyJSON["code"])
|
||||
require.Equal(s.T(), "STATE2", captured.bodyJSON["state"])
|
||||
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
|
||||
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
|
||||
require.Equal(s.T(), "ver", captured.bodyJSON["code_verifier"])
|
||||
// Regular OAuth should not include expires_in
|
||||
require.Nil(s.T(), captured.bodyJSON["expires_in"], "regular OAuth should not include expires_in")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "setup_token_includes_expires_in",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
TokenType: "bearer",
|
||||
ExpiresIn: 31536000,
|
||||
})
|
||||
},
|
||||
code: "AUTH",
|
||||
isSetupToken: true,
|
||||
wantResp: &oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
},
|
||||
validate: func(captured requestCapture) {
|
||||
// Setup token should include expires_in with 1 year value
|
||||
require.Equal(s.T(), float64(31536000), captured.bodyJSON["expires_in"],
|
||||
"setup token should include expires_in: 31536000")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non_200_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte("bad request"))
|
||||
},
|
||||
code: "AUTH",
|
||||
isSetupToken: false,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.method = r.Method
|
||||
captured.contentType = r.Header.Get("Content-Type")
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||
tt.handler(w, r)
|
||||
}), nil)
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.tokenURL = "http://in-process/token"
|
||||
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
|
||||
|
||||
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
|
||||
require.Equal(s.T(), tt.wantResp.RefreshToken, resp.RefreshToken)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
wantErr bool
|
||||
wantResp *oauth.TokenResponse
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "sends_json_format",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||
AccessToken: "new_access_token",
|
||||
TokenType: "bearer",
|
||||
ExpiresIn: 28800,
|
||||
RefreshToken: "new_refresh_token",
|
||||
Scope: "user:profile user:inference",
|
||||
})
|
||||
},
|
||||
wantResp: &oauth.TokenResponse{
|
||||
AccessToken: "new_access_token",
|
||||
RefreshToken: "new_refresh_token",
|
||||
},
|
||||
validate: func(captured requestCapture) {
|
||||
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
||||
// 验证使用 JSON 格式(不是 form 格式)
|
||||
require.True(s.T(), strings.HasPrefix(captured.contentType, "application/json"),
|
||||
"expected JSON content-type, got: %s", captured.contentType)
|
||||
// 验证 JSON body 内容
|
||||
require.Equal(s.T(), "refresh_token", captured.bodyJSON["grant_type"])
|
||||
require.Equal(s.T(), "rt", captured.bodyJSON["refresh_token"])
|
||||
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "returns_new_refresh_token",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
TokenType: "bearer",
|
||||
ExpiresIn: 28800,
|
||||
RefreshToken: "rotated_rt", // Anthropic rotates refresh tokens
|
||||
})
|
||||
},
|
||||
wantResp: &oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
RefreshToken: "rotated_rt",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non_200_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"error":"invalid_grant"}`))
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.method = r.Method
|
||||
captured.contentType = r.Header.Get("Content-Type")
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||
tt.handler(w, r)
|
||||
}), nil)
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.tokenURL = "http://in-process/token"
|
||||
s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
|
||||
|
||||
resp, err := s.client.RefreshToken(context.Background(), "rt", "")
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
|
||||
require.Equal(s.T(), tt.wantResp.RefreshToken, resp.RefreshToken)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeOAuthServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(ClaudeOAuthServiceSuite))
|
||||
}
|
||||
62
backend/internal/repository/claude_usage_service.go
Normal file
62
backend/internal/repository/claude_usage_service.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
|
||||
|
||||
type claudeUsageService struct {
|
||||
usageURL string
|
||||
allowPrivateHosts bool
|
||||
}
|
||||
|
||||
func NewClaudeUsageFetcher() service.ClaudeUsageFetcher {
|
||||
return &claudeUsageService{usageURL: defaultClaudeUsageURL}
|
||||
}
|
||||
|
||||
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 30 * time.Second,
|
||||
ValidateResolvedIP: true,
|
||||
AllowPrivateHosts: s.allowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
client = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request failed: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var usageResp service.ClaudeUsageResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil {
|
||||
return nil, fmt.Errorf("decode response failed: %w", err)
|
||||
}
|
||||
|
||||
return &usageResp, nil
|
||||
}
|
||||
117
backend/internal/repository/claude_usage_service_test.go
Normal file
117
backend/internal/repository/claude_usage_service_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ClaudeUsageServiceSuite struct {
|
||||
suite.Suite
|
||||
srv *httptest.Server
|
||||
fetcher *claudeUsageService
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
// usageRequestCapture holds captured request data for assertions in the main goroutine.
|
||||
type usageRequestCapture struct {
|
||||
authorization string
|
||||
anthropicBeta string
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
|
||||
var captured usageRequestCapture
|
||||
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.authorization = r.Header.Get("Authorization")
|
||||
captured.anthropicBeta = r.Header.Get("anthropic-beta")
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{
|
||||
"five_hour": {"utilization": 12.5, "resets_at": "2025-01-01T00:00:00Z"},
|
||||
"seven_day": {"utilization": 34.0, "resets_at": "2025-01-08T00:00:00Z"},
|
||||
"seven_day_sonnet": {"utilization": 56.0, "resets_at": "2025-01-08T00:00:00Z"}
|
||||
}`)
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{
|
||||
usageURL: s.srv.URL,
|
||||
allowPrivateHosts: true,
|
||||
}
|
||||
|
||||
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
|
||||
require.NoError(s.T(), err, "FetchUsage")
|
||||
require.Equal(s.T(), 12.5, resp.FiveHour.Utilization, "FiveHour utilization mismatch")
|
||||
require.Equal(s.T(), 34.0, resp.SevenDay.Utilization, "SevenDay utilization mismatch")
|
||||
require.Equal(s.T(), 56.0, resp.SevenDaySonnet.Utilization, "SevenDaySonnet utilization mismatch")
|
||||
|
||||
// Assertions on captured request data
|
||||
require.Equal(s.T(), "Bearer at", captured.authorization, "Authorization header mismatch")
|
||||
require.Equal(s.T(), "oauth-2025-04-20", captured.anthropicBeta, "anthropic-beta header mismatch")
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = io.WriteString(w, "nope")
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{
|
||||
usageURL: s.srv.URL,
|
||||
allowPrivateHosts: true,
|
||||
}
|
||||
|
||||
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "status 401")
|
||||
require.ErrorContains(s.T(), err, "nope")
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, "not-json")
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{
|
||||
usageURL: s.srv.URL,
|
||||
allowPrivateHosts: true,
|
||||
}
|
||||
|
||||
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "decode response failed")
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Never respond - simulate slow server
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{
|
||||
usageURL: s.srv.URL,
|
||||
allowPrivateHosts: true,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
_, err := s.fetcher.FetchUsage(ctx, "at", "")
|
||||
require.Error(s.T(), err, "expected error for cancelled context")
|
||||
}
|
||||
|
||||
func TestClaudeUsageServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(ClaudeUsageServiceSuite))
|
||||
}
|
||||
391
backend/internal/repository/concurrency_cache.go
Normal file
391
backend/internal/repository/concurrency_cache.go
Normal file
@@ -0,0 +1,391 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// 并发控制缓存常量定义
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现使用 SCAN 命令遍历独立的槽位键(concurrency:account:{id}:{requestID}),
|
||||
// 在高并发场景下 SCAN 需要多次往返,且遍历大量键时性能下降明显。
|
||||
//
|
||||
// 新实现改用 Redis 有序集合(Sorted Set):
|
||||
// 1. 每个账号/用户只有一个键,成员为 requestID,分数为时间戳
|
||||
// 2. 使用 ZCARD 原子获取并发数,时间复杂度 O(1)
|
||||
// 3. 使用 ZREMRANGEBYSCORE 清理过期槽位,避免手动管理 TTL
|
||||
// 4. 单次 Redis 调用完成计数,减少网络往返
|
||||
const (
|
||||
// 并发槽位键前缀(有序集合)
|
||||
// 格式: concurrency:account:{accountID}
|
||||
accountSlotKeyPrefix = "concurrency:account:"
|
||||
// 格式: concurrency:user:{userID}
|
||||
userSlotKeyPrefix = "concurrency:user:"
|
||||
// 等待队列计数器格式: concurrency:wait:{userID}
|
||||
waitQueueKeyPrefix = "concurrency:wait:"
|
||||
// 账号级等待队列计数器格式: wait:account:{accountID}
|
||||
accountWaitKeyPrefix = "wait:account:"
|
||||
|
||||
// 默认槽位过期时间(分钟),可通过配置覆盖
|
||||
defaultSlotTTLMinutes = 15
|
||||
)
|
||||
|
||||
var (
|
||||
// acquireScript 使用有序集合计数并在未达上限时添加槽位
|
||||
// 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步问题
|
||||
// KEYS[1] = 有序集合键 (concurrency:account:{id} / concurrency:user:{id})
|
||||
// ARGV[1] = maxConcurrency
|
||||
// ARGV[2] = TTL(秒)
|
||||
// ARGV[3] = requestID
|
||||
acquireScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local maxConcurrency = tonumber(ARGV[1])
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local requestID = ARGV[3]
|
||||
|
||||
-- 使用 Redis 服务器时间,确保多实例时钟一致
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
|
||||
-- 清理过期槽位
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
|
||||
-- 检查是否已存在(支持重试场景刷新时间戳)
|
||||
local exists = redis.call('ZSCORE', key, requestID)
|
||||
if exists ~= false then
|
||||
redis.call('ZADD', key, now, requestID)
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
return 1
|
||||
end
|
||||
|
||||
-- 检查是否达到并发上限
|
||||
local count = redis.call('ZCARD', key)
|
||||
if count < maxConcurrency then
|
||||
redis.call('ZADD', key, now, requestID)
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
return 1
|
||||
end
|
||||
|
||||
return 0
|
||||
`)
|
||||
|
||||
// getCountScript 统计有序集合中的槽位数量并清理过期条目
|
||||
// 使用 Redis TIME 命令获取服务器时间
|
||||
// KEYS[1] = 有序集合键
|
||||
// ARGV[1] = TTL(秒)
|
||||
getCountScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
|
||||
-- 使用 Redis 服务器时间
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
return redis.call('ZCARD', key)
|
||||
`)
|
||||
|
||||
// incrementWaitScript - refreshes TTL on each increment to keep queue depth accurate
|
||||
// KEYS[1] = wait queue key
|
||||
// ARGV[1] = maxWait
|
||||
// ARGV[2] = TTL in seconds
|
||||
incrementWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
|
||||
if current >= tonumber(ARGV[1]) then
|
||||
return 0
|
||||
end
|
||||
|
||||
local newVal = redis.call('INCR', KEYS[1])
|
||||
|
||||
-- Refresh TTL so long-running traffic doesn't expire active queue counters.
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
|
||||
return 1
|
||||
`)
|
||||
|
||||
// incrementAccountWaitScript - account-level wait queue count (refresh TTL on each increment)
|
||||
incrementAccountWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
|
||||
if current >= tonumber(ARGV[1]) then
|
||||
return 0
|
||||
end
|
||||
|
||||
local newVal = redis.call('INCR', KEYS[1])
|
||||
|
||||
-- Refresh TTL so long-running traffic doesn't expire active queue counters.
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
|
||||
return 1
|
||||
`)
|
||||
|
||||
// decrementWaitScript - same as before
|
||||
decrementWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
|
||||
// getAccountsLoadBatchScript - batch load query with expired slot cleanup
|
||||
// ARGV[1] = slot TTL (seconds)
|
||||
// ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ...
|
||||
getAccountsLoadBatchScript = redis.NewScript(`
|
||||
local result = {}
|
||||
local slotTTL = tonumber(ARGV[1])
|
||||
|
||||
-- Get current server time
|
||||
local timeResult = redis.call('TIME')
|
||||
local nowSeconds = tonumber(timeResult[1])
|
||||
local cutoffTime = nowSeconds - slotTTL
|
||||
|
||||
local i = 2
|
||||
while i <= #ARGV do
|
||||
local accountID = ARGV[i]
|
||||
local maxConcurrency = tonumber(ARGV[i + 1])
|
||||
|
||||
local slotKey = 'concurrency:account:' .. accountID
|
||||
|
||||
-- Clean up expired slots before counting
|
||||
redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime)
|
||||
local currentConcurrency = redis.call('ZCARD', slotKey)
|
||||
|
||||
local waitKey = 'wait:account:' .. accountID
|
||||
local waitingCount = redis.call('GET', waitKey)
|
||||
if waitingCount == false then
|
||||
waitingCount = 0
|
||||
else
|
||||
waitingCount = tonumber(waitingCount)
|
||||
end
|
||||
|
||||
local loadRate = 0
|
||||
if maxConcurrency > 0 then
|
||||
loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency)
|
||||
end
|
||||
|
||||
table.insert(result, accountID)
|
||||
table.insert(result, currentConcurrency)
|
||||
table.insert(result, waitingCount)
|
||||
table.insert(result, loadRate)
|
||||
|
||||
i = i + 2
|
||||
end
|
||||
|
||||
return result
|
||||
`)
|
||||
|
||||
// cleanupExpiredSlotsScript - remove expired slots
|
||||
// KEYS[1] = concurrency:account:{accountID}
|
||||
// ARGV[1] = TTL (seconds)
|
||||
cleanupExpiredSlotsScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
return redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
`)
|
||||
)
|
||||
|
||||
type concurrencyCache struct {
|
||||
rdb *redis.Client
|
||||
slotTTLSeconds int // 槽位过期时间(秒)
|
||||
waitQueueTTLSeconds int // 等待队列过期时间(秒)
|
||||
}
|
||||
|
||||
// NewConcurrencyCache 创建并发控制缓存
|
||||
// slotTTLMinutes: 槽位过期时间(分钟),0 或负数使用默认值 15 分钟
|
||||
// waitQueueTTLSeconds: 等待队列过期时间(秒),0 或负数使用 slot TTL
|
||||
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int, waitQueueTTLSeconds int) service.ConcurrencyCache {
|
||||
if slotTTLMinutes <= 0 {
|
||||
slotTTLMinutes = defaultSlotTTLMinutes
|
||||
}
|
||||
if waitQueueTTLSeconds <= 0 {
|
||||
waitQueueTTLSeconds = slotTTLMinutes * 60
|
||||
}
|
||||
return &concurrencyCache{
|
||||
rdb: rdb,
|
||||
slotTTLSeconds: slotTTLMinutes * 60,
|
||||
waitQueueTTLSeconds: waitQueueTTLSeconds,
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for key generation
|
||||
func accountSlotKey(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
func userSlotKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
}
|
||||
|
||||
func waitQueueKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
}
|
||||
|
||||
func accountWaitKey(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
// Account slot operations
|
||||
|
||||
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
key := accountSlotKey(accountID)
|
||||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||
key := accountSlotKey(accountID)
|
||||
return c.rdb.ZRem(ctx, key, requestID).Err()
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||
key := accountSlotKey(accountID)
|
||||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
|
||||
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// User slot operations
|
||||
|
||||
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
key := userSlotKey(userID)
|
||||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
|
||||
key := userSlotKey(userID)
|
||||
return c.rdb.ZRem(ctx, key, requestID).Err()
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||
key := userSlotKey(userID)
|
||||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
|
||||
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Wait queue operations
|
||||
|
||||
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
key := waitQueueKey(userID)
|
||||
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||
key := waitQueueKey(userID)
|
||||
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
// Account wait queue operations
|
||||
|
||||
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
key := accountWaitKey(accountID)
|
||||
result, err := incrementAccountWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
||||
key := accountWaitKey(accountID)
|
||||
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
key := accountWaitKey(accountID)
|
||||
val, err := c.rdb.Get(ctx, key).Int()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return 0, err
|
||||
}
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return 0, nil
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
|
||||
if len(accounts) == 0 {
|
||||
return map[int64]*service.AccountLoadInfo{}, nil
|
||||
}
|
||||
|
||||
args := []any{c.slotTTLSeconds}
|
||||
for _, acc := range accounts {
|
||||
args = append(args, acc.ID, acc.MaxConcurrency)
|
||||
}
|
||||
|
||||
result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loadMap := make(map[int64]*service.AccountLoadInfo)
|
||||
for i := 0; i < len(result); i += 4 {
|
||||
if i+3 >= len(result) {
|
||||
break
|
||||
}
|
||||
|
||||
accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64)
|
||||
currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1]))
|
||||
waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2]))
|
||||
loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3]))
|
||||
|
||||
loadMap[accountID] = &service.AccountLoadInfo{
|
||||
AccountID: accountID,
|
||||
CurrentConcurrency: currentConcurrency,
|
||||
WaitingCount: waitingCount,
|
||||
LoadRate: loadRate,
|
||||
}
|
||||
}
|
||||
|
||||
return loadMap, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||||
key := accountSlotKey(accountID)
|
||||
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
|
||||
return err
|
||||
}
|
||||
135
backend/internal/repository/concurrency_cache_benchmark_test.go
Normal file
135
backend/internal/repository/concurrency_cache_benchmark_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// 基准测试用 TTL 配置
|
||||
const benchSlotTTLMinutes = 15
|
||||
|
||||
var benchSlotTTL = time.Duration(benchSlotTTLMinutes) * time.Minute
|
||||
|
||||
// BenchmarkAccountConcurrency 用于对比 SCAN 与有序集合的计数性能。
|
||||
func BenchmarkAccountConcurrency(b *testing.B) {
|
||||
rdb := newBenchmarkRedisClient(b)
|
||||
defer func() {
|
||||
_ = rdb.Close()
|
||||
}()
|
||||
|
||||
cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes, int(benchSlotTTL.Seconds())).(*concurrencyCache)
|
||||
ctx := context.Background()
|
||||
|
||||
for _, size := range []int{10, 100, 1000} {
|
||||
size := size
|
||||
b.Run(fmt.Sprintf("zset/slots=%d", size), func(b *testing.B) {
|
||||
accountID := time.Now().UnixNano()
|
||||
key := accountSlotKey(accountID)
|
||||
|
||||
b.StopTimer()
|
||||
members := make([]redis.Z, 0, size)
|
||||
now := float64(time.Now().Unix())
|
||||
for i := 0; i < size; i++ {
|
||||
members = append(members, redis.Z{
|
||||
Score: now,
|
||||
Member: fmt.Sprintf("req_%d", i),
|
||||
})
|
||||
}
|
||||
if err := rdb.ZAdd(ctx, key, members...).Err(); err != nil {
|
||||
b.Fatalf("初始化有序集合失败: %v", err)
|
||||
}
|
||||
if err := rdb.Expire(ctx, key, benchSlotTTL).Err(); err != nil {
|
||||
b.Fatalf("设置有序集合 TTL 失败: %v", err)
|
||||
}
|
||||
b.StartTimer()
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := cache.GetAccountConcurrency(ctx, accountID); err != nil {
|
||||
b.Fatalf("获取并发数量失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
b.StopTimer()
|
||||
if err := rdb.Del(ctx, key).Err(); err != nil {
|
||||
b.Fatalf("清理有序集合失败: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run(fmt.Sprintf("scan/slots=%d", size), func(b *testing.B) {
|
||||
accountID := time.Now().UnixNano()
|
||||
pattern := fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
|
||||
keys := make([]string, 0, size)
|
||||
|
||||
b.StopTimer()
|
||||
pipe := rdb.Pipeline()
|
||||
for i := 0; i < size; i++ {
|
||||
key := fmt.Sprintf("%s%d:req_%d", accountSlotKeyPrefix, accountID, i)
|
||||
keys = append(keys, key)
|
||||
pipe.Set(ctx, key, "1", benchSlotTTL)
|
||||
}
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
b.Fatalf("初始化扫描键失败: %v", err)
|
||||
}
|
||||
b.StartTimer()
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := scanSlotCount(ctx, rdb, pattern); err != nil {
|
||||
b.Fatalf("SCAN 计数失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
b.StopTimer()
|
||||
if err := rdb.Del(ctx, keys...).Err(); err != nil {
|
||||
b.Fatalf("清理扫描键失败: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func scanSlotCount(ctx context.Context, rdb *redis.Client, pattern string) (int, error) {
|
||||
var cursor uint64
|
||||
count := 0
|
||||
for {
|
||||
keys, nextCursor, err := rdb.Scan(ctx, cursor, pattern, 100).Result()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
count += len(keys)
|
||||
if nextCursor == 0 {
|
||||
break
|
||||
}
|
||||
cursor = nextCursor
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func newBenchmarkRedisClient(b *testing.B) *redis.Client {
|
||||
b.Helper()
|
||||
|
||||
redisURL := os.Getenv("TEST_REDIS_URL")
|
||||
if redisURL == "" {
|
||||
b.Skip("未设置 TEST_REDIS_URL,跳过 Redis 基准测试")
|
||||
}
|
||||
|
||||
opt, err := redis.ParseURL(redisURL)
|
||||
if err != nil {
|
||||
b.Fatalf("解析 TEST_REDIS_URL 失败: %v", err)
|
||||
}
|
||||
|
||||
client := redis.NewClient(opt)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
b.Fatalf("Redis 连接失败: %v", err)
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
@@ -0,0 +1,412 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// 测试用 TTL 配置(15 分钟,与默认值一致)
|
||||
const testSlotTTLMinutes = 15
|
||||
|
||||
// 测试用 TTL Duration,用于 TTL 断言
|
||||
var testSlotTTL = time.Duration(testSlotTTLMinutes) * time.Minute
|
||||
|
||||
type ConcurrencyCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache service.ConcurrencyCache
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
|
||||
accountID := int64(10)
|
||||
reqID1, reqID2, reqID3 := "req1", "req2", "req3"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID1)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot 1")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID2)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot 2")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID3)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot 3")
|
||||
require.False(s.T(), ok, "expected third acquire to fail")
|
||||
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err, "GetAccountConcurrency")
|
||||
require.Equal(s.T(), 2, cur, "concurrency mismatch")
|
||||
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID1), "ReleaseAccountSlot")
|
||||
|
||||
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err, "GetAccountConcurrency after release")
|
||||
require.Equal(s.T(), 1, cur, "expected 1 after release")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
|
||||
accountID := int64(11)
|
||||
reqID := "req_ttl_test"
|
||||
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() {
|
||||
accountID := int64(12)
|
||||
reqID := "dup-req"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Acquiring with same reqID should be idempotent
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 1, cur, "expected concurrency=1 (idempotent)")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_ReleaseIdempotent() {
|
||||
accountID := int64(13)
|
||||
reqID := "release-test"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 1, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot")
|
||||
// Releasing again should not error
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot again")
|
||||
// Releasing non-existent should not error
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, "non-existent"), "ReleaseAccountSlot non-existent")
|
||||
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 0, cur)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_MaxZero() {
|
||||
accountID := int64(14)
|
||||
reqID := "max-zero-test"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 0, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.False(s.T(), ok, "expected acquire to fail with max=0")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() {
|
||||
userID := int64(42)
|
||||
reqID1, reqID2 := "req1", "req2"
|
||||
|
||||
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID1)
|
||||
require.NoError(s.T(), err, "AcquireUserSlot")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID2)
|
||||
require.NoError(s.T(), err, "AcquireUserSlot 2")
|
||||
require.False(s.T(), ok, "expected second acquire to fail at max=1")
|
||||
|
||||
cur, err := s.cache.GetUserConcurrency(s.ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserConcurrency")
|
||||
require.Equal(s.T(), 1, cur, "expected concurrency=1")
|
||||
|
||||
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, reqID1), "ReleaseUserSlot")
|
||||
// Releasing a non-existent slot should not error
|
||||
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, "non-existent"), "ReleaseUserSlot non-existent")
|
||||
|
||||
cur, err = s.cache.GetUserConcurrency(s.ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserConcurrency after release")
|
||||
require.Equal(s.T(), 0, cur, "expected concurrency=0 after release")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
|
||||
userID := int64(200)
|
||||
reqID := "req_ttl_test"
|
||||
slotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
|
||||
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID)
|
||||
require.NoError(s.T(), err, "AcquireUserSlot")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
|
||||
userID := int64(20)
|
||||
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
|
||||
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 1")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 2")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 3")
|
||||
require.False(s.T(), ok, "expected wait increment over max to fail")
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
|
||||
require.NoError(s.T(), err, "TTL waitKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
|
||||
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.Equal(s.T(), 1, val, "expected wait count 1")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
|
||||
userID := int64(300)
|
||||
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
|
||||
// Test decrement on non-existent key - should not error and should not create negative value
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key")
|
||||
|
||||
// Verify no key was created or it's not negative
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty")
|
||||
|
||||
// Set count to 1, then decrement twice
|
||||
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 5)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Decrement once (1 -> 0)
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
|
||||
|
||||
// Decrement again on 0 - should not go negative
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero")
|
||||
|
||||
// Verify count is 0, not negative
|
||||
val, err = s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey after double decrement")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
|
||||
accountID := int64(30)
|
||||
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
|
||||
ok, err := s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
|
||||
require.NoError(s.T(), err, "IncrementAccountWaitCount 1")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
|
||||
require.NoError(s.T(), err, "IncrementAccountWaitCount 2")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
|
||||
require.NoError(s.T(), err, "IncrementAccountWaitCount 3")
|
||||
require.False(s.T(), ok, "expected account wait increment over max to fail")
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
|
||||
require.NoError(s.T(), err, "TTL account waitKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount")
|
||||
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.Equal(s.T(), 1, val, "expected account wait count 1")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_DecrementNoNegative() {
|
||||
accountID := int64(301)
|
||||
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount on non-existent key")
|
||||
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative account wait count after decrement on empty")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
|
||||
// When no slots exist, GetAccountConcurrency should return 0
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, 999)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 0, cur)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
|
||||
// When no slots exist, GetUserConcurrency should return 0
|
||||
cur, err := s.cache.GetUserConcurrency(s.ctx, 999)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 0, cur)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch() {
|
||||
s.T().Skip("TODO: Fix this test - CurrentConcurrency returns 0 instead of expected value in CI")
|
||||
// Setup: Create accounts with different load states
|
||||
account1 := int64(100)
|
||||
account2 := int64(101)
|
||||
account3 := int64(102)
|
||||
|
||||
// Account 1: 2/3 slots used, 1 waiting
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req1")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req2")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, account1, 5)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Account 2: 1/2 slots used, 0 waiting
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, account2, 2, "req3")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Account 3: 0/1 slots used, 0 waiting (idle)
|
||||
|
||||
// Query batch load
|
||||
accounts := []service.AccountWithConcurrency{
|
||||
{ID: account1, MaxConcurrency: 3},
|
||||
{ID: account2, MaxConcurrency: 2},
|
||||
{ID: account3, MaxConcurrency: 1},
|
||||
}
|
||||
|
||||
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, accounts)
|
||||
require.NoError(s.T(), err)
|
||||
require.Len(s.T(), loadMap, 3)
|
||||
|
||||
// Verify account1: (2 + 1) / 3 = 100%
|
||||
load1 := loadMap[account1]
|
||||
require.NotNil(s.T(), load1)
|
||||
require.Equal(s.T(), account1, load1.AccountID)
|
||||
require.Equal(s.T(), 2, load1.CurrentConcurrency)
|
||||
require.Equal(s.T(), 1, load1.WaitingCount)
|
||||
require.Equal(s.T(), 100, load1.LoadRate)
|
||||
|
||||
// Verify account2: (1 + 0) / 2 = 50%
|
||||
load2 := loadMap[account2]
|
||||
require.NotNil(s.T(), load2)
|
||||
require.Equal(s.T(), account2, load2.AccountID)
|
||||
require.Equal(s.T(), 1, load2.CurrentConcurrency)
|
||||
require.Equal(s.T(), 0, load2.WaitingCount)
|
||||
require.Equal(s.T(), 50, load2.LoadRate)
|
||||
|
||||
// Verify account3: (0 + 0) / 1 = 0%
|
||||
load3 := loadMap[account3]
|
||||
require.NotNil(s.T(), load3)
|
||||
require.Equal(s.T(), account3, load3.AccountID)
|
||||
require.Equal(s.T(), 0, load3.CurrentConcurrency)
|
||||
require.Equal(s.T(), 0, load3.WaitingCount)
|
||||
require.Equal(s.T(), 0, load3.LoadRate)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch_Empty() {
|
||||
// Test with empty account list
|
||||
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, []service.AccountWithConcurrency{})
|
||||
require.NoError(s.T(), err)
|
||||
require.Empty(s.T(), loadMap)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots() {
|
||||
accountID := int64(200)
|
||||
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
|
||||
// Acquire 3 slots
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req3")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Verify 3 slots exist
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 3, cur)
|
||||
|
||||
// Manually set old timestamps for req1 and req2 (simulate expired slots)
|
||||
now := time.Now().Unix()
|
||||
expiredTime := now - int64(testSlotTTL.Seconds()) - 10 // 10 seconds past TTL
|
||||
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req1"}).Err()
|
||||
require.NoError(s.T(), err)
|
||||
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req2"}).Err()
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// Run cleanup
|
||||
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// Verify only 1 slot remains (req3)
|
||||
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 1, cur)
|
||||
|
||||
// Verify req3 still exists
|
||||
members, err := s.rdb.ZRange(s.ctx, slotKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Len(s.T(), members, 1)
|
||||
require.Equal(s.T(), "req3", members[0])
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
|
||||
accountID := int64(201)
|
||||
|
||||
// Acquire 2 fresh slots
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Run cleanup (should not remove anything)
|
||||
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// Verify both slots still exist
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 2, cur)
|
||||
}
|
||||
|
||||
func TestConcurrencyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ConcurrencyCacheSuite))
|
||||
}
|
||||
392
backend/internal/repository/dashboard_aggregation_repo.go
Normal file
392
backend/internal/repository/dashboard_aggregation_repo.go
Normal file
@@ -0,0 +1,392 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
type dashboardAggregationRepository struct {
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
|
||||
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
|
||||
if sqlDB == nil {
|
||||
return nil
|
||||
}
|
||||
if !isPostgresDriver(sqlDB) {
|
||||
log.Printf("[DashboardAggregation] 检测到非 PostgreSQL 驱动,已自动禁用预聚合")
|
||||
return nil
|
||||
}
|
||||
return newDashboardAggregationRepositoryWithSQL(sqlDB)
|
||||
}
|
||||
|
||||
func newDashboardAggregationRepositoryWithSQL(sqlq sqlExecutor) *dashboardAggregationRepository {
|
||||
return &dashboardAggregationRepository{sql: sqlq}
|
||||
}
|
||||
|
||||
func isPostgresDriver(db *sql.DB) bool {
|
||||
if db == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := db.Driver().(*pq.Driver)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
loc := timezone.Location()
|
||||
startLocal := start.In(loc)
|
||||
endLocal := end.In(loc)
|
||||
if !endLocal.After(startLocal) {
|
||||
return nil
|
||||
}
|
||||
|
||||
hourStart := startLocal.Truncate(time.Hour)
|
||||
hourEnd := endLocal.Truncate(time.Hour)
|
||||
if endLocal.After(hourEnd) {
|
||||
hourEnd = hourEnd.Add(time.Hour)
|
||||
}
|
||||
|
||||
dayStart := truncateToDay(startLocal)
|
||||
dayEnd := truncateToDay(endLocal)
|
||||
if endLocal.After(dayEnd) {
|
||||
dayEnd = dayEnd.Add(24 * time.Hour)
|
||||
}
|
||||
|
||||
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
|
||||
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.insertDailyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.upsertHourlyAggregates(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.upsertDailyAggregates(ctx, dayStart, dayEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||
var ts time.Time
|
||||
query := "SELECT last_aggregated_at FROM usage_dashboard_aggregation_watermark WHERE id = 1"
|
||||
if err := scanSingleRow(ctx, r.sql, query, nil, &ts); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return time.Unix(0, 0).UTC(), nil
|
||||
}
|
||||
return time.Time{}, err
|
||||
}
|
||||
return ts.UTC(), nil
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
|
||||
query := `
|
||||
INSERT INTO usage_dashboard_aggregation_watermark (id, last_aggregated_at, updated_at)
|
||||
VALUES (1, $1, NOW())
|
||||
ON CONFLICT (id)
|
||||
DO UPDATE SET last_aggregated_at = EXCLUDED.last_aggregated_at, updated_at = EXCLUDED.updated_at
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, aggregatedAt.UTC())
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
|
||||
hourlyCutoffUTC := hourlyCutoff.UTC()
|
||||
dailyCutoffUTC := dailyCutoff.UTC()
|
||||
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly WHERE bucket_start < $1", hourlyCutoffUTC); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly_users WHERE bucket_start < $1", hourlyCutoffUTC); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily WHERE bucket_date < $1::date", dailyCutoffUTC); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily_users WHERE bucket_date < $1::date", dailyCutoffUTC); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||
isPartitioned, err := r.isUsageLogsPartitioned(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if isPartitioned {
|
||||
return r.dropUsageLogsPartitions(ctx, cutoff)
|
||||
}
|
||||
_, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC())
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
isPartitioned, err := r.isUsageLogsPartitioned(ctx)
|
||||
if err != nil || !isPartitioned {
|
||||
return err
|
||||
}
|
||||
monthStart := truncateToMonthUTC(now)
|
||||
prevMonth := monthStart.AddDate(0, -1, 0)
|
||||
nextMonth := monthStart.AddDate(0, 1, 0)
|
||||
|
||||
for _, m := range []time.Time{prevMonth, monthStart, nextMonth} {
|
||||
if err := r.createUsageLogsPartition(ctx, m); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) insertHourlyActiveUsers(ctx context.Context, start, end time.Time) error {
|
||||
tzName := timezone.Name()
|
||||
query := `
|
||||
INSERT INTO usage_dashboard_hourly_users (bucket_start, user_id)
|
||||
SELECT DISTINCT
|
||||
date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
|
||||
user_id
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
ON CONFLICT DO NOTHING
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) insertDailyActiveUsers(ctx context.Context, start, end time.Time) error {
|
||||
tzName := timezone.Name()
|
||||
query := `
|
||||
INSERT INTO usage_dashboard_daily_users (bucket_date, user_id)
|
||||
SELECT DISTINCT
|
||||
(bucket_start AT TIME ZONE $3)::date AS bucket_date,
|
||||
user_id
|
||||
FROM usage_dashboard_hourly_users
|
||||
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||
ON CONFLICT DO NOTHING
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Context, start, end time.Time) error {
|
||||
tzName := timezone.Name()
|
||||
query := `
|
||||
WITH hourly AS (
|
||||
SELECT
|
||||
date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
|
||||
COUNT(*) AS total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) AS input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) AS output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) AS cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
|
||||
COALESCE(SUM(total_cost), 0) AS total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) AS actual_cost,
|
||||
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) AS total_duration_ms
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
GROUP BY 1
|
||||
),
|
||||
user_counts AS (
|
||||
SELECT bucket_start, COUNT(*) AS active_users
|
||||
FROM usage_dashboard_hourly_users
|
||||
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||
GROUP BY bucket_start
|
||||
)
|
||||
INSERT INTO usage_dashboard_hourly (
|
||||
bucket_start,
|
||||
total_requests,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
total_cost,
|
||||
actual_cost,
|
||||
total_duration_ms,
|
||||
active_users,
|
||||
computed_at
|
||||
)
|
||||
SELECT
|
||||
hourly.bucket_start,
|
||||
hourly.total_requests,
|
||||
hourly.input_tokens,
|
||||
hourly.output_tokens,
|
||||
hourly.cache_creation_tokens,
|
||||
hourly.cache_read_tokens,
|
||||
hourly.total_cost,
|
||||
hourly.actual_cost,
|
||||
hourly.total_duration_ms,
|
||||
COALESCE(user_counts.active_users, 0) AS active_users,
|
||||
NOW()
|
||||
FROM hourly
|
||||
LEFT JOIN user_counts ON user_counts.bucket_start = hourly.bucket_start
|
||||
ON CONFLICT (bucket_start)
|
||||
DO UPDATE SET
|
||||
total_requests = EXCLUDED.total_requests,
|
||||
input_tokens = EXCLUDED.input_tokens,
|
||||
output_tokens = EXCLUDED.output_tokens,
|
||||
cache_creation_tokens = EXCLUDED.cache_creation_tokens,
|
||||
cache_read_tokens = EXCLUDED.cache_read_tokens,
|
||||
total_cost = EXCLUDED.total_cost,
|
||||
actual_cost = EXCLUDED.actual_cost,
|
||||
total_duration_ms = EXCLUDED.total_duration_ms,
|
||||
active_users = EXCLUDED.active_users,
|
||||
computed_at = EXCLUDED.computed_at
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Context, start, end time.Time) error {
|
||||
tzName := timezone.Name()
|
||||
query := `
|
||||
WITH daily AS (
|
||||
SELECT
|
||||
(bucket_start AT TIME ZONE $5)::date AS bucket_date,
|
||||
COALESCE(SUM(total_requests), 0) AS total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) AS input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) AS output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) AS cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
|
||||
COALESCE(SUM(total_cost), 0) AS total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) AS actual_cost,
|
||||
COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms
|
||||
FROM usage_dashboard_hourly
|
||||
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||
GROUP BY (bucket_start AT TIME ZONE $5)::date
|
||||
),
|
||||
user_counts AS (
|
||||
SELECT bucket_date, COUNT(*) AS active_users
|
||||
FROM usage_dashboard_daily_users
|
||||
WHERE bucket_date >= $3::date AND bucket_date < $4::date
|
||||
GROUP BY bucket_date
|
||||
)
|
||||
INSERT INTO usage_dashboard_daily (
|
||||
bucket_date,
|
||||
total_requests,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
total_cost,
|
||||
actual_cost,
|
||||
total_duration_ms,
|
||||
active_users,
|
||||
computed_at
|
||||
)
|
||||
SELECT
|
||||
daily.bucket_date,
|
||||
daily.total_requests,
|
||||
daily.input_tokens,
|
||||
daily.output_tokens,
|
||||
daily.cache_creation_tokens,
|
||||
daily.cache_read_tokens,
|
||||
daily.total_cost,
|
||||
daily.actual_cost,
|
||||
daily.total_duration_ms,
|
||||
COALESCE(user_counts.active_users, 0) AS active_users,
|
||||
NOW()
|
||||
FROM daily
|
||||
LEFT JOIN user_counts ON user_counts.bucket_date = daily.bucket_date
|
||||
ON CONFLICT (bucket_date)
|
||||
DO UPDATE SET
|
||||
total_requests = EXCLUDED.total_requests,
|
||||
input_tokens = EXCLUDED.input_tokens,
|
||||
output_tokens = EXCLUDED.output_tokens,
|
||||
cache_creation_tokens = EXCLUDED.cache_creation_tokens,
|
||||
cache_read_tokens = EXCLUDED.cache_read_tokens,
|
||||
total_cost = EXCLUDED.total_cost,
|
||||
actual_cost = EXCLUDED.actual_cost,
|
||||
total_duration_ms = EXCLUDED.total_duration_ms,
|
||||
active_users = EXCLUDED.active_users,
|
||||
computed_at = EXCLUDED.computed_at
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, start, end, start, end, tzName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) isUsageLogsPartitioned(ctx context.Context) (bool, error) {
|
||||
query := `
|
||||
SELECT EXISTS(
|
||||
SELECT 1
|
||||
FROM pg_partitioned_table pt
|
||||
JOIN pg_class c ON c.oid = pt.partrelid
|
||||
WHERE c.relname = 'usage_logs'
|
||||
)
|
||||
`
|
||||
var partitioned bool
|
||||
if err := scanSingleRow(ctx, r.sql, query, nil, &partitioned); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return partitioned, nil
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) dropUsageLogsPartitions(ctx context.Context, cutoff time.Time) error {
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT c.relname
|
||||
FROM pg_inherits
|
||||
JOIN pg_class c ON c.oid = pg_inherits.inhrelid
|
||||
JOIN pg_class p ON p.oid = pg_inherits.inhparent
|
||||
WHERE p.relname = 'usage_logs'
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
cutoffMonth := truncateToMonthUTC(cutoff)
|
||||
for rows.Next() {
|
||||
var name string
|
||||
if err := rows.Scan(&name); err != nil {
|
||||
return err
|
||||
}
|
||||
if !strings.HasPrefix(name, "usage_logs_") {
|
||||
continue
|
||||
}
|
||||
suffix := strings.TrimPrefix(name, "usage_logs_")
|
||||
month, err := time.Parse("200601", suffix)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
month = month.UTC()
|
||||
if month.Before(cutoffMonth) {
|
||||
if _, err := r.sql.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", pq.QuoteIdentifier(name))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) createUsageLogsPartition(ctx context.Context, month time.Time) error {
|
||||
monthStart := truncateToMonthUTC(month)
|
||||
nextMonth := monthStart.AddDate(0, 1, 0)
|
||||
name := fmt.Sprintf("usage_logs_%s", monthStart.Format("200601"))
|
||||
query := fmt.Sprintf(
|
||||
"CREATE TABLE IF NOT EXISTS %s PARTITION OF usage_logs FOR VALUES FROM (%s) TO (%s)",
|
||||
pq.QuoteIdentifier(name),
|
||||
pq.QuoteLiteral(monthStart.Format("2006-01-02")),
|
||||
pq.QuoteLiteral(nextMonth.Format("2006-01-02")),
|
||||
)
|
||||
_, err := r.sql.ExecContext(ctx, query)
|
||||
return err
|
||||
}
|
||||
|
||||
func truncateToDay(t time.Time) time.Time {
|
||||
return timezone.StartOfDay(t)
|
||||
}
|
||||
|
||||
func truncateToMonthUTC(t time.Time) time.Time {
|
||||
t = t.UTC()
|
||||
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
58
backend/internal/repository/dashboard_cache.go
Normal file
58
backend/internal/repository/dashboard_cache.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const dashboardStatsCacheKey = "dashboard:stats:v1"
|
||||
|
||||
type dashboardCache struct {
|
||||
rdb *redis.Client
|
||||
keyPrefix string
|
||||
}
|
||||
|
||||
func NewDashboardCache(rdb *redis.Client, cfg *config.Config) service.DashboardStatsCache {
|
||||
prefix := "sub2api:"
|
||||
if cfg != nil {
|
||||
prefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
|
||||
}
|
||||
if prefix != "" && !strings.HasSuffix(prefix, ":") {
|
||||
prefix += ":"
|
||||
}
|
||||
return &dashboardCache{
|
||||
rdb: rdb,
|
||||
keyPrefix: prefix,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *dashboardCache) GetDashboardStats(ctx context.Context) (string, error) {
|
||||
val, err := c.rdb.Get(ctx, c.buildKey()).Result()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return "", service.ErrDashboardStatsCacheMiss
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (c *dashboardCache) SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error {
|
||||
return c.rdb.Set(ctx, c.buildKey(), data, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *dashboardCache) buildKey() string {
|
||||
if c.keyPrefix == "" {
|
||||
return dashboardStatsCacheKey
|
||||
}
|
||||
return c.keyPrefix + dashboardStatsCacheKey
|
||||
}
|
||||
|
||||
func (c *dashboardCache) DeleteDashboardStats(ctx context.Context) error {
|
||||
return c.rdb.Del(ctx, c.buildKey()).Err()
|
||||
}
|
||||
28
backend/internal/repository/dashboard_cache_test.go
Normal file
28
backend/internal/repository/dashboard_cache_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewDashboardCacheKeyPrefix(t *testing.T) {
|
||||
cache := NewDashboardCache(nil, &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{
|
||||
KeyPrefix: "prod",
|
||||
},
|
||||
})
|
||||
impl, ok := cache.(*dashboardCache)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "prod:", impl.keyPrefix)
|
||||
|
||||
cache = NewDashboardCache(nil, &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{
|
||||
KeyPrefix: "staging:",
|
||||
},
|
||||
})
|
||||
impl, ok = cache.(*dashboardCache)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "staging:", impl.keyPrefix)
|
||||
}
|
||||
32
backend/internal/repository/db_pool.go
Normal file
32
backend/internal/repository/db_pool.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
type dbPoolSettings struct {
|
||||
MaxOpenConns int
|
||||
MaxIdleConns int
|
||||
ConnMaxLifetime time.Duration
|
||||
ConnMaxIdleTime time.Duration
|
||||
}
|
||||
|
||||
func buildDBPoolSettings(cfg *config.Config) dbPoolSettings {
|
||||
return dbPoolSettings{
|
||||
MaxOpenConns: cfg.Database.MaxOpenConns,
|
||||
MaxIdleConns: cfg.Database.MaxIdleConns,
|
||||
ConnMaxLifetime: time.Duration(cfg.Database.ConnMaxLifetimeMinutes) * time.Minute,
|
||||
ConnMaxIdleTime: time.Duration(cfg.Database.ConnMaxIdleTimeMinutes) * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
func applyDBPoolSettings(db *sql.DB, cfg *config.Config) {
|
||||
settings := buildDBPoolSettings(cfg)
|
||||
db.SetMaxOpenConns(settings.MaxOpenConns)
|
||||
db.SetMaxIdleConns(settings.MaxIdleConns)
|
||||
db.SetConnMaxLifetime(settings.ConnMaxLifetime)
|
||||
db.SetConnMaxIdleTime(settings.ConnMaxIdleTime)
|
||||
}
|
||||
50
backend/internal/repository/db_pool_test.go
Normal file
50
backend/internal/repository/db_pool_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
func TestBuildDBPoolSettings(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
MaxOpenConns: 50,
|
||||
MaxIdleConns: 10,
|
||||
ConnMaxLifetimeMinutes: 30,
|
||||
ConnMaxIdleTimeMinutes: 5,
|
||||
},
|
||||
}
|
||||
|
||||
settings := buildDBPoolSettings(cfg)
|
||||
require.Equal(t, 50, settings.MaxOpenConns)
|
||||
require.Equal(t, 10, settings.MaxIdleConns)
|
||||
require.Equal(t, 30*time.Minute, settings.ConnMaxLifetime)
|
||||
require.Equal(t, 5*time.Minute, settings.ConnMaxIdleTime)
|
||||
}
|
||||
|
||||
func TestApplyDBPoolSettings(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
MaxOpenConns: 40,
|
||||
MaxIdleConns: 8,
|
||||
ConnMaxLifetimeMinutes: 15,
|
||||
ConnMaxIdleTimeMinutes: 3,
|
||||
},
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", "host=127.0.0.1 port=5432 user=postgres sslmode=disable")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = db.Close()
|
||||
})
|
||||
|
||||
applyDBPoolSettings(db, cfg)
|
||||
stats := db.Stats()
|
||||
require.Equal(t, 40, stats.MaxOpenConnections)
|
||||
}
|
||||
52
backend/internal/repository/email_cache.go
Normal file
52
backend/internal/repository/email_cache.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const verifyCodeKeyPrefix = "verify_code:"
|
||||
|
||||
// verifyCodeKey generates the Redis key for email verification code.
|
||||
func verifyCodeKey(email string) string {
|
||||
return verifyCodeKeyPrefix + email
|
||||
}
|
||||
|
||||
type emailCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewEmailCache(rdb *redis.Client) service.EmailCache {
|
||||
return &emailCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*service.VerificationCodeData, error) {
|
||||
key := verifyCodeKey(email)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var data service.VerificationCodeData
|
||||
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *service.VerificationCodeData, ttl time.Duration) error {
|
||||
key := verifyCodeKey(email)
|
||||
val, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.rdb.Set(ctx, key, val, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) error {
|
||||
key := verifyCodeKey(email)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
92
backend/internal/repository/email_cache_integration_test.go
Normal file
92
backend/internal/repository/email_cache_integration_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type EmailCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache service.EmailCache
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewEmailCache(s.rdb)
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestGetVerificationCode_Missing() {
|
||||
_, err := s.cache.GetVerificationCode(s.ctx, "nonexistent@example.com")
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing verification code")
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestSetAndGetVerificationCode() {
|
||||
email := "a@example.com"
|
||||
emailTTL := 2 * time.Minute
|
||||
data := &service.VerificationCodeData{Code: "123456", Attempts: 1, CreatedAt: time.Now()}
|
||||
|
||||
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
|
||||
|
||||
got, err := s.cache.GetVerificationCode(s.ctx, email)
|
||||
require.NoError(s.T(), err, "GetVerificationCode")
|
||||
require.Equal(s.T(), "123456", got.Code)
|
||||
require.Equal(s.T(), 1, got.Attempts)
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestVerificationCode_TTL() {
|
||||
email := "ttl@example.com"
|
||||
emailTTL := 2 * time.Minute
|
||||
data := &service.VerificationCodeData{Code: "654321", Attempts: 0, CreatedAt: time.Now()}
|
||||
|
||||
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
|
||||
|
||||
emailKey := verifyCodeKeyPrefix + email
|
||||
ttl, err := s.rdb.TTL(s.ctx, emailKey).Result()
|
||||
require.NoError(s.T(), err, "TTL emailKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, emailTTL)
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestDeleteVerificationCode() {
|
||||
email := "delete@example.com"
|
||||
data := &service.VerificationCodeData{Code: "999999", Attempts: 0, CreatedAt: time.Now()}
|
||||
|
||||
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, 2*time.Minute), "SetVerificationCode")
|
||||
|
||||
// Verify it exists
|
||||
_, err := s.cache.GetVerificationCode(s.ctx, email)
|
||||
require.NoError(s.T(), err, "GetVerificationCode before delete")
|
||||
|
||||
// Delete
|
||||
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, email), "DeleteVerificationCode")
|
||||
|
||||
// Verify it's gone
|
||||
_, err = s.cache.GetVerificationCode(s.ctx, email)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestDeleteVerificationCode_NonExistent() {
|
||||
// Deleting a non-existent key should not error
|
||||
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, "nonexistent@example.com"), "DeleteVerificationCode non-existent")
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestGetVerificationCode_JSONCorruption() {
|
||||
emailKey := verifyCodeKeyPrefix + "corrupted@example.com"
|
||||
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, emailKey, "not-json", 1*time.Minute).Err(), "Set invalid JSON")
|
||||
|
||||
_, err := s.cache.GetVerificationCode(s.ctx, "corrupted@example.com")
|
||||
require.Error(s.T(), err, "expected error for corrupted JSON")
|
||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil")
|
||||
}
|
||||
|
||||
func TestEmailCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(EmailCacheSuite))
|
||||
}
|
||||
45
backend/internal/repository/email_cache_test.go
Normal file
45
backend/internal/repository/email_cache_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVerifyCodeKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_email",
|
||||
email: "user@example.com",
|
||||
expected: "verify_code:user@example.com",
|
||||
},
|
||||
{
|
||||
name: "empty_email",
|
||||
email: "",
|
||||
expected: "verify_code:",
|
||||
},
|
||||
{
|
||||
name: "email_with_plus",
|
||||
email: "user+tag@example.com",
|
||||
expected: "verify_code:user+tag@example.com",
|
||||
},
|
||||
{
|
||||
name: "email_with_special_chars",
|
||||
email: "user.name+tag@sub.domain.com",
|
||||
expected: "verify_code:user.name+tag@sub.domain.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := verifyCodeKey(tc.email)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
69
backend/internal/repository/ent.go
Normal file
69
backend/internal/repository/ent.go
Normal file
@@ -0,0 +1,69 @@
|
||||
// Package repository 提供应用程序的基础设施层组件。
|
||||
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/migrations"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
_ "github.com/lib/pq" // PostgreSQL 驱动,通过副作用导入注册驱动
|
||||
)
|
||||
|
||||
// InitEnt 初始化 Ent ORM 客户端并返回客户端实例和底层的 *sql.DB。
|
||||
//
|
||||
// 该函数执行以下操作:
|
||||
// 1. 初始化全局时区设置,确保时间处理一致性
|
||||
// 2. 建立 PostgreSQL 数据库连接
|
||||
// 3. 自动执行数据库迁移,确保 schema 与代码同步
|
||||
// 4. 创建并返回 Ent 客户端实例
|
||||
//
|
||||
// 重要提示:调用者必须负责关闭返回的 ent.Client(关闭时会自动关闭底层的 driver/db)。
|
||||
//
|
||||
// 参数:
|
||||
// - cfg: 应用程序配置,包含数据库连接信息和时区设置
|
||||
//
|
||||
// 返回:
|
||||
// - *ent.Client: Ent ORM 客户端,用于执行数据库操作
|
||||
// - *sql.DB: 底层的 SQL 数据库连接,可用于直接执行原生 SQL
|
||||
// - error: 初始化过程中的错误
|
||||
func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
|
||||
// 优先初始化时区设置,确保所有时间操作使用统一的时区。
|
||||
// 这对于跨时区部署和日志时间戳的一致性至关重要。
|
||||
if err := timezone.Init(cfg.Timezone); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 构建包含时区信息的数据库连接字符串 (DSN)。
|
||||
// 时区信息会传递给 PostgreSQL,确保数据库层面的时间处理正确。
|
||||
dsn := cfg.Database.DSNWithTimezone(cfg.Timezone)
|
||||
|
||||
// 使用 Ent 的 SQL 驱动打开 PostgreSQL 连接。
|
||||
// dialect.Postgres 指定使用 PostgreSQL 方言进行 SQL 生成。
|
||||
drv, err := entsql.Open(dialect.Postgres, dsn)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
applyDBPoolSettings(drv.DB(), cfg)
|
||||
|
||||
// 确保数据库 schema 已准备就绪。
|
||||
// SQL 迁移文件是 schema 的权威来源(source of truth)。
|
||||
// 这种方式比 Ent 的自动迁移更可控,支持复杂的迁移场景。
|
||||
migrationCtx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
if err := applyMigrationsFS(migrationCtx, drv.DB(), migrations.FS); err != nil {
|
||||
_ = drv.Close() // 迁移失败时关闭驱动,避免资源泄露
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 创建 Ent 客户端,绑定到已配置的数据库驱动。
|
||||
client := ent.NewClient(ent.Driver(drv))
|
||||
return client, drv.DB(), nil
|
||||
}
|
||||
97
backend/internal/repository/error_translate.go
Normal file
97
backend/internal/repository/error_translate.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
// clientFromContext 从 context 中获取事务 client,如果不存在则返回默认 client。
|
||||
//
|
||||
// 这个辅助函数支持 repository 方法在事务上下文中工作:
|
||||
// - 如果 context 中存在事务(通过 ent.NewTxContext 设置),返回事务的 client
|
||||
// - 否则返回传入的默认 client
|
||||
//
|
||||
// 使用示例:
|
||||
//
|
||||
// func (r *someRepo) SomeMethod(ctx context.Context) error {
|
||||
// client := clientFromContext(ctx, r.client)
|
||||
// return client.SomeEntity.Create().Save(ctx)
|
||||
// }
|
||||
func clientFromContext(ctx context.Context, defaultClient *dbent.Client) *dbent.Client {
|
||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||
return tx.Client()
|
||||
}
|
||||
return defaultClient
|
||||
}
|
||||
|
||||
// translatePersistenceError 将数据库层错误翻译为业务层错误。
|
||||
//
|
||||
// 这是 Repository 层的核心错误处理函数,确保数据库细节不会泄露到业务层。
|
||||
// 通过统一的错误翻译,业务层可以使用语义明确的错误类型(如 ErrUserNotFound)
|
||||
// 而不是依赖于特定数据库的错误(如 sql.ErrNoRows)。
|
||||
//
|
||||
// 参数:
|
||||
// - err: 原始数据库错误
|
||||
// - notFound: 当记录不存在时返回的业务错误(可为 nil 表示不处理)
|
||||
// - conflict: 当违反唯一约束时返回的业务错误(可为 nil 表示不处理)
|
||||
//
|
||||
// 返回:
|
||||
// - 翻译后的业务错误,或原始错误(如果不匹配任何规则)
|
||||
//
|
||||
// 示例:
|
||||
//
|
||||
// err := translatePersistenceError(dbErr, service.ErrUserNotFound, service.ErrEmailExists)
|
||||
func translatePersistenceError(err error, notFound, conflict *infraerrors.ApplicationError) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 兼容 Ent ORM 和标准 database/sql 的 NotFound 行为。
|
||||
// Ent 使用自定义的 NotFoundError,而标准库使用 sql.ErrNoRows。
|
||||
// 这里同时处理两种情况,保持业务错误映射一致。
|
||||
if notFound != nil && (errors.Is(err, sql.ErrNoRows) || dbent.IsNotFound(err)) {
|
||||
return notFound.WithCause(err)
|
||||
}
|
||||
|
||||
// 处理唯一约束冲突(如邮箱已存在、名称重复等)
|
||||
if conflict != nil && isUniqueConstraintViolation(err) {
|
||||
return conflict.WithCause(err)
|
||||
}
|
||||
|
||||
// 未匹配任何规则,返回原始错误
|
||||
return err
|
||||
}
|
||||
|
||||
// isUniqueConstraintViolation 判断错误是否为唯一约束冲突。
|
||||
//
|
||||
// 支持多种检测方式:
|
||||
// 1. PostgreSQL 特定错误码 23505(唯一约束冲突)
|
||||
// 2. 错误消息中包含的通用关键词
|
||||
//
|
||||
// 这种多层次的检测确保了对不同数据库驱动和 ORM 的兼容性。
|
||||
func isUniqueConstraintViolation(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 优先检测 PostgreSQL 特定错误码(最精确)。
|
||||
// 错误码 23505 对应 unique_violation。
|
||||
// 参考:https://www.postgresql.org/docs/current/errcodes-appendix.html
|
||||
var pgErr *pq.Error
|
||||
if errors.As(err, &pgErr) {
|
||||
return pgErr.Code == "23505"
|
||||
}
|
||||
|
||||
// 回退到错误消息检测(兼容其他场景)。
|
||||
// 这些关键词覆盖了 PostgreSQL、MySQL 等主流数据库的错误消息。
|
||||
msg := strings.ToLower(err.Error())
|
||||
return strings.Contains(msg, "duplicate key") ||
|
||||
strings.Contains(msg, "unique constraint") ||
|
||||
strings.Contains(msg, "duplicate entry")
|
||||
}
|
||||
391
backend/internal/repository/fixtures_integration_test.go
Normal file
391
backend/internal/repository/fixtures_integration_test.go
Normal file
@@ -0,0 +1,391 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func mustCreateUser(t *testing.T, client *dbent.Client, u *service.User) *service.User {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
if u.Email == "" {
|
||||
u.Email = "user-" + time.Now().Format(time.RFC3339Nano) + "@example.com"
|
||||
}
|
||||
if u.PasswordHash == "" {
|
||||
u.PasswordHash = "test-password-hash"
|
||||
}
|
||||
if u.Role == "" {
|
||||
u.Role = service.RoleUser
|
||||
}
|
||||
if u.Status == "" {
|
||||
u.Status = service.StatusActive
|
||||
}
|
||||
if u.Concurrency == 0 {
|
||||
u.Concurrency = 5
|
||||
}
|
||||
|
||||
create := client.User.Create().
|
||||
SetEmail(u.Email).
|
||||
SetPasswordHash(u.PasswordHash).
|
||||
SetRole(u.Role).
|
||||
SetStatus(u.Status).
|
||||
SetBalance(u.Balance).
|
||||
SetConcurrency(u.Concurrency).
|
||||
SetUsername(u.Username).
|
||||
SetNotes(u.Notes)
|
||||
if !u.CreatedAt.IsZero() {
|
||||
create.SetCreatedAt(u.CreatedAt)
|
||||
}
|
||||
if !u.UpdatedAt.IsZero() {
|
||||
create.SetUpdatedAt(u.UpdatedAt)
|
||||
}
|
||||
|
||||
created, err := create.Save(ctx)
|
||||
require.NoError(t, err, "create user")
|
||||
|
||||
u.ID = created.ID
|
||||
u.CreatedAt = created.CreatedAt
|
||||
u.UpdatedAt = created.UpdatedAt
|
||||
|
||||
if len(u.AllowedGroups) > 0 {
|
||||
for _, groupID := range u.AllowedGroups {
|
||||
_, err := client.UserAllowedGroup.Create().
|
||||
SetUserID(u.ID).
|
||||
SetGroupID(groupID).
|
||||
Save(ctx)
|
||||
require.NoError(t, err, "create user_allowed_groups row")
|
||||
}
|
||||
}
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
func mustCreateGroup(t *testing.T, client *dbent.Client, g *service.Group) *service.Group {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
if g.Platform == "" {
|
||||
g.Platform = service.PlatformAnthropic
|
||||
}
|
||||
if g.Status == "" {
|
||||
g.Status = service.StatusActive
|
||||
}
|
||||
if g.SubscriptionType == "" {
|
||||
g.SubscriptionType = service.SubscriptionTypeStandard
|
||||
}
|
||||
|
||||
create := client.Group.Create().
|
||||
SetName(g.Name).
|
||||
SetPlatform(g.Platform).
|
||||
SetStatus(g.Status).
|
||||
SetSubscriptionType(g.SubscriptionType).
|
||||
SetRateMultiplier(g.RateMultiplier).
|
||||
SetIsExclusive(g.IsExclusive)
|
||||
if g.Description != "" {
|
||||
create.SetDescription(g.Description)
|
||||
}
|
||||
if g.DailyLimitUSD != nil {
|
||||
create.SetDailyLimitUsd(*g.DailyLimitUSD)
|
||||
}
|
||||
if g.WeeklyLimitUSD != nil {
|
||||
create.SetWeeklyLimitUsd(*g.WeeklyLimitUSD)
|
||||
}
|
||||
if g.MonthlyLimitUSD != nil {
|
||||
create.SetMonthlyLimitUsd(*g.MonthlyLimitUSD)
|
||||
}
|
||||
if !g.CreatedAt.IsZero() {
|
||||
create.SetCreatedAt(g.CreatedAt)
|
||||
}
|
||||
if !g.UpdatedAt.IsZero() {
|
||||
create.SetUpdatedAt(g.UpdatedAt)
|
||||
}
|
||||
|
||||
created, err := create.Save(ctx)
|
||||
require.NoError(t, err, "create group")
|
||||
|
||||
g.ID = created.ID
|
||||
g.CreatedAt = created.CreatedAt
|
||||
g.UpdatedAt = created.UpdatedAt
|
||||
return g
|
||||
}
|
||||
|
||||
func mustCreateProxy(t *testing.T, client *dbent.Client, p *service.Proxy) *service.Proxy {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
if p.Protocol == "" {
|
||||
p.Protocol = "http"
|
||||
}
|
||||
if p.Host == "" {
|
||||
p.Host = "127.0.0.1"
|
||||
}
|
||||
if p.Port == 0 {
|
||||
p.Port = 8080
|
||||
}
|
||||
if p.Status == "" {
|
||||
p.Status = service.StatusActive
|
||||
}
|
||||
|
||||
create := client.Proxy.Create().
|
||||
SetName(p.Name).
|
||||
SetProtocol(p.Protocol).
|
||||
SetHost(p.Host).
|
||||
SetPort(p.Port).
|
||||
SetStatus(p.Status)
|
||||
if p.Username != "" {
|
||||
create.SetUsername(p.Username)
|
||||
}
|
||||
if p.Password != "" {
|
||||
create.SetPassword(p.Password)
|
||||
}
|
||||
if !p.CreatedAt.IsZero() {
|
||||
create.SetCreatedAt(p.CreatedAt)
|
||||
}
|
||||
if !p.UpdatedAt.IsZero() {
|
||||
create.SetUpdatedAt(p.UpdatedAt)
|
||||
}
|
||||
|
||||
created, err := create.Save(ctx)
|
||||
require.NoError(t, err, "create proxy")
|
||||
|
||||
p.ID = created.ID
|
||||
p.CreatedAt = created.CreatedAt
|
||||
p.UpdatedAt = created.UpdatedAt
|
||||
return p
|
||||
}
|
||||
|
||||
func mustCreateAccount(t *testing.T, client *dbent.Client, a *service.Account) *service.Account {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
if a.Platform == "" {
|
||||
a.Platform = service.PlatformAnthropic
|
||||
}
|
||||
if a.Type == "" {
|
||||
a.Type = service.AccountTypeOAuth
|
||||
}
|
||||
if a.Status == "" {
|
||||
a.Status = service.StatusActive
|
||||
}
|
||||
if a.Concurrency == 0 {
|
||||
a.Concurrency = 3
|
||||
}
|
||||
if a.Priority == 0 {
|
||||
a.Priority = 50
|
||||
}
|
||||
if !a.Schedulable {
|
||||
a.Schedulable = true
|
||||
}
|
||||
if a.Credentials == nil {
|
||||
a.Credentials = map[string]any{}
|
||||
}
|
||||
if a.Extra == nil {
|
||||
a.Extra = map[string]any{}
|
||||
}
|
||||
|
||||
create := client.Account.Create().
|
||||
SetName(a.Name).
|
||||
SetPlatform(a.Platform).
|
||||
SetType(a.Type).
|
||||
SetCredentials(a.Credentials).
|
||||
SetExtra(a.Extra).
|
||||
SetConcurrency(a.Concurrency).
|
||||
SetPriority(a.Priority).
|
||||
SetStatus(a.Status).
|
||||
SetSchedulable(a.Schedulable).
|
||||
SetErrorMessage(a.ErrorMessage)
|
||||
|
||||
if a.ProxyID != nil {
|
||||
create.SetProxyID(*a.ProxyID)
|
||||
}
|
||||
if a.LastUsedAt != nil {
|
||||
create.SetLastUsedAt(*a.LastUsedAt)
|
||||
}
|
||||
if a.RateLimitedAt != nil {
|
||||
create.SetRateLimitedAt(*a.RateLimitedAt)
|
||||
}
|
||||
if a.RateLimitResetAt != nil {
|
||||
create.SetRateLimitResetAt(*a.RateLimitResetAt)
|
||||
}
|
||||
if a.OverloadUntil != nil {
|
||||
create.SetOverloadUntil(*a.OverloadUntil)
|
||||
}
|
||||
if a.SessionWindowStart != nil {
|
||||
create.SetSessionWindowStart(*a.SessionWindowStart)
|
||||
}
|
||||
if a.SessionWindowEnd != nil {
|
||||
create.SetSessionWindowEnd(*a.SessionWindowEnd)
|
||||
}
|
||||
if a.SessionWindowStatus != "" {
|
||||
create.SetSessionWindowStatus(a.SessionWindowStatus)
|
||||
}
|
||||
if !a.CreatedAt.IsZero() {
|
||||
create.SetCreatedAt(a.CreatedAt)
|
||||
}
|
||||
if !a.UpdatedAt.IsZero() {
|
||||
create.SetUpdatedAt(a.UpdatedAt)
|
||||
}
|
||||
|
||||
created, err := create.Save(ctx)
|
||||
require.NoError(t, err, "create account")
|
||||
|
||||
a.ID = created.ID
|
||||
a.CreatedAt = created.CreatedAt
|
||||
a.UpdatedAt = created.UpdatedAt
|
||||
return a
|
||||
}
|
||||
|
||||
func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.APIKey) *service.APIKey {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
if k.Status == "" {
|
||||
k.Status = service.StatusActive
|
||||
}
|
||||
if k.Key == "" {
|
||||
k.Key = "sk-" + time.Now().Format("150405.000000")
|
||||
}
|
||||
if k.Name == "" {
|
||||
k.Name = "default"
|
||||
}
|
||||
|
||||
create := client.APIKey.Create().
|
||||
SetUserID(k.UserID).
|
||||
SetKey(k.Key).
|
||||
SetName(k.Name).
|
||||
SetStatus(k.Status)
|
||||
if k.GroupID != nil {
|
||||
create.SetGroupID(*k.GroupID)
|
||||
}
|
||||
if !k.CreatedAt.IsZero() {
|
||||
create.SetCreatedAt(k.CreatedAt)
|
||||
}
|
||||
if !k.UpdatedAt.IsZero() {
|
||||
create.SetUpdatedAt(k.UpdatedAt)
|
||||
}
|
||||
|
||||
created, err := create.Save(ctx)
|
||||
require.NoError(t, err, "create api key")
|
||||
|
||||
k.ID = created.ID
|
||||
k.CreatedAt = created.CreatedAt
|
||||
k.UpdatedAt = created.UpdatedAt
|
||||
return k
|
||||
}
|
||||
|
||||
func mustCreateRedeemCode(t *testing.T, client *dbent.Client, c *service.RedeemCode) *service.RedeemCode {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
if c.Status == "" {
|
||||
c.Status = service.StatusUnused
|
||||
}
|
||||
if c.Type == "" {
|
||||
c.Type = service.RedeemTypeBalance
|
||||
}
|
||||
if c.Code == "" {
|
||||
c.Code = "rc-" + time.Now().Format("150405.000000")
|
||||
}
|
||||
|
||||
create := client.RedeemCode.Create().
|
||||
SetCode(c.Code).
|
||||
SetType(c.Type).
|
||||
SetValue(c.Value).
|
||||
SetStatus(c.Status).
|
||||
SetNotes(c.Notes).
|
||||
SetValidityDays(c.ValidityDays)
|
||||
if c.UsedBy != nil {
|
||||
create.SetUsedBy(*c.UsedBy)
|
||||
}
|
||||
if c.UsedAt != nil {
|
||||
create.SetUsedAt(*c.UsedAt)
|
||||
}
|
||||
if c.GroupID != nil {
|
||||
create.SetGroupID(*c.GroupID)
|
||||
}
|
||||
if !c.CreatedAt.IsZero() {
|
||||
create.SetCreatedAt(c.CreatedAt)
|
||||
}
|
||||
|
||||
created, err := create.Save(ctx)
|
||||
require.NoError(t, err, "create redeem code")
|
||||
|
||||
c.ID = created.ID
|
||||
c.CreatedAt = created.CreatedAt
|
||||
return c
|
||||
}
|
||||
|
||||
func mustCreateSubscription(t *testing.T, client *dbent.Client, s *service.UserSubscription) *service.UserSubscription {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
if s.Status == "" {
|
||||
s.Status = service.SubscriptionStatusActive
|
||||
}
|
||||
now := time.Now()
|
||||
if s.StartsAt.IsZero() {
|
||||
s.StartsAt = now.Add(-1 * time.Hour)
|
||||
}
|
||||
if s.ExpiresAt.IsZero() {
|
||||
s.ExpiresAt = now.Add(24 * time.Hour)
|
||||
}
|
||||
if s.AssignedAt.IsZero() {
|
||||
s.AssignedAt = now
|
||||
}
|
||||
if s.CreatedAt.IsZero() {
|
||||
s.CreatedAt = now
|
||||
}
|
||||
if s.UpdatedAt.IsZero() {
|
||||
s.UpdatedAt = now
|
||||
}
|
||||
|
||||
create := client.UserSubscription.Create().
|
||||
SetUserID(s.UserID).
|
||||
SetGroupID(s.GroupID).
|
||||
SetStartsAt(s.StartsAt).
|
||||
SetExpiresAt(s.ExpiresAt).
|
||||
SetStatus(s.Status).
|
||||
SetAssignedAt(s.AssignedAt).
|
||||
SetNotes(s.Notes).
|
||||
SetDailyUsageUsd(s.DailyUsageUSD).
|
||||
SetWeeklyUsageUsd(s.WeeklyUsageUSD).
|
||||
SetMonthlyUsageUsd(s.MonthlyUsageUSD)
|
||||
|
||||
if s.AssignedBy != nil {
|
||||
create.SetAssignedBy(*s.AssignedBy)
|
||||
}
|
||||
if !s.CreatedAt.IsZero() {
|
||||
create.SetCreatedAt(s.CreatedAt)
|
||||
}
|
||||
if !s.UpdatedAt.IsZero() {
|
||||
create.SetUpdatedAt(s.UpdatedAt)
|
||||
}
|
||||
|
||||
created, err := create.Save(ctx)
|
||||
require.NoError(t, err, "create user subscription")
|
||||
|
||||
s.ID = created.ID
|
||||
s.CreatedAt = created.CreatedAt
|
||||
s.UpdatedAt = created.UpdatedAt
|
||||
return s
|
||||
}
|
||||
|
||||
func mustBindAccountToGroup(t *testing.T, client *dbent.Client, accountID, groupID int64, priority int) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := client.AccountGroup.Create().
|
||||
SetAccountID(accountID).
|
||||
SetGroupID(groupID).
|
||||
SetPriority(priority).
|
||||
Save(ctx)
|
||||
require.NoError(t, err, "create account_group")
|
||||
}
|
||||
41
backend/internal/repository/gateway_cache.go
Normal file
41
backend/internal/repository/gateway_cache.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const stickySessionPrefix = "sticky_session:"
|
||||
|
||||
type gatewayCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewGatewayCache(rdb *redis.Client) service.GatewayCache {
|
||||
return &gatewayCache{rdb: rdb}
|
||||
}
|
||||
|
||||
// buildSessionKey 构建 session key,包含 groupID 实现分组隔离
|
||||
// 格式: sticky_session:{groupID}:{sessionHash}
|
||||
func buildSessionKey(groupID int64, sessionHash string) string {
|
||||
return fmt.Sprintf("%s%d:%s", stickySessionPrefix, groupID, sessionHash)
|
||||
}
|
||||
|
||||
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Get(ctx, key).Int64()
|
||||
}
|
||||
|
||||
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Set(ctx, key, accountID, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Expire(ctx, key, ttl).Err()
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type GatewayCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache service.GatewayCache
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewGatewayCache(s.rdb)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGetSessionAccountID_Missing() {
|
||||
_, err := s.cache.GetSessionAccountID(s.ctx, 1, "nonexistent")
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing session")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() {
|
||||
sessionID := "s1"
|
||||
accountID := int64(99)
|
||||
groupID := int64(1)
|
||||
sessionTTL := 1 * time.Minute
|
||||
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
|
||||
|
||||
sid, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
|
||||
require.NoError(s.T(), err, "GetSessionAccountID")
|
||||
require.Equal(s.T(), accountID, sid, "session id mismatch")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestSessionAccountID_TTL() {
|
||||
sessionID := "s2"
|
||||
accountID := int64(100)
|
||||
groupID := int64(1)
|
||||
sessionTTL := 1 * time.Minute
|
||||
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
|
||||
|
||||
sessionKey := buildSessionKey(groupID, sessionID)
|
||||
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
|
||||
require.NoError(s.T(), err, "TTL sessionKey after Set")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, sessionTTL)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestRefreshSessionTTL() {
|
||||
sessionID := "s3"
|
||||
accountID := int64(101)
|
||||
groupID := int64(1)
|
||||
initialTTL := 1 * time.Minute
|
||||
refreshTTL := 3 * time.Minute
|
||||
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, initialTTL), "SetSessionAccountID")
|
||||
|
||||
require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, groupID, sessionID, refreshTTL), "RefreshSessionTTL")
|
||||
|
||||
sessionKey := buildSessionKey(groupID, sessionID)
|
||||
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
|
||||
require.NoError(s.T(), err, "TTL after Refresh")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, refreshTTL)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
|
||||
// RefreshSessionTTL on a missing key should not error (no-op)
|
||||
err := s.cache.RefreshSessionTTL(s.ctx, 1, "missing-session", 1*time.Minute)
|
||||
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
|
||||
sessionID := "corrupted"
|
||||
groupID := int64(1)
|
||||
sessionKey := buildSessionKey(groupID, sessionID)
|
||||
|
||||
// Set a non-integer value
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, sessionKey, "not-a-number", 1*time.Minute).Err(), "Set invalid value")
|
||||
|
||||
_, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
|
||||
require.Error(s.T(), err, "expected error for corrupted value")
|
||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
|
||||
}
|
||||
|
||||
func TestGatewayCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(GatewayCacheSuite))
|
||||
}
|
||||
250
backend/internal/repository/gateway_routing_integration_test.go
Normal file
250
backend/internal/repository/gateway_routing_integration_test.go
Normal file
@@ -0,0 +1,250 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// GatewayRoutingSuite 测试网关路由相关的数据库查询
|
||||
// 验证账户选择和分流逻辑在真实数据库环境下的行为
|
||||
type GatewayRoutingSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
accountRepo *accountRepository
|
||||
}
|
||||
|
||||
func (s *GatewayRoutingSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx)
|
||||
}
|
||||
|
||||
func TestGatewayRoutingSuite(t *testing.T) {
|
||||
suite.Run(t, new(GatewayRoutingSuite))
|
||||
}
|
||||
|
||||
// TestListSchedulableByPlatforms_GeminiAndAntigravity 验证多平台账户查询
|
||||
func (s *GatewayRoutingSuite) TestListSchedulableByPlatforms_GeminiAndAntigravity() {
|
||||
// 创建各平台账户
|
||||
geminiAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "gemini-oauth",
|
||||
Platform: service.PlatformGemini,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Priority: 1,
|
||||
})
|
||||
|
||||
antigravityAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "antigravity-oauth",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Priority: 2,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test-token",
|
||||
"refresh_token": "test-refresh",
|
||||
"project_id": "test-project",
|
||||
},
|
||||
})
|
||||
|
||||
// 创建不应被选中的 anthropic 账户
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "anthropic-oauth",
|
||||
Platform: service.PlatformAnthropic,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Priority: 0,
|
||||
})
|
||||
|
||||
// 查询 gemini + antigravity 平台
|
||||
accounts, err := s.accountRepo.ListSchedulableByPlatforms(s.ctx, []string{
|
||||
service.PlatformGemini,
|
||||
service.PlatformAntigravity,
|
||||
})
|
||||
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 2, "应返回 gemini 和 antigravity 两个账户")
|
||||
|
||||
// 验证返回的账户平台
|
||||
platforms := make(map[string]bool)
|
||||
for _, acc := range accounts {
|
||||
platforms[acc.Platform] = true
|
||||
}
|
||||
s.Require().True(platforms[service.PlatformGemini], "应包含 gemini 账户")
|
||||
s.Require().True(platforms[service.PlatformAntigravity], "应包含 antigravity 账户")
|
||||
s.Require().False(platforms[service.PlatformAnthropic], "不应包含 anthropic 账户")
|
||||
|
||||
// 验证账户 ID 匹配
|
||||
ids := make(map[int64]bool)
|
||||
for _, acc := range accounts {
|
||||
ids[acc.ID] = true
|
||||
}
|
||||
s.Require().True(ids[geminiAcc.ID])
|
||||
s.Require().True(ids[antigravityAcc.ID])
|
||||
}
|
||||
|
||||
// TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding 验证按分组过滤
|
||||
func (s *GatewayRoutingSuite) TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding() {
|
||||
// 创建 gemini 分组
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{
|
||||
Name: "gemini-group",
|
||||
Platform: service.PlatformGemini,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
// 创建账户
|
||||
boundAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "bound-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
unboundAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "unbound-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
// 只绑定一个账户到分组
|
||||
mustBindAccountToGroup(s.T(), s.client, boundAcc.ID, group.ID, 1)
|
||||
|
||||
// 查询分组内的账户
|
||||
accounts, err := s.accountRepo.ListSchedulableByGroupIDAndPlatforms(s.ctx, group.ID, []string{
|
||||
service.PlatformGemini,
|
||||
service.PlatformAntigravity,
|
||||
})
|
||||
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 1, "应只返回绑定到分组的账户")
|
||||
s.Require().Equal(boundAcc.ID, accounts[0].ID)
|
||||
|
||||
// 确认未绑定的账户不在结果中
|
||||
for _, acc := range accounts {
|
||||
s.Require().NotEqual(unboundAcc.ID, acc.ID, "不应包含未绑定的账户")
|
||||
}
|
||||
}
|
||||
|
||||
// TestListSchedulableByPlatform_Antigravity 验证单平台查询
|
||||
func (s *GatewayRoutingSuite) TestListSchedulableByPlatform_Antigravity() {
|
||||
// 创建多种平台账户
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "gemini-1",
|
||||
Platform: service.PlatformGemini,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
antigravity := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "antigravity-1",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
// 只查询 antigravity 平台
|
||||
accounts, err := s.accountRepo.ListSchedulableByPlatform(s.ctx, service.PlatformAntigravity)
|
||||
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().Equal(antigravity.ID, accounts[0].ID)
|
||||
s.Require().Equal(service.PlatformAntigravity, accounts[0].Platform)
|
||||
}
|
||||
|
||||
// TestSchedulableFilter_ExcludesInactive 验证不可调度账户被过滤
|
||||
func (s *GatewayRoutingSuite) TestSchedulableFilter_ExcludesInactive() {
|
||||
// 创建可调度账户
|
||||
activeAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "active-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
// 创建不可调度账户(需要先创建再更新,因为 fixture 默认设置 Schedulable=true)
|
||||
inactiveAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "inactive-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
s.Require().NoError(s.client.Account.UpdateOneID(inactiveAcc.ID).SetSchedulable(false).Exec(s.ctx))
|
||||
|
||||
// 创建错误状态账户
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "error-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusError,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
accounts, err := s.accountRepo.ListSchedulableByPlatform(s.ctx, service.PlatformAntigravity)
|
||||
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 1, "应只返回可调度的 active 账户")
|
||||
s.Require().Equal(activeAcc.ID, accounts[0].ID)
|
||||
}
|
||||
|
||||
// TestPlatformRoutingDecision 验证平台路由决策
|
||||
// 这个测试模拟 Handler 层在选择账户后的路由决策逻辑
|
||||
func (s *GatewayRoutingSuite) TestPlatformRoutingDecision() {
|
||||
// 创建两种平台的账户
|
||||
geminiAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "gemini-route-test",
|
||||
Platform: service.PlatformGemini,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
antigravityAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "antigravity-route-test",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID int64
|
||||
expectedService string
|
||||
}{
|
||||
{
|
||||
name: "Gemini账户路由到ForwardNative",
|
||||
accountID: geminiAcc.ID,
|
||||
expectedService: "GeminiMessagesCompatService.ForwardNative",
|
||||
},
|
||||
{
|
||||
name: "Antigravity账户路由到ForwardGemini",
|
||||
accountID: antigravityAcc.ID,
|
||||
expectedService: "AntigravityGatewayService.ForwardGemini",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
// 从数据库获取账户
|
||||
account, err := s.accountRepo.GetByID(s.ctx, tt.accountID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// 模拟 Handler 层的路由决策
|
||||
var routedService string
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
routedService = "AntigravityGatewayService.ForwardGemini"
|
||||
} else {
|
||||
routedService = "GeminiMessagesCompatService.ForwardNative"
|
||||
}
|
||||
|
||||
s.Require().Equal(tt.expectedService, routedService)
|
||||
})
|
||||
}
|
||||
}
|
||||
119
backend/internal/repository/gemini_oauth_client.go
Normal file
119
backend/internal/repository/gemini_oauth_client.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
type geminiOAuthClient struct {
|
||||
tokenURL string
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewGeminiOAuthClient(cfg *config.Config) service.GeminiOAuthClient {
|
||||
return &geminiOAuthClient{
|
||||
tokenURL: geminicli.TokenURL,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) {
|
||||
client := createGeminiReqClient(proxyURL)
|
||||
|
||||
// Use different OAuth clients based on oauthType:
|
||||
// - code_assist: always use built-in Gemini CLI OAuth client (public)
|
||||
// - google_one: always use built-in Gemini CLI OAuth client (public)
|
||||
// - ai_studio: requires a user-provided OAuth client
|
||||
oauthCfgInput := geminicli.OAuthConfig{
|
||||
ClientID: c.cfg.Gemini.OAuth.ClientID,
|
||||
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
|
||||
Scopes: c.cfg.Gemini.OAuth.Scopes,
|
||||
}
|
||||
if oauthType == "code_assist" || oauthType == "google_one" {
|
||||
// Force use of built-in Gemini CLI OAuth client
|
||||
oauthCfgInput.ClientID = ""
|
||||
oauthCfgInput.ClientSecret = ""
|
||||
}
|
||||
|
||||
oauthCfg, err := geminicli.EffectiveOAuthConfig(oauthCfgInput, oauthType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "authorization_code")
|
||||
formData.Set("client_id", oauthCfg.ClientID)
|
||||
formData.Set("client_secret", oauthCfg.ClientSecret)
|
||||
formData.Set("code", code)
|
||||
formData.Set("code_verifier", codeVerifier)
|
||||
formData.Set("redirect_uri", redirectURI)
|
||||
|
||||
var tokenResp geminicli.TokenResponse
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(c.tokenURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, geminicli.SanitizeBodyForLogs(resp.String()))
|
||||
}
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) {
|
||||
client := createGeminiReqClient(proxyURL)
|
||||
|
||||
oauthCfgInput := geminicli.OAuthConfig{
|
||||
ClientID: c.cfg.Gemini.OAuth.ClientID,
|
||||
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
|
||||
Scopes: c.cfg.Gemini.OAuth.Scopes,
|
||||
}
|
||||
if oauthType == "code_assist" || oauthType == "google_one" {
|
||||
// Force use of built-in Gemini CLI OAuth client
|
||||
oauthCfgInput.ClientID = ""
|
||||
oauthCfgInput.ClientSecret = ""
|
||||
}
|
||||
|
||||
oauthCfg, err := geminicli.EffectiveOAuthConfig(oauthCfgInput, oauthType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "refresh_token")
|
||||
formData.Set("refresh_token", refreshToken)
|
||||
formData.Set("client_id", oauthCfg.ClientID)
|
||||
formData.Set("client_secret", oauthCfg.ClientSecret)
|
||||
|
||||
var tokenResp geminicli.TokenResponse
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(c.tokenURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, geminicli.SanitizeBodyForLogs(resp.String()))
|
||||
}
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func createGeminiReqClient(proxyURL string) *req.Client {
|
||||
return getSharedReqClient(reqClientOptions{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 60 * time.Second,
|
||||
})
|
||||
}
|
||||
49
backend/internal/repository/gemini_token_cache.go
Normal file
49
backend/internal/repository/gemini_token_cache.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
geminiTokenKeyPrefix = "gemini:token:"
|
||||
geminiRefreshLockKeyPrefix = "gemini:refresh_lock:"
|
||||
)
|
||||
|
||||
type geminiTokenCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewGeminiTokenCache(rdb *redis.Client) service.GeminiTokenCache {
|
||||
return &geminiTokenCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *geminiTokenCache) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
|
||||
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
|
||||
return c.rdb.Get(ctx, key).Result()
|
||||
}
|
||||
|
||||
func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
|
||||
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
|
||||
return c.rdb.Set(ctx, key, token, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *geminiTokenCache) DeleteAccessToken(ctx context.Context, cacheKey string) error {
|
||||
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
|
||||
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
|
||||
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
|
||||
}
|
||||
|
||||
func (c *geminiTokenCache) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
|
||||
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type GeminiTokenCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache service.GeminiTokenCache
|
||||
}
|
||||
|
||||
func (s *GeminiTokenCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewGeminiTokenCache(s.rdb)
|
||||
}
|
||||
|
||||
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken() {
|
||||
cacheKey := "project-123"
|
||||
token := "token-value"
|
||||
require.NoError(s.T(), s.cache.SetAccessToken(s.ctx, cacheKey, token, time.Minute))
|
||||
|
||||
got, err := s.cache.GetAccessToken(s.ctx, cacheKey)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), token, got)
|
||||
|
||||
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, cacheKey))
|
||||
|
||||
_, err = s.cache.GetAccessToken(s.ctx, cacheKey)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
|
||||
}
|
||||
|
||||
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken_MissingKey() {
|
||||
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, "missing-key"))
|
||||
}
|
||||
|
||||
func TestGeminiTokenCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(GeminiTokenCacheSuite))
|
||||
}
|
||||
28
backend/internal/repository/gemini_token_cache_test.go
Normal file
28
backend/internal/repository/gemini_token_cache_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGeminiTokenCache_DeleteAccessToken_RedisError(t *testing.T) {
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "127.0.0.1:1",
|
||||
DialTimeout: 50 * time.Millisecond,
|
||||
ReadTimeout: 50 * time.Millisecond,
|
||||
WriteTimeout: 50 * time.Millisecond,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = rdb.Close()
|
||||
})
|
||||
|
||||
cache := NewGeminiTokenCache(rdb)
|
||||
err := cache.DeleteAccessToken(context.Background(), "broken")
|
||||
require.Error(t, err)
|
||||
}
|
||||
104
backend/internal/repository/geminicli_codeassist_client.go
Normal file
104
backend/internal/repository/geminicli_codeassist_client.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
type geminiCliCodeAssistClient struct {
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func NewGeminiCliCodeAssistClient() service.GeminiCliCodeAssistClient {
|
||||
return &geminiCliCodeAssistClient{baseURL: geminicli.GeminiCliBaseURL}
|
||||
}
|
||||
|
||||
func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessToken, proxyURL string, reqBody *geminicli.LoadCodeAssistRequest) (*geminicli.LoadCodeAssistResponse, error) {
|
||||
if reqBody == nil {
|
||||
reqBody = defaultLoadCodeAssistRequest()
|
||||
}
|
||||
|
||||
var out geminicli.LoadCodeAssistResponse
|
||||
resp, err := createGeminiCliReqClient(proxyURL).R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+accessToken).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetHeader("User-Agent", geminicli.GeminiCLIUserAgent).
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&out).
|
||||
Post(c.baseURL + "/v1internal:loadCodeAssist")
|
||||
if err != nil {
|
||||
fmt.Printf("[CodeAssist] LoadCodeAssist request error: %v\n", err)
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
if !resp.IsSuccessState() {
|
||||
body := geminicli.SanitizeBodyForLogs(resp.String())
|
||||
fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, body)
|
||||
return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, body)
|
||||
}
|
||||
fmt.Printf("[CodeAssist] LoadCodeAssist success: status %d, response: %+v\n", resp.StatusCode, out)
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken, proxyURL string, reqBody *geminicli.OnboardUserRequest) (*geminicli.OnboardUserResponse, error) {
|
||||
if reqBody == nil {
|
||||
reqBody = defaultOnboardUserRequest()
|
||||
}
|
||||
|
||||
fmt.Printf("[CodeAssist] OnboardUser request body: %+v\n", reqBody)
|
||||
|
||||
var out geminicli.OnboardUserResponse
|
||||
resp, err := createGeminiCliReqClient(proxyURL).R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+accessToken).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetHeader("User-Agent", geminicli.GeminiCLIUserAgent).
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&out).
|
||||
Post(c.baseURL + "/v1internal:onboardUser")
|
||||
if err != nil {
|
||||
fmt.Printf("[CodeAssist] OnboardUser request error: %v\n", err)
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
if !resp.IsSuccessState() {
|
||||
body := geminicli.SanitizeBodyForLogs(resp.String())
|
||||
fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, body)
|
||||
return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, body)
|
||||
}
|
||||
fmt.Printf("[CodeAssist] OnboardUser success: status %d, response: %+v\n", resp.StatusCode, out)
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func createGeminiCliReqClient(proxyURL string) *req.Client {
|
||||
return getSharedReqClient(reqClientOptions{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 30 * time.Second,
|
||||
})
|
||||
}
|
||||
|
||||
func defaultLoadCodeAssistRequest() *geminicli.LoadCodeAssistRequest {
|
||||
return &geminicli.LoadCodeAssistRequest{
|
||||
Metadata: geminicli.LoadCodeAssistMetadata{
|
||||
IDEType: "ANTIGRAVITY",
|
||||
Platform: "PLATFORM_UNSPECIFIED",
|
||||
PluginType: "GEMINI",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func defaultOnboardUserRequest() *geminicli.OnboardUserRequest {
|
||||
return &geminicli.OnboardUserRequest{
|
||||
TierID: "LEGACY",
|
||||
Metadata: geminicli.LoadCodeAssistMetadata{
|
||||
IDEType: "ANTIGRAVITY",
|
||||
Platform: "PLATFORM_UNSPECIFIED",
|
||||
PluginType: "GEMINI",
|
||||
},
|
||||
}
|
||||
}
|
||||
136
backend/internal/repository/github_release_service.go
Normal file
136
backend/internal/repository/github_release_service.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type githubReleaseClient struct {
|
||||
httpClient *http.Client
|
||||
downloadHTTPClient *http.Client
|
||||
}
|
||||
|
||||
// NewGitHubReleaseClient 创建 GitHub Release 客户端
|
||||
// proxyURL 为空时直连 GitHub,支持 http/https/socks5/socks5h 协议
|
||||
func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient {
|
||||
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 30 * time.Second,
|
||||
ProxyURL: proxyURL,
|
||||
})
|
||||
if err != nil {
|
||||
sharedClient = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
|
||||
// 下载客户端需要更长的超时时间
|
||||
downloadClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 10 * time.Minute,
|
||||
ProxyURL: proxyURL,
|
||||
})
|
||||
if err != nil {
|
||||
downloadClient = &http.Client{Timeout: 10 * time.Minute}
|
||||
}
|
||||
|
||||
return &githubReleaseClient{
|
||||
httpClient: sharedClient,
|
||||
downloadHTTPClient: downloadClient,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo string) (*service.GitHubRelease, error) {
|
||||
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", "application/vnd.github.v3+json")
|
||||
req.Header.Set("User-Agent", "Sub2API-Updater")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var release service.GitHubRelease
|
||||
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &release, nil
|
||||
}
|
||||
|
||||
func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string, maxSize int64) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 使用预配置的下载客户端(已包含代理配置)
|
||||
resp, err := c.downloadHTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("download returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// SECURITY: Check Content-Length if available
|
||||
if resp.ContentLength > maxSize {
|
||||
return fmt.Errorf("file too large: %d bytes (max %d)", resp.ContentLength, maxSize)
|
||||
}
|
||||
|
||||
out, err := os.Create(dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = out.Close() }()
|
||||
|
||||
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
|
||||
limited := io.LimitReader(resp.Body, maxSize+1)
|
||||
written, err := io.Copy(out, limited)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if we hit the limit (downloaded more than maxSize)
|
||||
if written > maxSize {
|
||||
_ = os.Remove(dest) // Clean up partial file (best-effort)
|
||||
return fmt.Errorf("download exceeded maximum size of %d bytes", maxSize)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *githubReleaseClient) FetchChecksumFile(ctx context.Context, url string) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
317
backend/internal/repository/github_release_service_test.go
Normal file
317
backend/internal/repository/github_release_service_test.go
Normal file
@@ -0,0 +1,317 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type GitHubReleaseServiceSuite struct {
|
||||
suite.Suite
|
||||
srv *httptest.Server
|
||||
client *githubReleaseClient
|
||||
tempDir string
|
||||
}
|
||||
|
||||
// testTransport redirects requests to the test server
|
||||
type testTransport struct {
|
||||
testServerURL string
|
||||
}
|
||||
|
||||
func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Rewrite the URL to point to our test server
|
||||
testURL := t.testServerURL + req.URL.Path
|
||||
newReq, err := http.NewRequestWithContext(req.Context(), req.Method, testURL, req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newReq.Header = req.Header
|
||||
return http.DefaultTransport.RoundTrip(newReq)
|
||||
}
|
||||
|
||||
func newTestGitHubReleaseClient() *githubReleaseClient {
|
||||
return &githubReleaseClient{
|
||||
httpClient: &http.Client{},
|
||||
downloadHTTPClient: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) SetupTest() {
|
||||
s.tempDir = s.T().TempDir()
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLength() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Length", "100")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(bytes.Repeat([]byte("a"), 100))
|
||||
}))
|
||||
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
dest := filepath.Join(s.tempDir, "file1.bin")
|
||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
|
||||
require.Error(s.T(), err, "expected error for oversized download with Content-Length")
|
||||
|
||||
_, statErr := os.Stat(dest)
|
||||
require.Error(s.T(), statErr, "expected file to not exist for rejected download")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Force chunked encoding (unknown Content-Length) by flushing headers before writing.
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if fl, ok := w.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
_, _ = w.Write(bytes.Repeat([]byte("b"), 10))
|
||||
if fl, ok := w.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
dest := filepath.Join(s.tempDir, "file2.bin")
|
||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
|
||||
require.Error(s.T(), err, "expected error for oversized chunked download")
|
||||
|
||||
_, statErr := os.Stat(dest)
|
||||
require.Error(s.T(), statErr, "expected file to be cleaned up for oversized chunked download")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if fl, ok := w.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
_, _ = w.Write(bytes.Repeat([]byte("b"), 10))
|
||||
if fl, ok := w.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
dest := filepath.Join(s.tempDir, "file3.bin")
|
||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 200)
|
||||
require.NoError(s.T(), err, "expected success")
|
||||
|
||||
b, err := os.ReadFile(dest)
|
||||
require.NoError(s.T(), err, "read")
|
||||
require.True(s.T(), strings.HasPrefix(string(b), "b"), "downloaded content should start with 'b'")
|
||||
require.Len(s.T(), b, 100, "downloaded content length mismatch")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
dest := filepath.Join(s.tempDir, "notfound.bin")
|
||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
|
||||
require.Error(s.T(), err, "expected error for 404")
|
||||
|
||||
_, statErr := os.Stat(dest)
|
||||
require.Error(s.T(), statErr, "expected file to not exist for 404")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("sum"))
|
||||
}))
|
||||
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
body, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
|
||||
require.NoError(s.T(), err, "FetchChecksumFile")
|
||||
require.Equal(s.T(), "sum", string(body), "checksum body mismatch")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
_, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
|
||||
require.Error(s.T(), err, "expected error for non-200")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
dest := filepath.Join(s.tempDir, "cancelled.bin")
|
||||
err := s.client.DownloadFile(ctx, s.srv.URL, dest, 100)
|
||||
require.Error(s.T(), err, "expected error for cancelled context")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
dest := filepath.Join(s.tempDir, "invalid.bin")
|
||||
err := s.client.DownloadFile(context.Background(), "://invalid-url", dest, 100)
|
||||
require.Error(s.T(), err, "expected error for invalid URL")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("content"))
|
||||
}))
|
||||
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
// Use a path that cannot be created (directory doesn't exist)
|
||||
dest := filepath.Join(s.tempDir, "nonexistent", "subdir", "file.bin")
|
||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
|
||||
require.Error(s.T(), err, "expected error for invalid destination path")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_InvalidURL() {
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
_, err := s.client.FetchChecksumFile(context.Background(), "://invalid-url")
|
||||
require.Error(s.T(), err, "expected error for invalid URL")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
|
||||
releaseJSON := `{
|
||||
"tag_name": "v1.0.0",
|
||||
"name": "Release 1.0.0",
|
||||
"body": "Release notes",
|
||||
"html_url": "https://github.com/test/repo/releases/v1.0.0",
|
||||
"assets": [
|
||||
{
|
||||
"name": "app-linux-amd64.tar.gz",
|
||||
"browser_download_url": "https://github.com/test/repo/releases/download/v1.0.0/app-linux-amd64.tar.gz"
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(s.T(), "/repos/test/repo/releases/latest", r.URL.Path)
|
||||
require.Equal(s.T(), "application/vnd.github.v3+json", r.Header.Get("Accept"))
|
||||
require.Equal(s.T(), "Sub2API-Updater", r.Header.Get("User-Agent"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(releaseJSON))
|
||||
}))
|
||||
|
||||
// Use custom transport to redirect requests to test server
|
||||
s.client = &githubReleaseClient{
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
downloadHTTPClient: &http.Client{},
|
||||
}
|
||||
|
||||
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), "v1.0.0", release.TagName)
|
||||
require.Equal(s.T(), "Release 1.0.0", release.Name)
|
||||
require.Len(s.T(), release.Assets, 1)
|
||||
require.Equal(s.T(), "app-linux-amd64.tar.gz", release.Assets[0].Name)
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
|
||||
s.client = &githubReleaseClient{
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
downloadHTTPClient: &http.Client{},
|
||||
}
|
||||
|
||||
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||
require.Error(s.T(), err)
|
||||
require.Contains(s.T(), err.Error(), "404")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("not valid json"))
|
||||
}))
|
||||
|
||||
s.client = &githubReleaseClient{
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
downloadHTTPClient: &http.Client{},
|
||||
}
|
||||
|
||||
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||
require.Error(s.T(), err)
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
s.client = &githubReleaseClient{
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
downloadHTTPClient: &http.Client{},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err := s.client.FetchLatestRelease(ctx, "test/repo")
|
||||
require.Error(s.T(), err)
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err := s.client.FetchChecksumFile(ctx, s.srv.URL)
|
||||
require.Error(s.T(), err)
|
||||
}
|
||||
|
||||
func TestGitHubReleaseServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(GitHubReleaseServiceSuite))
|
||||
}
|
||||
413
backend/internal/repository/group_repo.go
Normal file
413
backend/internal/repository/group_repo.go
Normal file
@@ -0,0 +1,413 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"log"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
type sqlExecutor interface {
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
type groupRepository struct {
|
||||
client *dbent.Client
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
func NewGroupRepository(client *dbent.Client, sqlDB *sql.DB) service.GroupRepository {
|
||||
return newGroupRepositoryWithSQL(client, sqlDB)
|
||||
}
|
||||
|
||||
func newGroupRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *groupRepository {
|
||||
return &groupRepository{client: client, sql: sqlq}
|
||||
}
|
||||
|
||||
func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) error {
|
||||
builder := r.client.Group.Create().
|
||||
SetName(groupIn.Name).
|
||||
SetDescription(groupIn.Description).
|
||||
SetPlatform(groupIn.Platform).
|
||||
SetRateMultiplier(groupIn.RateMultiplier).
|
||||
SetIsExclusive(groupIn.IsExclusive).
|
||||
SetStatus(groupIn.Status).
|
||||
SetSubscriptionType(groupIn.SubscriptionType).
|
||||
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
|
||||
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
|
||||
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
|
||||
SetNillableImagePrice1k(groupIn.ImagePrice1K).
|
||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||
SetNillableFallbackGroupID(groupIn.FallbackGroupID)
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
groupIn.ID = created.ID
|
||||
groupIn.CreatedAt = created.CreatedAt
|
||||
groupIn.UpdatedAt = created.UpdatedAt
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group create failed: group=%d err=%v", groupIn.ID, err)
|
||||
}
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrGroupExists)
|
||||
}
|
||||
|
||||
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) {
|
||||
out, err := r.GetByIDLite(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
count, _ := r.GetAccountCount(ctx, out.ID)
|
||||
out.AccountCount = count
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
|
||||
// AccountCount is intentionally not loaded here; use GetByID when needed.
|
||||
m, err := r.client.Group.Query().
|
||||
Where(group.IDEQ(id)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
}
|
||||
|
||||
return groupEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error {
|
||||
builder := r.client.Group.UpdateOneID(groupIn.ID).
|
||||
SetName(groupIn.Name).
|
||||
SetDescription(groupIn.Description).
|
||||
SetPlatform(groupIn.Platform).
|
||||
SetRateMultiplier(groupIn.RateMultiplier).
|
||||
SetIsExclusive(groupIn.IsExclusive).
|
||||
SetStatus(groupIn.Status).
|
||||
SetSubscriptionType(groupIn.SubscriptionType).
|
||||
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
|
||||
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
|
||||
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
|
||||
SetNillableImagePrice1k(groupIn.ImagePrice1K).
|
||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly)
|
||||
|
||||
// 处理 FallbackGroupID:nil 时清除,否则设置
|
||||
if groupIn.FallbackGroupID != nil {
|
||||
builder = builder.SetFallbackGroupID(*groupIn.FallbackGroupID)
|
||||
} else {
|
||||
builder = builder.ClearFallbackGroupID()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
|
||||
}
|
||||
groupIn.UpdatedAt = updated.UpdatedAt
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group update failed: group=%d err=%v", groupIn.ID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) Delete(ctx context.Context, id int64) error {
|
||||
_, err := r.client.Group.Delete().Where(group.IDEQ(id)).Exec(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group delete failed: group=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "", nil)
|
||||
}
|
||||
|
||||
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
q := r.client.Group.Query()
|
||||
|
||||
if platform != "" {
|
||||
q = q.Where(group.PlatformEQ(platform))
|
||||
}
|
||||
if status != "" {
|
||||
q = q.Where(group.StatusEQ(status))
|
||||
}
|
||||
if search != "" {
|
||||
q = q.Where(group.Or(
|
||||
group.NameContainsFold(search),
|
||||
group.DescriptionContainsFold(search),
|
||||
))
|
||||
}
|
||||
if isExclusive != nil {
|
||||
q = q.Where(group.IsExclusiveEQ(*isExclusive))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
groups, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Asc(group.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
groupIDs := make([]int64, 0, len(groups))
|
||||
outGroups := make([]service.Group, 0, len(groups))
|
||||
for i := range groups {
|
||||
g := groupEntityToService(groups[i])
|
||||
outGroups = append(outGroups, *g)
|
||||
groupIDs = append(groupIDs, g.ID)
|
||||
}
|
||||
|
||||
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
||||
if err == nil {
|
||||
for i := range outGroups {
|
||||
outGroups[i].AccountCount = counts[outGroups[i].ID]
|
||||
}
|
||||
}
|
||||
|
||||
return outGroups, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, error) {
|
||||
groups, err := r.client.Group.Query().
|
||||
Where(group.StatusEQ(service.StatusActive)).
|
||||
Order(dbent.Asc(group.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupIDs := make([]int64, 0, len(groups))
|
||||
outGroups := make([]service.Group, 0, len(groups))
|
||||
for i := range groups {
|
||||
g := groupEntityToService(groups[i])
|
||||
outGroups = append(outGroups, *g)
|
||||
groupIDs = append(groupIDs, g.ID)
|
||||
}
|
||||
|
||||
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
||||
if err == nil {
|
||||
for i := range outGroups {
|
||||
outGroups[i].AccountCount = counts[outGroups[i].ID]
|
||||
}
|
||||
}
|
||||
|
||||
return outGroups, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
|
||||
groups, err := r.client.Group.Query().
|
||||
Where(group.StatusEQ(service.StatusActive), group.PlatformEQ(platform)).
|
||||
Order(dbent.Asc(group.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupIDs := make([]int64, 0, len(groups))
|
||||
outGroups := make([]service.Group, 0, len(groups))
|
||||
for i := range groups {
|
||||
g := groupEntityToService(groups[i])
|
||||
outGroups = append(outGroups, *g)
|
||||
groupIDs = append(groupIDs, g.ID)
|
||||
}
|
||||
|
||||
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
||||
if err == nil {
|
||||
for i := range outGroups {
|
||||
outGroups[i].AccountCount = counts[outGroups[i].ID]
|
||||
}
|
||||
}
|
||||
|
||||
return outGroups, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
return r.client.Group.Query().Where(group.NameEQ(name)).Exist(ctx)
|
||||
}
|
||||
|
||||
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
res, err := r.sql.ExecContext(ctx, "DELETE FROM account_groups WHERE group_id = $1", groupID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group account clear failed: group=%d err=%v", groupID, err)
|
||||
}
|
||||
return affected, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
|
||||
g, err := r.client.Group.Query().Where(group.IDEQ(id)).Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
}
|
||||
groupSvc := groupEntityToService(g)
|
||||
|
||||
// 使用 ent 事务统一包裹:避免手工基于 *sql.Tx 构造 ent client 带来的驱动断言问题,
|
||||
// 同时保证级联删除的原子性。
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||
return nil, err
|
||||
}
|
||||
exec := r.client
|
||||
txClient := r.client
|
||||
if err == nil {
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
exec = tx.Client()
|
||||
txClient = exec
|
||||
}
|
||||
// err 为 dbent.ErrTxStarted 时,复用当前 client 参与同一事务。
|
||||
|
||||
// Lock the group row to avoid concurrent writes while we cascade.
|
||||
// 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分"未找到"与其他错误。
|
||||
rows, err := exec.QueryContext(ctx, "SELECT id FROM groups WHERE id = $1 AND deleted_at IS NULL FOR UPDATE", id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var lockedID int64
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&lockedID); err != nil {
|
||||
_ = rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lockedID == 0 {
|
||||
return nil, service.ErrGroupNotFound
|
||||
}
|
||||
|
||||
var affectedUserIDs []int64
|
||||
if groupSvc.IsSubscriptionType() {
|
||||
// 只查询未软删除的订阅,避免通知已取消订阅的用户
|
||||
rows, err := exec.QueryContext(ctx, "SELECT user_id FROM user_subscriptions WHERE group_id = $1 AND deleted_at IS NULL", id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for rows.Next() {
|
||||
var userID int64
|
||||
if scanErr := rows.Scan(&userID); scanErr != nil {
|
||||
_ = rows.Close()
|
||||
return nil, scanErr
|
||||
}
|
||||
affectedUserIDs = append(affectedUserIDs, userID)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 软删除订阅:设置 deleted_at 而非硬删除
|
||||
if _, err := exec.ExecContext(ctx, "UPDATE user_subscriptions SET deleted_at = NOW() WHERE group_id = $1 AND deleted_at IS NULL", id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Clear group_id for api keys bound to this group.
|
||||
// 仅更新未软删除的记录,避免修改已删除数据,保证审计与历史回溯一致性。
|
||||
// 与 APIKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
|
||||
if _, err := txClient.APIKey.Update().
|
||||
Where(apikey.GroupIDEQ(id), apikey.DeletedAtIsNil()).
|
||||
ClearGroupID().
|
||||
Save(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3. Remove the group id from user_allowed_groups join table.
|
||||
// Legacy users.allowed_groups 列已弃用,不再同步。
|
||||
if _, err := exec.ExecContext(ctx, "DELETE FROM user_allowed_groups WHERE group_id = $1", id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4. Delete account_groups join rows.
|
||||
if _, err := exec.ExecContext(ctx, "DELETE FROM account_groups WHERE group_id = $1", id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 5. Soft-delete group itself.
|
||||
if _, err := txClient.Group.Delete().Where(group.IDEQ(id)).Exec(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tx != nil {
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group cascade delete failed: group=%d err=%v", id, err)
|
||||
}
|
||||
|
||||
return affectedUserIDs, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]int64, err error) {
|
||||
counts = make(map[int64]int64, len(groupIDs))
|
||||
if len(groupIDs) == 0 {
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
rows, err := r.sql.QueryContext(
|
||||
ctx,
|
||||
"SELECT group_id, COUNT(*) FROM account_groups WHERE group_id = ANY($1) GROUP BY group_id",
|
||||
pq.Array(groupIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
counts = nil
|
||||
}
|
||||
}()
|
||||
|
||||
for rows.Next() {
|
||||
var groupID int64
|
||||
var count int64
|
||||
if err = rows.Scan(&groupID, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts[groupID] = count
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return counts, nil
|
||||
}
|
||||
677
backend/internal/repository/group_repo_integration_test.go
Normal file
677
backend/internal/repository/group_repo_integration_test.go
Normal file
@@ -0,0 +1,677 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type GroupRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
tx *dbent.Tx
|
||||
repo *groupRepository
|
||||
}
|
||||
|
||||
type forbidSQLExecutor struct {
|
||||
called bool
|
||||
}
|
||||
|
||||
func (s *forbidSQLExecutor) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
||||
s.called = true
|
||||
return nil, errors.New("unexpected sql exec")
|
||||
}
|
||||
|
||||
func (s *forbidSQLExecutor) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
|
||||
s.called = true
|
||||
return nil, errors.New("unexpected sql query")
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.tx = tx
|
||||
s.repo = newGroupRepositoryWithSQL(tx.Client(), tx)
|
||||
}
|
||||
|
||||
func TestGroupRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(GroupRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / GetByID / Update / Delete ---
|
||||
|
||||
func (s *GroupRepoSuite) TestCreate() {
|
||||
group := &service.Group{
|
||||
Name: "test-create",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, group)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(group.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("test-create", got.Name)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
s.Require().ErrorIs(err, service.ErrGroupNotFound)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestGetByIDLite_DoesNotUseAccountCount() {
|
||||
group := &service.Group{
|
||||
Name: "lite-group",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||
|
||||
spy := &forbidSQLExecutor{}
|
||||
repo := newGroupRepositoryWithSQL(s.tx.Client(), spy)
|
||||
|
||||
got, err := repo.GetByIDLite(s.ctx, group.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(group.ID, got.ID)
|
||||
s.Require().False(spy.called, "expected no direct sql executor usage")
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestUpdate() {
|
||||
group := &service.Group{
|
||||
Name: "original",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||
|
||||
group.Name = "updated"
|
||||
err := s.repo.Update(s.ctx, group)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("updated", got.Name)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestDelete() {
|
||||
group := &service.Group{
|
||||
Name: "to-delete",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||
|
||||
err := s.repo.Delete(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, group.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
s.Require().ErrorIs(err, service.ErrGroupNotFound)
|
||||
}
|
||||
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *GroupRepoSuite) TestList() {
|
||||
baseGroups, basePage, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List base")
|
||||
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g2",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(groups, len(baseGroups)+2)
|
||||
s.Require().Equal(basePage.Total+2, page.Total)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_Platform() {
|
||||
baseGroups, _, err := s.repo.ListWithFilters(
|
||||
s.ctx,
|
||||
pagination.PaginationParams{Page: 1, PageSize: 10},
|
||||
service.PlatformOpenAI,
|
||||
"",
|
||||
"",
|
||||
nil,
|
||||
)
|
||||
s.Require().NoError(err, "ListWithFilters base")
|
||||
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g2",
|
||||
Platform: service.PlatformOpenAI,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", "", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(groups, len(baseGroups)+1)
|
||||
// Verify all groups are OpenAI platform
|
||||
for _, g := range groups {
|
||||
s.Require().Equal(service.PlatformOpenAI, g.Platform)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_Status() {
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g2",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusDisabled,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().Equal(service.StatusDisabled, groups[0].Status)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g2",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: true,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
isExclusive := true
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", &isExclusive)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().True(groups[0].IsExclusive)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_Search() {
|
||||
newRepo := func() (*groupRepository, context.Context) {
|
||||
tx := testEntTx(s.T())
|
||||
return newGroupRepositoryWithSQL(tx.Client(), tx), context.Background()
|
||||
}
|
||||
|
||||
containsID := func(groups []service.Group, id int64) bool {
|
||||
for i := range groups {
|
||||
if groups[i].ID == id {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
mustCreate := func(repo *groupRepository, ctx context.Context, g *service.Group) *service.Group {
|
||||
s.Require().NoError(repo.Create(ctx, g))
|
||||
s.Require().NotZero(g.ID)
|
||||
return g
|
||||
}
|
||||
|
||||
newGroup := func(name string) *service.Group {
|
||||
return &service.Group{
|
||||
Name: name,
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
}
|
||||
|
||||
s.Run("search_name_should_match", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
target := mustCreate(repo, ctx, newGroup("it-group-search-name-target"))
|
||||
other := mustCreate(repo, ctx, newGroup("it-group-search-name-other"))
|
||||
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "name-target", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, target.ID), "expected target group to match by name")
|
||||
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
|
||||
})
|
||||
|
||||
s.Run("search_description_should_match", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
target := newGroup("it-group-search-desc-target")
|
||||
target.Description = "something about desc-needle in here"
|
||||
target = mustCreate(repo, ctx, target)
|
||||
|
||||
other := newGroup("it-group-search-desc-other")
|
||||
other.Description = "nothing to see here"
|
||||
other = mustCreate(repo, ctx, other)
|
||||
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "desc-needle", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, target.ID), "expected target group to match by description")
|
||||
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
|
||||
})
|
||||
|
||||
s.Run("search_nonexistent_should_return_empty", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
_ = mustCreate(repo, ctx, newGroup("it-group-search-nonexistent-baseline"))
|
||||
|
||||
search := s.T().Name() + "__no_such_group__"
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", search, nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(groups)
|
||||
})
|
||||
|
||||
s.Run("search_should_be_case_insensitive", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
target := mustCreate(repo, ctx, newGroup("MiXeDCaSe-Needle"))
|
||||
other := mustCreate(repo, ctx, newGroup("it-group-search-case-other"))
|
||||
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "mixedcase-needle", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, target.ID), "expected case-insensitive match")
|
||||
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
|
||||
})
|
||||
|
||||
s.Run("search_should_escape_like_wildcards", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
percentTarget := mustCreate(repo, ctx, newGroup("it-group-search-100%-target"))
|
||||
percentOther := mustCreate(repo, ctx, newGroup("it-group-search-100X-other"))
|
||||
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "100%", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, percentTarget.ID), "expected literal %% match")
|
||||
s.Require().False(containsID(groups, percentOther.ID), "expected %% not to act as wildcard")
|
||||
|
||||
underscoreTarget := mustCreate(repo, ctx, newGroup("it-group-search-ab_cd-target"))
|
||||
underscoreOther := mustCreate(repo, ctx, newGroup("it-group-search-abXcd-other"))
|
||||
|
||||
groups, _, err = repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "ab_cd", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, underscoreTarget.ID), "expected literal _ match")
|
||||
s.Require().False(containsID(groups, underscoreOther.ID), "expected _ not to act as wildcard")
|
||||
})
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
|
||||
g1 := &service.Group{
|
||||
Name: "g1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
g2 := &service.Group{
|
||||
Name: "g2",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: true,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g1))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g2))
|
||||
|
||||
var accountID int64
|
||||
s.Require().NoError(scanSingleRow(
|
||||
s.ctx,
|
||||
s.tx,
|
||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||
[]any{"acc1", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||
&accountID,
|
||||
))
|
||||
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g1.ID, 1)
|
||||
s.Require().NoError(err)
|
||||
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g2.ID, 1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
isExclusive := true
|
||||
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, "", &isExclusive)
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().Equal(g2.ID, groups[0].ID, "ListWithFilters returned wrong group")
|
||||
s.Require().Equal(int64(1), groups[0].AccountCount, "AccountCount mismatch")
|
||||
}
|
||||
|
||||
// --- ListActive / ListActiveByPlatform ---
|
||||
|
||||
func (s *GroupRepoSuite) TestListActive() {
|
||||
baseGroups, err := s.repo.ListActive(s.ctx)
|
||||
s.Require().NoError(err, "ListActive base")
|
||||
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "active1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "inactive1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusDisabled,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
groups, err := s.repo.ListActive(s.ctx)
|
||||
s.Require().NoError(err, "ListActive")
|
||||
s.Require().Len(groups, len(baseGroups)+1)
|
||||
// Verify our test group is in the results
|
||||
var found bool
|
||||
for _, g := range groups {
|
||||
if g.Name == "active1" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
s.Require().True(found, "active1 group should be in results")
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListActiveByPlatform() {
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g2",
|
||||
Platform: service.PlatformOpenAI,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g3",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusDisabled,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
groups, err := s.repo.ListActiveByPlatform(s.ctx, service.PlatformAnthropic)
|
||||
s.Require().NoError(err, "ListActiveByPlatform")
|
||||
// 1 default anthropic group + 1 test active anthropic group = 2 total
|
||||
s.Require().Len(groups, 2)
|
||||
// Verify our test group is in the results
|
||||
var found bool
|
||||
for _, g := range groups {
|
||||
if g.Name == "g1" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
s.Require().True(found, "g1 group should be in results")
|
||||
}
|
||||
|
||||
// --- ExistsByName ---
|
||||
|
||||
func (s *GroupRepoSuite) TestExistsByName() {
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "existing-group",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
exists, err := s.repo.ExistsByName(s.ctx, "existing-group")
|
||||
s.Require().NoError(err, "ExistsByName")
|
||||
s.Require().True(exists)
|
||||
|
||||
notExists, err := s.repo.ExistsByName(s.ctx, "non-existing")
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(notExists)
|
||||
}
|
||||
|
||||
// --- GetAccountCount ---
|
||||
|
||||
func (s *GroupRepoSuite) TestGetAccountCount() {
|
||||
group := &service.Group{
|
||||
Name: "g-count",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||
|
||||
var a1 int64
|
||||
s.Require().NoError(scanSingleRow(
|
||||
s.ctx,
|
||||
s.tx,
|
||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||
[]any{"a1", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||
&a1,
|
||||
))
|
||||
var a2 int64
|
||||
s.Require().NoError(scanSingleRow(
|
||||
s.ctx,
|
||||
s.tx,
|
||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||
[]any{"a2", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||
&a2,
|
||||
))
|
||||
|
||||
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a1, group.ID, 1)
|
||||
s.Require().NoError(err)
|
||||
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, group.ID, 2)
|
||||
s.Require().NoError(err)
|
||||
|
||||
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "GetAccountCount")
|
||||
s.Require().Equal(int64(2), count)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
|
||||
group := &service.Group{
|
||||
Name: "g-empty",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||
|
||||
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
|
||||
// --- DeleteAccountGroupsByGroupID ---
|
||||
|
||||
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
|
||||
g := &service.Group{
|
||||
Name: "g-del",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g))
|
||||
var accountID int64
|
||||
s.Require().NoError(scanSingleRow(
|
||||
s.ctx,
|
||||
s.tx,
|
||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||
[]any{"acc-del", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||
&accountID,
|
||||
))
|
||||
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g.ID, 1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
|
||||
s.Require().NoError(err, "DeleteAccountGroupsByGroupID")
|
||||
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||
|
||||
count, err := s.repo.GetAccountCount(s.ctx, g.ID)
|
||||
s.Require().NoError(err, "GetAccountCount")
|
||||
s.Require().Equal(int64(0), count, "expected 0 account groups")
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
|
||||
g := &service.Group{
|
||||
Name: "g-multi",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g))
|
||||
|
||||
insertAccount := func(name string) int64 {
|
||||
var id int64
|
||||
s.Require().NoError(scanSingleRow(
|
||||
s.ctx,
|
||||
s.tx,
|
||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||
[]any{name, service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||
&id,
|
||||
))
|
||||
return id
|
||||
}
|
||||
a1 := insertAccount("a1")
|
||||
a2 := insertAccount("a2")
|
||||
a3 := insertAccount("a3")
|
||||
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a1, g.ID, 1)
|
||||
s.Require().NoError(err)
|
||||
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, g.ID, 2)
|
||||
s.Require().NoError(err)
|
||||
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a3, g.ID, 3)
|
||||
s.Require().NoError(err)
|
||||
|
||||
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(int64(3), affected)
|
||||
|
||||
count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
|
||||
// --- 软删除过滤测试 ---
|
||||
|
||||
func (s *GroupRepoSuite) TestDelete_SoftDelete_NotVisibleInList() {
|
||||
group := &service.Group{
|
||||
Name: "to-soft-delete",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||
|
||||
// 获取删除前的列表数量
|
||||
listBefore, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100})
|
||||
s.Require().NoError(err)
|
||||
beforeCount := len(listBefore)
|
||||
|
||||
// 软删除
|
||||
err = s.repo.Delete(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "Delete (soft delete)")
|
||||
|
||||
// 验证列表中不再包含软删除的 group
|
||||
listAfter, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(listAfter, beforeCount-1, "soft deleted group should not appear in list")
|
||||
|
||||
// 验证 GetByID 也无法找到
|
||||
_, err = s.repo.GetByID(s.ctx, group.ID)
|
||||
s.Require().Error(err)
|
||||
s.Require().ErrorIs(err, service.ErrGroupNotFound)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestDelete_SoftDeletedGroup_lockForUpdate() {
|
||||
group := &service.Group{
|
||||
Name: "lock-soft-delete",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||
|
||||
// 软删除
|
||||
err := s.repo.Delete(s.ctx, group.ID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// 验证软删除的 group 在 GetByID 时返回 ErrGroupNotFound
|
||||
// 这证明 lockForUpdate 的 deleted_at IS NULL 过滤正在工作
|
||||
_, err = s.repo.GetByID(s.ctx, group.ID)
|
||||
s.Require().Error(err, "should fail to get soft-deleted group")
|
||||
s.Require().ErrorIs(err, service.ErrGroupNotFound)
|
||||
}
|
||||
653
backend/internal/repository/http_upstream.go
Normal file
653
backend/internal/repository/http_upstream.go
Normal file
@@ -0,0 +1,653 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
)
|
||||
|
||||
// 默认配置常量
|
||||
// 这些值在配置文件未指定时作为回退默认值使用
|
||||
const (
|
||||
// directProxyKey: 无代理时的缓存键标识
|
||||
directProxyKey = "direct"
|
||||
// defaultMaxIdleConns: 默认最大空闲连接总数
|
||||
// HTTP/2 场景下,单连接可多路复用,240 足以支撑高并发
|
||||
defaultMaxIdleConns = 240
|
||||
// defaultMaxIdleConnsPerHost: 默认每主机最大空闲连接数
|
||||
defaultMaxIdleConnsPerHost = 120
|
||||
// defaultMaxConnsPerHost: 默认每主机最大连接数(含活跃连接)
|
||||
// 达到上限后新请求会等待,而非无限创建连接
|
||||
defaultMaxConnsPerHost = 240
|
||||
// defaultIdleConnTimeout: 默认空闲连接超时时间(90秒)
|
||||
// 超时后连接会被关闭,释放系统资源(建议小于上游 LB 超时)
|
||||
defaultIdleConnTimeout = 90 * time.Second
|
||||
// defaultResponseHeaderTimeout: 默认等待响应头超时时间(5分钟)
|
||||
// LLM 请求可能排队较久,需要较长超时
|
||||
defaultResponseHeaderTimeout = 300 * time.Second
|
||||
// defaultMaxUpstreamClients: 默认最大客户端缓存数量
|
||||
// 超出后会淘汰最久未使用的客户端
|
||||
defaultMaxUpstreamClients = 5000
|
||||
// defaultClientIdleTTLSeconds: 默认客户端空闲回收阈值(15分钟)
|
||||
defaultClientIdleTTLSeconds = 900
|
||||
)
|
||||
|
||||
var errUpstreamClientLimitReached = errors.New("upstream client cache limit reached")
|
||||
|
||||
// poolSettings 连接池配置参数
|
||||
// 封装 Transport 所需的各项连接池参数
|
||||
type poolSettings struct {
|
||||
maxIdleConns int // 最大空闲连接总数
|
||||
maxIdleConnsPerHost int // 每主机最大空闲连接数
|
||||
maxConnsPerHost int // 每主机最大连接数(含活跃)
|
||||
idleConnTimeout time.Duration // 空闲连接超时时间
|
||||
responseHeaderTimeout time.Duration // 等待响应头超时时间
|
||||
}
|
||||
|
||||
// upstreamClientEntry 上游客户端缓存条目
|
||||
// 记录客户端实例及其元数据,用于连接池管理和淘汰策略
|
||||
type upstreamClientEntry struct {
|
||||
client *http.Client // HTTP 客户端实例
|
||||
proxyKey string // 代理标识(用于检测代理变更)
|
||||
poolKey string // 连接池配置标识(用于检测配置变更)
|
||||
lastUsed int64 // 最后使用时间戳(纳秒),用于 LRU 淘汰
|
||||
inFlight int64 // 当前进行中的请求数,>0 时不可淘汰
|
||||
}
|
||||
|
||||
// httpUpstreamService 通用 HTTP 上游服务
|
||||
// 用于向任意 HTTP API(Claude、OpenAI 等)发送请求,支持可选代理
|
||||
//
|
||||
// 架构设计:
|
||||
// - 根据隔离策略(proxy/account/account_proxy)缓存客户端实例
|
||||
// - 每个客户端拥有独立的 Transport 连接池
|
||||
// - 支持 LRU + 空闲时间双重淘汰策略
|
||||
//
|
||||
// 性能优化:
|
||||
// 1. 根据隔离策略缓存客户端实例,避免频繁创建 http.Client
|
||||
// 2. 复用 Transport 连接池,减少 TCP 握手和 TLS 协商开销
|
||||
// 3. 支持账号级隔离与空闲回收,降低连接层关联风险
|
||||
// 4. 达到最大连接数后等待可用连接,而非无限创建
|
||||
// 5. 仅回收空闲客户端,避免中断活跃请求
|
||||
// 6. HTTP/2 多路复用,连接上限不等于并发请求上限
|
||||
// 7. 代理变更时清空旧连接池,避免复用错误代理
|
||||
// 8. 账号并发数与连接池上限对应(账号隔离策略下)
|
||||
type httpUpstreamService struct {
|
||||
cfg *config.Config // 全局配置
|
||||
mu sync.RWMutex // 保护 clients map 的读写锁
|
||||
clients map[string]*upstreamClientEntry // 客户端缓存池,key 由隔离策略决定
|
||||
}
|
||||
|
||||
// NewHTTPUpstream 创建通用 HTTP 上游服务
|
||||
// 使用配置中的连接池参数构建 Transport
|
||||
//
|
||||
// 参数:
|
||||
// - cfg: 全局配置,包含连接池参数和隔离策略
|
||||
//
|
||||
// 返回:
|
||||
// - service.HTTPUpstream 接口实现
|
||||
func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
|
||||
return &httpUpstreamService{
|
||||
cfg: cfg,
|
||||
clients: make(map[string]*upstreamClientEntry),
|
||||
}
|
||||
}
|
||||
|
||||
// Do 执行 HTTP 请求
|
||||
// 根据隔离策略获取或创建客户端,并跟踪请求生命周期
|
||||
//
|
||||
// 参数:
|
||||
// - req: HTTP 请求对象
|
||||
// - proxyURL: 代理地址,空字符串表示直连
|
||||
// - accountID: 账户 ID,用于账户级隔离
|
||||
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
|
||||
//
|
||||
// 返回:
|
||||
// - *http.Response: HTTP 响应(Body 已包装,关闭时自动更新计数)
|
||||
// - error: 请求错误
|
||||
//
|
||||
// 注意:
|
||||
// - 调用方必须关闭 resp.Body,否则会导致 inFlight 计数泄漏
|
||||
// - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断
|
||||
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||
if err := s.validateRequestHost(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取或创建对应的客户端,并标记请求占用
|
||||
entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
resp, err := entry.client.Do(req)
|
||||
if err != nil {
|
||||
// 请求失败,立即减少计数
|
||||
atomic.AddInt64(&entry.inFlight, -1)
|
||||
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 包装响应体,在关闭时自动减少计数并更新时间戳
|
||||
// 这确保了流式响应(如 SSE)在完全读取前不会被淘汰
|
||||
resp.Body = wrapTrackedBody(resp.Body, func() {
|
||||
atomic.AddInt64(&entry.inFlight, -1)
|
||||
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
|
||||
})
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *httpUpstreamService) shouldValidateResolvedIP() bool {
|
||||
if s.cfg == nil {
|
||||
return false
|
||||
}
|
||||
if !s.cfg.Security.URLAllowlist.Enabled {
|
||||
return false
|
||||
}
|
||||
return !s.cfg.Security.URLAllowlist.AllowPrivateHosts
|
||||
}
|
||||
|
||||
func (s *httpUpstreamService) validateRequestHost(req *http.Request) error {
|
||||
if !s.shouldValidateResolvedIP() {
|
||||
return nil
|
||||
}
|
||||
if req == nil || req.URL == nil {
|
||||
return errors.New("request url is nil")
|
||||
}
|
||||
host := strings.TrimSpace(req.URL.Hostname())
|
||||
if host == "" {
|
||||
return errors.New("request host is empty")
|
||||
}
|
||||
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *httpUpstreamService) redirectChecker(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 10 {
|
||||
return errors.New("stopped after 10 redirects")
|
||||
}
|
||||
return s.validateRequestHost(req)
|
||||
}
|
||||
|
||||
// acquireClient 获取或创建客户端,并标记为进行中请求
|
||||
// 用于请求路径,避免在获取后被淘汰
|
||||
func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
|
||||
return s.getClientEntry(proxyURL, accountID, accountConcurrency, true, true)
|
||||
}
|
||||
|
||||
// getOrCreateClient 获取或创建客户端
|
||||
// 根据隔离策略和参数决定缓存键,处理代理变更和配置变更
|
||||
//
|
||||
// 参数:
|
||||
// - proxyURL: 代理地址
|
||||
// - accountID: 账户 ID
|
||||
// - accountConcurrency: 账户并发限制
|
||||
//
|
||||
// 返回:
|
||||
// - *upstreamClientEntry: 客户端缓存条目
|
||||
//
|
||||
// 隔离策略说明:
|
||||
// - proxy: 按代理地址隔离,同一代理共享客户端
|
||||
// - account: 按账户隔离,同一账户共享客户端(代理变更时重建)
|
||||
// - account_proxy: 按账户+代理组合隔离,最细粒度
|
||||
func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) *upstreamClientEntry {
|
||||
entry, _ := s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false)
|
||||
return entry
|
||||
}
|
||||
|
||||
// getClientEntry 获取或创建客户端条目
|
||||
// markInFlight=true 时会标记进行中请求,用于请求路径防止被淘汰
|
||||
// enforceLimit=true 时会限制客户端数量,超限且无法淘汰时返回错误
|
||||
func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
|
||||
// 获取隔离模式
|
||||
isolation := s.getIsolationMode()
|
||||
// 标准化代理 URL 并解析
|
||||
proxyKey, parsedProxy := normalizeProxyURL(proxyURL)
|
||||
// 构建缓存键(根据隔离策略不同)
|
||||
cacheKey := buildCacheKey(isolation, proxyKey, accountID)
|
||||
// 构建连接池配置键(用于检测配置变更)
|
||||
poolKey := s.buildPoolKey(isolation, accountConcurrency)
|
||||
|
||||
now := time.Now()
|
||||
nowUnix := now.UnixNano()
|
||||
|
||||
// 读锁快速路径:命中缓存直接返回,减少锁竞争
|
||||
s.mu.RLock()
|
||||
if entry, ok := s.clients[cacheKey]; ok && s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
|
||||
atomic.StoreInt64(&entry.lastUsed, nowUnix)
|
||||
if markInFlight {
|
||||
atomic.AddInt64(&entry.inFlight, 1)
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
return entry, nil
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
// 写锁慢路径:创建或重建客户端
|
||||
s.mu.Lock()
|
||||
if entry, ok := s.clients[cacheKey]; ok {
|
||||
if s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
|
||||
atomic.StoreInt64(&entry.lastUsed, nowUnix)
|
||||
if markInFlight {
|
||||
atomic.AddInt64(&entry.inFlight, 1)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
return entry, nil
|
||||
}
|
||||
s.removeClientLocked(cacheKey, entry)
|
||||
}
|
||||
|
||||
// 超出缓存上限时尝试淘汰,无法淘汰则拒绝新建
|
||||
if enforceLimit && s.maxUpstreamClients() > 0 {
|
||||
s.evictIdleLocked(now)
|
||||
if len(s.clients) >= s.maxUpstreamClients() {
|
||||
if !s.evictOldestIdleLocked() {
|
||||
s.mu.Unlock()
|
||||
return nil, errUpstreamClientLimitReached
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 缓存未命中或需要重建,创建新客户端
|
||||
settings := s.resolvePoolSettings(isolation, accountConcurrency)
|
||||
transport, err := buildUpstreamTransport(settings, parsedProxy)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return nil, fmt.Errorf("build transport: %w", err)
|
||||
}
|
||||
client := &http.Client{Transport: transport}
|
||||
if s.shouldValidateResolvedIP() {
|
||||
client.CheckRedirect = s.redirectChecker
|
||||
}
|
||||
entry := &upstreamClientEntry{
|
||||
client: client,
|
||||
proxyKey: proxyKey,
|
||||
poolKey: poolKey,
|
||||
}
|
||||
atomic.StoreInt64(&entry.lastUsed, nowUnix)
|
||||
if markInFlight {
|
||||
atomic.StoreInt64(&entry.inFlight, 1)
|
||||
}
|
||||
s.clients[cacheKey] = entry
|
||||
|
||||
// 执行淘汰策略:先淘汰空闲超时的,再淘汰超出数量限制的
|
||||
s.evictIdleLocked(now)
|
||||
s.evictOverLimitLocked()
|
||||
s.mu.Unlock()
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
// shouldReuseEntry 判断缓存条目是否可复用
|
||||
// 若代理或连接池配置发生变化,则需要重建客户端
|
||||
func (s *httpUpstreamService) shouldReuseEntry(entry *upstreamClientEntry, isolation, proxyKey, poolKey string) bool {
|
||||
if entry == nil {
|
||||
return false
|
||||
}
|
||||
if isolation == config.ConnectionPoolIsolationAccount && entry.proxyKey != proxyKey {
|
||||
return false
|
||||
}
|
||||
if entry.poolKey != poolKey {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// removeClientLocked 移除客户端(需持有锁)
|
||||
// 从缓存中删除并关闭空闲连接
|
||||
//
|
||||
// 参数:
|
||||
// - key: 缓存键
|
||||
// - entry: 客户端条目
|
||||
func (s *httpUpstreamService) removeClientLocked(key string, entry *upstreamClientEntry) {
|
||||
delete(s.clients, key)
|
||||
if entry != nil && entry.client != nil {
|
||||
// 关闭空闲连接,释放系统资源
|
||||
// 注意:这不会中断活跃连接
|
||||
entry.client.CloseIdleConnections()
|
||||
}
|
||||
}
|
||||
|
||||
// evictIdleLocked 淘汰空闲超时的客户端(需持有锁)
|
||||
// 遍历所有客户端,移除超过 TTL 且无活跃请求的条目
|
||||
//
|
||||
// 参数:
|
||||
// - now: 当前时间
|
||||
func (s *httpUpstreamService) evictIdleLocked(now time.Time) {
|
||||
ttl := s.clientIdleTTL()
|
||||
if ttl <= 0 {
|
||||
return
|
||||
}
|
||||
// 计算淘汰截止时间
|
||||
cutoff := now.Add(-ttl).UnixNano()
|
||||
for key, entry := range s.clients {
|
||||
// 跳过有活跃请求的客户端
|
||||
if atomic.LoadInt64(&entry.inFlight) != 0 {
|
||||
continue
|
||||
}
|
||||
// 淘汰超时的空闲客户端
|
||||
if atomic.LoadInt64(&entry.lastUsed) <= cutoff {
|
||||
s.removeClientLocked(key, entry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// evictOldestIdleLocked 淘汰最久未使用且无活跃请求的客户端(需持有锁)
|
||||
func (s *httpUpstreamService) evictOldestIdleLocked() bool {
|
||||
var (
|
||||
oldestKey string
|
||||
oldestEntry *upstreamClientEntry
|
||||
oldestTime int64
|
||||
)
|
||||
// 查找最久未使用且无活跃请求的客户端
|
||||
for key, entry := range s.clients {
|
||||
// 跳过有活跃请求的客户端
|
||||
if atomic.LoadInt64(&entry.inFlight) != 0 {
|
||||
continue
|
||||
}
|
||||
lastUsed := atomic.LoadInt64(&entry.lastUsed)
|
||||
if oldestEntry == nil || lastUsed < oldestTime {
|
||||
oldestKey = key
|
||||
oldestEntry = entry
|
||||
oldestTime = lastUsed
|
||||
}
|
||||
}
|
||||
// 所有客户端都有活跃请求,无法淘汰
|
||||
if oldestEntry == nil {
|
||||
return false
|
||||
}
|
||||
s.removeClientLocked(oldestKey, oldestEntry)
|
||||
return true
|
||||
}
|
||||
|
||||
// evictOverLimitLocked 淘汰超出数量限制的客户端(需持有锁)
|
||||
// 使用 LRU 策略,优先淘汰最久未使用且无活跃请求的客户端
|
||||
func (s *httpUpstreamService) evictOverLimitLocked() bool {
|
||||
maxClients := s.maxUpstreamClients()
|
||||
if maxClients <= 0 {
|
||||
return false
|
||||
}
|
||||
evicted := false
|
||||
// 循环淘汰直到满足数量限制
|
||||
for len(s.clients) > maxClients {
|
||||
if !s.evictOldestIdleLocked() {
|
||||
return evicted
|
||||
}
|
||||
evicted = true
|
||||
}
|
||||
return evicted
|
||||
}
|
||||
|
||||
// getIsolationMode 获取连接池隔离模式
|
||||
// 从配置中读取,无效值回退到 account_proxy 模式
|
||||
//
|
||||
// 返回:
|
||||
// - string: 隔离模式(proxy/account/account_proxy)
|
||||
func (s *httpUpstreamService) getIsolationMode() string {
|
||||
if s.cfg == nil {
|
||||
return config.ConnectionPoolIsolationAccountProxy
|
||||
}
|
||||
mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.ConnectionPoolIsolation))
|
||||
if mode == "" {
|
||||
return config.ConnectionPoolIsolationAccountProxy
|
||||
}
|
||||
switch mode {
|
||||
case config.ConnectionPoolIsolationProxy, config.ConnectionPoolIsolationAccount, config.ConnectionPoolIsolationAccountProxy:
|
||||
return mode
|
||||
default:
|
||||
return config.ConnectionPoolIsolationAccountProxy
|
||||
}
|
||||
}
|
||||
|
||||
// maxUpstreamClients 获取最大客户端缓存数量
|
||||
// 从配置中读取,无效值使用默认值
|
||||
func (s *httpUpstreamService) maxUpstreamClients() int {
|
||||
if s.cfg == nil {
|
||||
return defaultMaxUpstreamClients
|
||||
}
|
||||
if s.cfg.Gateway.MaxUpstreamClients > 0 {
|
||||
return s.cfg.Gateway.MaxUpstreamClients
|
||||
}
|
||||
return defaultMaxUpstreamClients
|
||||
}
|
||||
|
||||
// clientIdleTTL 获取客户端空闲回收阈值
|
||||
// 从配置中读取,无效值使用默认值
|
||||
func (s *httpUpstreamService) clientIdleTTL() time.Duration {
|
||||
if s.cfg == nil {
|
||||
return time.Duration(defaultClientIdleTTLSeconds) * time.Second
|
||||
}
|
||||
if s.cfg.Gateway.ClientIdleTTLSeconds > 0 {
|
||||
return time.Duration(s.cfg.Gateway.ClientIdleTTLSeconds) * time.Second
|
||||
}
|
||||
return time.Duration(defaultClientIdleTTLSeconds) * time.Second
|
||||
}
|
||||
|
||||
// resolvePoolSettings 解析连接池配置
|
||||
// 根据隔离策略和账户并发数动态调整连接池参数
|
||||
//
|
||||
// 参数:
|
||||
// - isolation: 隔离模式
|
||||
// - accountConcurrency: 账户并发限制
|
||||
//
|
||||
// 返回:
|
||||
// - poolSettings: 连接池配置
|
||||
//
|
||||
// 说明:
|
||||
// - 账户隔离模式下,连接池大小与账户并发数对应
|
||||
// - 这确保了单账户不会占用过多连接资源
|
||||
func (s *httpUpstreamService) resolvePoolSettings(isolation string, accountConcurrency int) poolSettings {
|
||||
settings := defaultPoolSettings(s.cfg)
|
||||
// 账户隔离模式下,根据账户并发数调整连接池大小
|
||||
if (isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy) && accountConcurrency > 0 {
|
||||
settings.maxIdleConns = accountConcurrency
|
||||
settings.maxIdleConnsPerHost = accountConcurrency
|
||||
settings.maxConnsPerHost = accountConcurrency
|
||||
}
|
||||
return settings
|
||||
}
|
||||
|
||||
// buildPoolKey 构建连接池配置键
|
||||
// 用于检测配置变更,配置变更时需要重建客户端
|
||||
//
|
||||
// 参数:
|
||||
// - isolation: 隔离模式
|
||||
// - accountConcurrency: 账户并发限制
|
||||
//
|
||||
// 返回:
|
||||
// - string: 配置键
|
||||
func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency int) string {
|
||||
if isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy {
|
||||
if accountConcurrency > 0 {
|
||||
return fmt.Sprintf("account:%d", accountConcurrency)
|
||||
}
|
||||
}
|
||||
return "default"
|
||||
}
|
||||
|
||||
// buildCacheKey 构建客户端缓存键
|
||||
// 根据隔离策略决定缓存键的组成
|
||||
//
|
||||
// 参数:
|
||||
// - isolation: 隔离模式
|
||||
// - proxyKey: 代理标识
|
||||
// - accountID: 账户 ID
|
||||
//
|
||||
// 返回:
|
||||
// - string: 缓存键
|
||||
//
|
||||
// 缓存键格式:
|
||||
// - proxy 模式: "proxy:{proxyKey}"
|
||||
// - account 模式: "account:{accountID}"
|
||||
// - account_proxy 模式: "account:{accountID}|proxy:{proxyKey}"
|
||||
func buildCacheKey(isolation, proxyKey string, accountID int64) string {
|
||||
switch isolation {
|
||||
case config.ConnectionPoolIsolationAccount:
|
||||
return fmt.Sprintf("account:%d", accountID)
|
||||
case config.ConnectionPoolIsolationAccountProxy:
|
||||
return fmt.Sprintf("account:%d|proxy:%s", accountID, proxyKey)
|
||||
default:
|
||||
return fmt.Sprintf("proxy:%s", proxyKey)
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeProxyURL 标准化代理 URL
|
||||
// 处理空值和解析错误,返回标准化的键和解析后的 URL
|
||||
//
|
||||
// 参数:
|
||||
// - raw: 原始代理 URL 字符串
|
||||
//
|
||||
// 返回:
|
||||
// - string: 标准化的代理键(空或解析失败返回 "direct")
|
||||
// - *url.URL: 解析后的 URL(空或解析失败返回 nil)
|
||||
func normalizeProxyURL(raw string) (string, *url.URL) {
|
||||
proxyURL := strings.TrimSpace(raw)
|
||||
if proxyURL == "" {
|
||||
return directProxyKey, nil
|
||||
}
|
||||
parsed, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return directProxyKey, nil
|
||||
}
|
||||
parsed.Scheme = strings.ToLower(parsed.Scheme)
|
||||
parsed.Host = strings.ToLower(parsed.Host)
|
||||
parsed.Path = ""
|
||||
parsed.RawPath = ""
|
||||
parsed.RawQuery = ""
|
||||
parsed.Fragment = ""
|
||||
parsed.ForceQuery = false
|
||||
if hostname := parsed.Hostname(); hostname != "" {
|
||||
port := parsed.Port()
|
||||
if (parsed.Scheme == "http" && port == "80") || (parsed.Scheme == "https" && port == "443") {
|
||||
port = ""
|
||||
}
|
||||
hostname = strings.ToLower(hostname)
|
||||
if port != "" {
|
||||
parsed.Host = net.JoinHostPort(hostname, port)
|
||||
} else {
|
||||
parsed.Host = hostname
|
||||
}
|
||||
}
|
||||
return parsed.String(), parsed
|
||||
}
|
||||
|
||||
// defaultPoolSettings 获取默认连接池配置
|
||||
// 从全局配置中读取,无效值使用常量默认值
|
||||
//
|
||||
// 参数:
|
||||
// - cfg: 全局配置
|
||||
//
|
||||
// 返回:
|
||||
// - poolSettings: 连接池配置
|
||||
func defaultPoolSettings(cfg *config.Config) poolSettings {
|
||||
maxIdleConns := defaultMaxIdleConns
|
||||
maxIdleConnsPerHost := defaultMaxIdleConnsPerHost
|
||||
maxConnsPerHost := defaultMaxConnsPerHost
|
||||
idleConnTimeout := defaultIdleConnTimeout
|
||||
responseHeaderTimeout := defaultResponseHeaderTimeout
|
||||
|
||||
if cfg != nil {
|
||||
if cfg.Gateway.MaxIdleConns > 0 {
|
||||
maxIdleConns = cfg.Gateway.MaxIdleConns
|
||||
}
|
||||
if cfg.Gateway.MaxIdleConnsPerHost > 0 {
|
||||
maxIdleConnsPerHost = cfg.Gateway.MaxIdleConnsPerHost
|
||||
}
|
||||
if cfg.Gateway.MaxConnsPerHost >= 0 {
|
||||
maxConnsPerHost = cfg.Gateway.MaxConnsPerHost
|
||||
}
|
||||
if cfg.Gateway.IdleConnTimeoutSeconds > 0 {
|
||||
idleConnTimeout = time.Duration(cfg.Gateway.IdleConnTimeoutSeconds) * time.Second
|
||||
}
|
||||
if cfg.Gateway.ResponseHeaderTimeout > 0 {
|
||||
responseHeaderTimeout = time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
return poolSettings{
|
||||
maxIdleConns: maxIdleConns,
|
||||
maxIdleConnsPerHost: maxIdleConnsPerHost,
|
||||
maxConnsPerHost: maxConnsPerHost,
|
||||
idleConnTimeout: idleConnTimeout,
|
||||
responseHeaderTimeout: responseHeaderTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// buildUpstreamTransport 构建上游请求的 Transport
|
||||
// 使用配置文件中的连接池参数,支持生产环境调优
|
||||
//
|
||||
// 参数:
|
||||
// - settings: 连接池配置
|
||||
// - proxyURL: 代理 URL(nil 表示直连)
|
||||
//
|
||||
// 返回:
|
||||
// - *http.Transport: 配置好的 Transport 实例
|
||||
// - error: 代理配置错误
|
||||
//
|
||||
// Transport 参数说明:
|
||||
// - MaxIdleConns: 所有主机的最大空闲连接总数
|
||||
// - MaxIdleConnsPerHost: 每主机最大空闲连接数(影响连接复用率)
|
||||
// - MaxConnsPerHost: 每主机最大连接数(达到后新请求等待)
|
||||
// - IdleConnTimeout: 空闲连接超时(超时后关闭)
|
||||
// - ResponseHeaderTimeout: 等待响应头超时(不影响流式传输)
|
||||
func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Transport, error) {
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: settings.maxIdleConns,
|
||||
MaxIdleConnsPerHost: settings.maxIdleConnsPerHost,
|
||||
MaxConnsPerHost: settings.maxConnsPerHost,
|
||||
IdleConnTimeout: settings.idleConnTimeout,
|
||||
ResponseHeaderTimeout: settings.responseHeaderTimeout,
|
||||
}
|
||||
if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
// trackedBody 带跟踪功能的响应体包装器
|
||||
// 在 Close 时执行回调,用于更新请求计数
|
||||
type trackedBody struct {
|
||||
io.ReadCloser // 原始响应体
|
||||
once sync.Once
|
||||
onClose func() // 关闭时的回调函数
|
||||
}
|
||||
|
||||
// Close 关闭响应体并执行回调
|
||||
// 使用 sync.Once 确保回调只执行一次
|
||||
func (b *trackedBody) Close() error {
|
||||
err := b.ReadCloser.Close()
|
||||
if b.onClose != nil {
|
||||
b.once.Do(b.onClose)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// wrapTrackedBody 包装响应体以跟踪关闭事件
|
||||
// 用于在响应体关闭时更新 inFlight 计数
|
||||
//
|
||||
// 参数:
|
||||
// - body: 原始响应体
|
||||
// - onClose: 关闭时的回调函数
|
||||
//
|
||||
// 返回:
|
||||
// - io.ReadCloser: 包装后的响应体
|
||||
func wrapTrackedBody(body io.ReadCloser, onClose func()) io.ReadCloser {
|
||||
if body == nil {
|
||||
return body
|
||||
}
|
||||
return &trackedBody{ReadCloser: body, onClose: onClose}
|
||||
}
|
||||
70
backend/internal/repository/http_upstream_benchmark_test.go
Normal file
70
backend/internal/repository/http_upstream_benchmark_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
// httpClientSink 用于防止编译器优化掉基准测试中的赋值操作
|
||||
// 这是 Go 基准测试的常见模式,确保测试结果准确
|
||||
var httpClientSink *http.Client
|
||||
|
||||
// BenchmarkHTTPUpstreamProxyClient 对比重复创建与复用代理客户端的开销
|
||||
//
|
||||
// 测试目的:
|
||||
// - 验证连接池复用相比每次新建的性能提升
|
||||
// - 量化内存分配差异
|
||||
//
|
||||
// 预期结果:
|
||||
// - "复用" 子测试应显著快于 "新建"
|
||||
// - "复用" 子测试应零内存分配
|
||||
func BenchmarkHTTPUpstreamProxyClient(b *testing.B) {
|
||||
// 创建测试配置
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{ResponseHeaderTimeout: 300},
|
||||
}
|
||||
upstream := NewHTTPUpstream(cfg)
|
||||
svc, ok := upstream.(*httpUpstreamService)
|
||||
if !ok {
|
||||
b.Fatalf("类型断言失败,无法获取 httpUpstreamService")
|
||||
}
|
||||
|
||||
proxyURL := "http://127.0.0.1:8080"
|
||||
b.ReportAllocs() // 报告内存分配统计
|
||||
|
||||
// 子测试:每次新建客户端
|
||||
// 模拟未优化前的行为,每次请求都创建新的 http.Client
|
||||
b.Run("新建", func(b *testing.B) {
|
||||
parsedProxy, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
b.Fatalf("解析代理地址失败: %v", err)
|
||||
}
|
||||
settings := defaultPoolSettings(cfg)
|
||||
for i := 0; i < b.N; i++ {
|
||||
// 每次迭代都创建新客户端,包含 Transport 分配
|
||||
transport, err := buildUpstreamTransport(settings, parsedProxy)
|
||||
if err != nil {
|
||||
b.Fatalf("创建 Transport 失败: %v", err)
|
||||
}
|
||||
httpClientSink = &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// 子测试:复用已缓存的客户端
|
||||
// 模拟优化后的行为,从缓存获取客户端
|
||||
b.Run("复用", func(b *testing.B) {
|
||||
// 预热:确保客户端已缓存
|
||||
entry := svc.getOrCreateClient(proxyURL, 1, 1)
|
||||
client := entry.client
|
||||
b.ResetTimer() // 重置计时器,排除预热时间
|
||||
for i := 0; i < b.N; i++ {
|
||||
// 直接使用缓存的客户端,无内存分配
|
||||
httpClientSink = client
|
||||
}
|
||||
})
|
||||
}
|
||||
291
backend/internal/repository/http_upstream_test.go
Normal file
291
backend/internal/repository/http_upstream_test.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// HTTPUpstreamSuite HTTP 上游服务测试套件
|
||||
// 使用 testify/suite 组织测试,支持 SetupTest 初始化
|
||||
type HTTPUpstreamSuite struct {
|
||||
suite.Suite
|
||||
cfg *config.Config // 测试用配置
|
||||
}
|
||||
|
||||
// SetupTest 每个测试用例执行前的初始化
|
||||
// 创建空配置,各测试用例可按需覆盖
|
||||
func (s *HTTPUpstreamSuite) SetupTest() {
|
||||
s.cfg = &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
URLAllowlist: config.URLAllowlistConfig{
|
||||
AllowPrivateHosts: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// newService 创建测试用的 httpUpstreamService 实例
|
||||
// 返回具体类型以便访问内部状态进行断言
|
||||
func (s *HTTPUpstreamSuite) newService() *httpUpstreamService {
|
||||
up := NewHTTPUpstream(s.cfg)
|
||||
svc, ok := up.(*httpUpstreamService)
|
||||
require.True(s.T(), ok, "expected *httpUpstreamService")
|
||||
return svc
|
||||
}
|
||||
|
||||
// TestDefaultResponseHeaderTimeout 测试默认响应头超时配置
|
||||
// 验证未配置时使用 300 秒默认值
|
||||
func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() {
|
||||
svc := s.newService()
|
||||
entry := svc.getOrCreateClient("", 0, 0)
|
||||
transport, ok := entry.client.Transport.(*http.Transport)
|
||||
require.True(s.T(), ok, "expected *http.Transport")
|
||||
require.Equal(s.T(), 300*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
|
||||
}
|
||||
|
||||
// TestCustomResponseHeaderTimeout 测试自定义响应头超时配置
|
||||
// 验证配置值能正确应用到 Transport
|
||||
func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() {
|
||||
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 7}
|
||||
svc := s.newService()
|
||||
entry := svc.getOrCreateClient("", 0, 0)
|
||||
transport, ok := entry.client.Transport.(*http.Transport)
|
||||
require.True(s.T(), ok, "expected *http.Transport")
|
||||
require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
|
||||
}
|
||||
|
||||
// TestGetOrCreateClient_InvalidURLFallsBackToDirect 测试无效代理 URL 回退
|
||||
// 验证解析失败时回退到直连模式
|
||||
func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLFallsBackToDirect() {
|
||||
svc := s.newService()
|
||||
entry := svc.getOrCreateClient("://bad-proxy-url", 1, 1)
|
||||
require.Equal(s.T(), directProxyKey, entry.proxyKey, "expected direct proxy fallback")
|
||||
}
|
||||
|
||||
// TestNormalizeProxyURL_Canonicalizes 测试代理 URL 规范化
|
||||
// 验证等价地址能够映射到同一缓存键
|
||||
func (s *HTTPUpstreamSuite) TestNormalizeProxyURL_Canonicalizes() {
|
||||
key1, _ := normalizeProxyURL("http://proxy.local:8080")
|
||||
key2, _ := normalizeProxyURL("http://proxy.local:8080/")
|
||||
require.Equal(s.T(), key1, key2, "expected normalized proxy keys to match")
|
||||
}
|
||||
|
||||
// TestAcquireClient_OverLimitReturnsError 测试连接池缓存上限保护
|
||||
// 验证超限且无可淘汰条目时返回错误
|
||||
func (s *HTTPUpstreamSuite) TestAcquireClient_OverLimitReturnsError() {
|
||||
s.cfg.Gateway = config.GatewayConfig{
|
||||
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy,
|
||||
MaxUpstreamClients: 1,
|
||||
}
|
||||
svc := s.newService()
|
||||
entry1, err := svc.acquireClient("http://proxy-a:8080", 1, 1)
|
||||
require.NoError(s.T(), err, "expected first acquire to succeed")
|
||||
require.NotNil(s.T(), entry1, "expected entry")
|
||||
|
||||
entry2, err := svc.acquireClient("http://proxy-b:8080", 2, 1)
|
||||
require.Error(s.T(), err, "expected error when cache limit reached")
|
||||
require.Nil(s.T(), entry2, "expected nil entry when cache limit reached")
|
||||
}
|
||||
|
||||
// TestDo_WithoutProxy_GoesDirect 测试无代理时直连
|
||||
// 验证空代理 URL 时请求直接发送到目标服务器
|
||||
func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() {
|
||||
// 创建模拟上游服务器
|
||||
upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "direct")
|
||||
}))
|
||||
s.T().Cleanup(upstream.Close)
|
||||
|
||||
up := NewHTTPUpstream(s.cfg)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, upstream.URL+"/x", nil)
|
||||
require.NoError(s.T(), err, "NewRequest")
|
||||
resp, err := up.Do(req, "", 1, 1)
|
||||
require.NoError(s.T(), err, "Do")
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
require.Equal(s.T(), "direct", string(b), "unexpected body")
|
||||
}
|
||||
|
||||
// TestDo_WithHTTPProxy_UsesProxy 测试 HTTP 代理功能
|
||||
// 验证请求通过代理服务器转发,使用绝对 URI 格式
|
||||
func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
|
||||
// 用于接收代理请求的通道
|
||||
seen := make(chan string, 1)
|
||||
// 创建模拟代理服务器
|
||||
proxySrv := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
seen <- r.RequestURI // 记录请求 URI
|
||||
_, _ = io.WriteString(w, "proxied")
|
||||
}))
|
||||
s.T().Cleanup(proxySrv.Close)
|
||||
|
||||
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 1}
|
||||
up := NewHTTPUpstream(s.cfg)
|
||||
|
||||
// 发送请求到外部地址,应通过代理
|
||||
req, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil)
|
||||
require.NoError(s.T(), err, "NewRequest")
|
||||
resp, err := up.Do(req, proxySrv.URL, 1, 1)
|
||||
require.NoError(s.T(), err, "Do")
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
require.Equal(s.T(), "proxied", string(b), "unexpected body")
|
||||
|
||||
// 验证代理收到的是绝对 URI 格式(HTTP 代理规范要求)
|
||||
select {
|
||||
case uri := <-seen:
|
||||
require.Equal(s.T(), "http://example.com/test", uri, "expected absolute-form request URI")
|
||||
default:
|
||||
require.Fail(s.T(), "expected proxy to receive request")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDo_EmptyProxy_UsesDirect 测试空代理字符串
|
||||
// 验证空字符串代理等同于直连
|
||||
func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() {
|
||||
upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "direct-empty")
|
||||
}))
|
||||
s.T().Cleanup(upstream.Close)
|
||||
|
||||
up := NewHTTPUpstream(s.cfg)
|
||||
req, err := http.NewRequest(http.MethodGet, upstream.URL+"/y", nil)
|
||||
require.NoError(s.T(), err, "NewRequest")
|
||||
resp, err := up.Do(req, "", 1, 1)
|
||||
require.NoError(s.T(), err, "Do with empty proxy")
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
require.Equal(s.T(), "direct-empty", string(b))
|
||||
}
|
||||
|
||||
// TestAccountIsolation_DifferentAccounts 测试账户隔离模式
|
||||
// 验证不同账户使用独立的连接池
|
||||
func (s *HTTPUpstreamSuite) TestAccountIsolation_DifferentAccounts() {
|
||||
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
|
||||
svc := s.newService()
|
||||
// 同一代理,不同账户
|
||||
entry1 := svc.getOrCreateClient("http://proxy.local:8080", 1, 3)
|
||||
entry2 := svc.getOrCreateClient("http://proxy.local:8080", 2, 3)
|
||||
require.NotSame(s.T(), entry1, entry2, "不同账号不应共享连接池")
|
||||
require.Equal(s.T(), 2, len(svc.clients), "账号隔离应缓存两个客户端")
|
||||
}
|
||||
|
||||
// TestAccountProxyIsolation_DifferentProxy 测试账户+代理组合隔离模式
|
||||
// 验证同一账户使用不同代理时创建独立连接池
|
||||
func (s *HTTPUpstreamSuite) TestAccountProxyIsolation_DifferentProxy() {
|
||||
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy}
|
||||
svc := s.newService()
|
||||
// 同一账户,不同代理
|
||||
entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3)
|
||||
entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3)
|
||||
require.NotSame(s.T(), entry1, entry2, "账号+代理隔离应区分不同代理")
|
||||
require.Equal(s.T(), 2, len(svc.clients), "账号+代理隔离应缓存两个客户端")
|
||||
}
|
||||
|
||||
// TestAccountModeProxyChangeClearsPool 测试账户模式下代理变更
|
||||
// 验证账户切换代理时清理旧连接池,避免复用错误代理
|
||||
func (s *HTTPUpstreamSuite) TestAccountModeProxyChangeClearsPool() {
|
||||
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
|
||||
svc := s.newService()
|
||||
// 同一账户,先后使用不同代理
|
||||
entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3)
|
||||
entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3)
|
||||
require.NotSame(s.T(), entry1, entry2, "账号切换代理应创建新连接池")
|
||||
require.Equal(s.T(), 1, len(svc.clients), "账号模式下应仅保留一个连接池")
|
||||
require.False(s.T(), hasEntry(svc, entry1), "旧连接池应被清理")
|
||||
}
|
||||
|
||||
// TestAccountConcurrencyOverridesPoolSettings 测试账户并发数覆盖连接池配置
|
||||
// 验证账户隔离模式下,连接池大小与账户并发数对应
|
||||
func (s *HTTPUpstreamSuite) TestAccountConcurrencyOverridesPoolSettings() {
|
||||
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
|
||||
svc := s.newService()
|
||||
// 账户并发数为 12
|
||||
entry := svc.getOrCreateClient("", 1, 12)
|
||||
transport, ok := entry.client.Transport.(*http.Transport)
|
||||
require.True(s.T(), ok, "expected *http.Transport")
|
||||
// 连接池参数应与并发数一致
|
||||
require.Equal(s.T(), 12, transport.MaxConnsPerHost, "MaxConnsPerHost mismatch")
|
||||
require.Equal(s.T(), 12, transport.MaxIdleConns, "MaxIdleConns mismatch")
|
||||
require.Equal(s.T(), 12, transport.MaxIdleConnsPerHost, "MaxIdleConnsPerHost mismatch")
|
||||
}
|
||||
|
||||
// TestAccountConcurrencyFallbackToDefault 测试账户并发数为 0 时回退到默认配置
|
||||
// 验证未指定并发数时使用全局配置值
|
||||
func (s *HTTPUpstreamSuite) TestAccountConcurrencyFallbackToDefault() {
|
||||
s.cfg.Gateway = config.GatewayConfig{
|
||||
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount,
|
||||
MaxIdleConns: 77,
|
||||
MaxIdleConnsPerHost: 55,
|
||||
MaxConnsPerHost: 66,
|
||||
}
|
||||
svc := s.newService()
|
||||
// 账户并发数为 0,应使用全局配置
|
||||
entry := svc.getOrCreateClient("", 1, 0)
|
||||
transport, ok := entry.client.Transport.(*http.Transport)
|
||||
require.True(s.T(), ok, "expected *http.Transport")
|
||||
require.Equal(s.T(), 66, transport.MaxConnsPerHost, "MaxConnsPerHost fallback mismatch")
|
||||
require.Equal(s.T(), 77, transport.MaxIdleConns, "MaxIdleConns fallback mismatch")
|
||||
require.Equal(s.T(), 55, transport.MaxIdleConnsPerHost, "MaxIdleConnsPerHost fallback mismatch")
|
||||
}
|
||||
|
||||
// TestEvictOverLimitRemovesOldestIdle 测试超出数量限制时的 LRU 淘汰
|
||||
// 验证优先淘汰最久未使用的空闲客户端
|
||||
func (s *HTTPUpstreamSuite) TestEvictOverLimitRemovesOldestIdle() {
|
||||
s.cfg.Gateway = config.GatewayConfig{
|
||||
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy,
|
||||
MaxUpstreamClients: 2, // 最多缓存 2 个客户端
|
||||
}
|
||||
svc := s.newService()
|
||||
// 创建两个客户端,设置不同的最后使用时间
|
||||
entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 1)
|
||||
entry2 := svc.getOrCreateClient("http://proxy-b:8080", 2, 1)
|
||||
atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Hour).UnixNano()) // 最久
|
||||
atomic.StoreInt64(&entry2.lastUsed, time.Now().Add(-time.Hour).UnixNano())
|
||||
// 创建第三个客户端,触发淘汰
|
||||
_ = svc.getOrCreateClient("http://proxy-c:8080", 3, 1)
|
||||
|
||||
require.LessOrEqual(s.T(), len(svc.clients), 2, "应保持在缓存上限内")
|
||||
require.False(s.T(), hasEntry(svc, entry1), "最久未使用的连接池应被清理")
|
||||
}
|
||||
|
||||
// TestIdleTTLDoesNotEvictActive 测试活跃请求保护
|
||||
// 验证有进行中请求的客户端不会被空闲超时淘汰
|
||||
func (s *HTTPUpstreamSuite) TestIdleTTLDoesNotEvictActive() {
|
||||
s.cfg.Gateway = config.GatewayConfig{
|
||||
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount,
|
||||
ClientIdleTTLSeconds: 1, // 1 秒空闲超时
|
||||
}
|
||||
svc := s.newService()
|
||||
entry1 := svc.getOrCreateClient("", 1, 1)
|
||||
// 设置为很久之前使用,但有活跃请求
|
||||
atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Minute).UnixNano())
|
||||
atomic.StoreInt64(&entry1.inFlight, 1) // 模拟有活跃请求
|
||||
// 创建新客户端,触发淘汰检查
|
||||
_ = svc.getOrCreateClient("", 2, 1)
|
||||
|
||||
require.True(s.T(), hasEntry(svc, entry1), "有活跃请求时不应回收")
|
||||
}
|
||||
|
||||
// TestHTTPUpstreamSuite 运行测试套件
|
||||
func TestHTTPUpstreamSuite(t *testing.T) {
|
||||
suite.Run(t, new(HTTPUpstreamSuite))
|
||||
}
|
||||
|
||||
// hasEntry 检查客户端是否存在于缓存中
|
||||
// 辅助函数,用于验证淘汰逻辑
|
||||
func hasEntry(svc *httpUpstreamService, target *upstreamClientEntry) bool {
|
||||
for _, entry := range svc.clients {
|
||||
if entry == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
51
backend/internal/repository/identity_cache.go
Normal file
51
backend/internal/repository/identity_cache.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
fingerprintKeyPrefix = "fingerprint:"
|
||||
fingerprintTTL = 24 * time.Hour
|
||||
)
|
||||
|
||||
// fingerprintKey generates the Redis key for account fingerprint cache.
|
||||
func fingerprintKey(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
type identityCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewIdentityCache(rdb *redis.Client) service.IdentityCache {
|
||||
return &identityCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*service.Fingerprint, error) {
|
||||
key := fingerprintKey(accountID)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var fp service.Fingerprint
|
||||
if err := json.Unmarshal([]byte(val), &fp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &fp, nil
|
||||
}
|
||||
|
||||
func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *service.Fingerprint) error {
|
||||
key := fingerprintKey(accountID)
|
||||
val, err := json.Marshal(fp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.rdb.Set(ctx, key, val, fingerprintTTL).Err()
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type IdentityCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache *identityCache
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewIdentityCache(s.rdb).(*identityCache)
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestGetFingerprint_Missing() {
|
||||
_, err := s.cache.GetFingerprint(s.ctx, 1)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing fingerprint")
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestSetAndGetFingerprint() {
|
||||
fp := &service.Fingerprint{ClientID: "c1", UserAgent: "ua"}
|
||||
require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 1, fp), "SetFingerprint")
|
||||
gotFP, err := s.cache.GetFingerprint(s.ctx, 1)
|
||||
require.NoError(s.T(), err, "GetFingerprint")
|
||||
require.Equal(s.T(), "c1", gotFP.ClientID)
|
||||
require.Equal(s.T(), "ua", gotFP.UserAgent)
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestFingerprint_TTL() {
|
||||
fp := &service.Fingerprint{ClientID: "c1", UserAgent: "ua"}
|
||||
require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 2, fp))
|
||||
|
||||
fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 2)
|
||||
ttl, err := s.rdb.TTL(s.ctx, fpKey).Result()
|
||||
require.NoError(s.T(), err, "TTL fpKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, fingerprintTTL)
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestGetFingerprint_JSONCorruption() {
|
||||
fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 999)
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, fpKey, "invalid-json-data", 1*time.Minute).Err(), "Set invalid JSON")
|
||||
|
||||
_, err := s.cache.GetFingerprint(s.ctx, 999)
|
||||
require.Error(s.T(), err, "expected error for corrupted JSON")
|
||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil")
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestSetFingerprint_Nil() {
|
||||
err := s.cache.SetFingerprint(s.ctx, 100, nil)
|
||||
require.NoError(s.T(), err, "SetFingerprint(nil) should succeed")
|
||||
}
|
||||
|
||||
func TestIdentityCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(IdentityCacheSuite))
|
||||
}
|
||||
46
backend/internal/repository/identity_cache_test.go
Normal file
46
backend/internal/repository/identity_cache_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFingerprintKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_account_id",
|
||||
accountID: 123,
|
||||
expected: "fingerprint:123",
|
||||
},
|
||||
{
|
||||
name: "zero_account_id",
|
||||
accountID: 0,
|
||||
expected: "fingerprint:0",
|
||||
},
|
||||
{
|
||||
name: "negative_account_id",
|
||||
accountID: -1,
|
||||
expected: "fingerprint:-1",
|
||||
},
|
||||
{
|
||||
name: "max_int64",
|
||||
accountID: math.MaxInt64,
|
||||
expected: "fingerprint:9223372036854775807",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := fingerprintKey(tc.accountID)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
63
backend/internal/repository/inprocess_transport_test.go
Normal file
63
backend/internal/repository/inprocess_transport_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
|
||||
|
||||
// newInProcessTransport adapts an http.HandlerFunc into an http.RoundTripper without opening sockets.
|
||||
// It captures the request body (if any) and then rewinds it before invoking the handler.
|
||||
func newInProcessTransport(handler http.HandlerFunc, capture func(r *http.Request, body []byte)) http.RoundTripper {
|
||||
return roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
var body []byte
|
||||
if r.Body != nil {
|
||||
body, _ = io.ReadAll(r.Body)
|
||||
_ = r.Body.Close()
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
}
|
||||
if capture != nil {
|
||||
capture(r, body)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
handler(rec, r)
|
||||
return rec.Result(), nil
|
||||
})
|
||||
}
|
||||
|
||||
var (
|
||||
canListenOnce sync.Once
|
||||
canListen bool
|
||||
canListenErr error
|
||||
)
|
||||
|
||||
func localListenerAvailable() bool {
|
||||
canListenOnce.Do(func() {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
canListenErr = err
|
||||
canListen = false
|
||||
return
|
||||
}
|
||||
_ = ln.Close()
|
||||
canListen = true
|
||||
})
|
||||
return canListen
|
||||
}
|
||||
|
||||
func newLocalTestServer(tb testing.TB, handler http.Handler) *httptest.Server {
|
||||
tb.Helper()
|
||||
if !localListenerAvailable() {
|
||||
tb.Skipf("local listeners are not permitted in this environment: %v", canListenErr)
|
||||
}
|
||||
return httptest.NewServer(handler)
|
||||
}
|
||||
408
backend/internal/repository/integration_harness_test.go
Normal file
408
backend/internal/repository/integration_harness_test.go
Normal file
@@ -0,0 +1,408 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
_ "github.com/Wei-Shaw/sub2api/ent/runtime"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
_ "github.com/lib/pq"
|
||||
redisclient "github.com/redis/go-redis/v9"
|
||||
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
|
||||
)
|
||||
|
||||
const (
|
||||
redisImageTag = "redis:8.4-alpine"
|
||||
postgresImageTag = "postgres:18.1-alpine3.23"
|
||||
)
|
||||
|
||||
var (
|
||||
integrationDB *sql.DB
|
||||
integrationEntClient *dbent.Client
|
||||
integrationRedis *redisclient.Client
|
||||
|
||||
redisNamespaceSeq uint64
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
ctx := context.Background()
|
||||
|
||||
if err := timezone.Init("UTC"); err != nil {
|
||||
log.Printf("failed to init timezone: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if !dockerIsAvailable(ctx) {
|
||||
// In CI we expect Docker to be available so integration tests should fail loudly.
|
||||
if os.Getenv("CI") != "" {
|
||||
log.Printf("docker is not available (CI=true); failing integration tests")
|
||||
os.Exit(1)
|
||||
}
|
||||
log.Printf("docker is not available; skipping integration tests (start Docker to enable)")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
postgresImage := selectDockerImage(ctx, postgresImageTag)
|
||||
pgContainer, err := tcpostgres.Run(
|
||||
ctx,
|
||||
postgresImage,
|
||||
tcpostgres.WithDatabase("sub2api_test"),
|
||||
tcpostgres.WithUsername("postgres"),
|
||||
tcpostgres.WithPassword("postgres"),
|
||||
tcpostgres.BasicWaitStrategies(),
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("failed to start postgres container: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer func() { _ = pgContainer.Terminate(ctx) }()
|
||||
|
||||
redisContainer, err := tcredis.Run(
|
||||
ctx,
|
||||
redisImageTag,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("failed to start redis container: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer func() { _ = redisContainer.Terminate(ctx) }()
|
||||
|
||||
dsn, err := pgContainer.ConnectionString(ctx, "sslmode=disable", "TimeZone=UTC")
|
||||
if err != nil {
|
||||
log.Printf("failed to get postgres dsn: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
integrationDB, err = openSQLWithRetry(ctx, dsn, 30*time.Second)
|
||||
if err != nil {
|
||||
log.Printf("failed to open sql db: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := ApplyMigrations(ctx, integrationDB); err != nil {
|
||||
log.Printf("failed to apply db migrations: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// 创建 ent client 用于集成测试
|
||||
drv := entsql.OpenDB(dialect.Postgres, integrationDB)
|
||||
integrationEntClient = dbent.NewClient(dbent.Driver(drv))
|
||||
|
||||
redisHost, err := redisContainer.Host(ctx)
|
||||
if err != nil {
|
||||
log.Printf("failed to get redis host: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp")
|
||||
if err != nil {
|
||||
log.Printf("failed to get redis port: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
integrationRedis = redisclient.NewClient(&redisclient.Options{
|
||||
Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()),
|
||||
DB: 0,
|
||||
})
|
||||
if err := integrationRedis.Ping(ctx).Err(); err != nil {
|
||||
log.Printf("failed to ping redis: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
code := m.Run()
|
||||
|
||||
_ = integrationEntClient.Close()
|
||||
_ = integrationRedis.Close()
|
||||
_ = integrationDB.Close()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func dockerIsAvailable(ctx context.Context) bool {
|
||||
cmd := exec.CommandContext(ctx, "docker", "info")
|
||||
cmd.Env = os.Environ()
|
||||
return cmd.Run() == nil
|
||||
}
|
||||
|
||||
func selectDockerImage(ctx context.Context, preferred string) string {
|
||||
if dockerImageExists(ctx, preferred) {
|
||||
return preferred
|
||||
}
|
||||
|
||||
return preferred
|
||||
}
|
||||
|
||||
func dockerImageExists(ctx context.Context, image string) bool {
|
||||
cmd := exec.CommandContext(ctx, "docker", "image", "inspect", image)
|
||||
cmd.Env = os.Environ()
|
||||
cmd.Stdout = nil
|
||||
cmd.Stderr = nil
|
||||
return cmd.Run() == nil
|
||||
}
|
||||
|
||||
func openSQLWithRetry(ctx context.Context, dsn string, timeout time.Duration) (*sql.DB, error) {
|
||||
deadline := time.Now().Add(timeout)
|
||||
var lastErr error
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := pingWithTimeout(ctx, db, 2*time.Second); err != nil {
|
||||
lastErr = err
|
||||
_ = db.Close()
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("db not ready after %s: %w", timeout, lastErr)
|
||||
}
|
||||
|
||||
func pingWithTimeout(ctx context.Context, db *sql.DB, timeout time.Duration) error {
|
||||
pingCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
return db.PingContext(pingCtx)
|
||||
}
|
||||
|
||||
func testTx(t *testing.T) *sql.Tx {
|
||||
t.Helper()
|
||||
|
||||
tx, err := integrationDB.BeginTx(context.Background(), nil)
|
||||
require.NoError(t, err, "begin tx")
|
||||
t.Cleanup(func() {
|
||||
_ = tx.Rollback()
|
||||
})
|
||||
return tx
|
||||
}
|
||||
|
||||
// testEntClient 返回全局的 ent client,用于测试需要内部管理事务的代码(如 Create/Update 方法)。
|
||||
// 注意:此 client 的操作会真正写入数据库,测试结束后不会自动回滚。
|
||||
func testEntClient(t *testing.T) *dbent.Client {
|
||||
t.Helper()
|
||||
return integrationEntClient
|
||||
}
|
||||
|
||||
// testEntTx 返回一个 ent 事务,用于需要事务隔离的测试。
|
||||
// 测试结束后会自动回滚,不会影响数据库状态。
|
||||
func testEntTx(t *testing.T) *dbent.Tx {
|
||||
t.Helper()
|
||||
|
||||
tx, err := integrationEntClient.Tx(context.Background())
|
||||
require.NoError(t, err, "begin ent tx")
|
||||
t.Cleanup(func() {
|
||||
_ = tx.Rollback()
|
||||
})
|
||||
return tx
|
||||
}
|
||||
|
||||
// testEntSQLTx 已弃用:不要在新测试中使用此函数。
|
||||
// 基于 *sql.Tx 创建的 ent client 在调用 client.Tx() 时会 panic。
|
||||
// 对于需要测试内部使用事务的代码,请使用 testEntClient。
|
||||
// 对于需要事务隔离的测试,请使用 testEntTx。
|
||||
//
|
||||
// Deprecated: Use testEntClient or testEntTx instead.
|
||||
func testEntSQLTx(t *testing.T) (*dbent.Client, *sql.Tx) {
|
||||
t.Helper()
|
||||
|
||||
// 直接失败,避免旧测试误用导致的事务嵌套 panic。
|
||||
t.Fatalf("testEntSQLTx 已弃用:请使用 testEntClient 或 testEntTx")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func testRedis(t *testing.T) *redisclient.Client {
|
||||
t.Helper()
|
||||
|
||||
prefix := fmt.Sprintf(
|
||||
"it:%s:%d:%d:",
|
||||
sanitizeRedisNamespace(t.Name()),
|
||||
time.Now().UnixNano(),
|
||||
atomic.AddUint64(&redisNamespaceSeq, 1),
|
||||
)
|
||||
|
||||
opts := *integrationRedis.Options()
|
||||
rdb := redisclient.NewClient(&opts)
|
||||
rdb.AddHook(prefixHook{prefix: prefix})
|
||||
|
||||
t.Cleanup(func() {
|
||||
ctx := context.Background()
|
||||
|
||||
var cursor uint64
|
||||
for {
|
||||
keys, nextCursor, err := integrationRedis.Scan(ctx, cursor, prefix+"*", 500).Result()
|
||||
require.NoError(t, err, "scan redis keys for cleanup")
|
||||
if len(keys) > 0 {
|
||||
require.NoError(t, integrationRedis.Unlink(ctx, keys...).Err(), "unlink redis keys for cleanup")
|
||||
}
|
||||
|
||||
cursor = nextCursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
_ = rdb.Close()
|
||||
})
|
||||
|
||||
return rdb
|
||||
}
|
||||
|
||||
func assertTTLWithin(t *testing.T, ttl time.Duration, min, max time.Duration) {
|
||||
t.Helper()
|
||||
require.GreaterOrEqual(t, ttl, min, "ttl should be >= min")
|
||||
require.LessOrEqual(t, ttl, max, "ttl should be <= max")
|
||||
}
|
||||
|
||||
func sanitizeRedisNamespace(name string) string {
|
||||
name = strings.ReplaceAll(name, "/", "_")
|
||||
name = strings.ReplaceAll(name, " ", "_")
|
||||
return name
|
||||
}
|
||||
|
||||
type prefixHook struct {
|
||||
prefix string
|
||||
}
|
||||
|
||||
func (h prefixHook) DialHook(next redisclient.DialHook) redisclient.DialHook { return next }
|
||||
|
||||
func (h prefixHook) ProcessHook(next redisclient.ProcessHook) redisclient.ProcessHook {
|
||||
return func(ctx context.Context, cmd redisclient.Cmder) error {
|
||||
h.prefixCmd(cmd)
|
||||
return next(ctx, cmd)
|
||||
}
|
||||
}
|
||||
|
||||
func (h prefixHook) ProcessPipelineHook(next redisclient.ProcessPipelineHook) redisclient.ProcessPipelineHook {
|
||||
return func(ctx context.Context, cmds []redisclient.Cmder) error {
|
||||
for _, cmd := range cmds {
|
||||
h.prefixCmd(cmd)
|
||||
}
|
||||
return next(ctx, cmds)
|
||||
}
|
||||
}
|
||||
|
||||
func (h prefixHook) prefixCmd(cmd redisclient.Cmder) {
|
||||
args := cmd.Args()
|
||||
if len(args) < 2 {
|
||||
return
|
||||
}
|
||||
|
||||
prefixOne := func(i int) {
|
||||
if i < 0 || i >= len(args) {
|
||||
return
|
||||
}
|
||||
|
||||
switch v := args[i].(type) {
|
||||
case string:
|
||||
if v != "" && !strings.HasPrefix(v, h.prefix) {
|
||||
args[i] = h.prefix + v
|
||||
}
|
||||
case []byte:
|
||||
s := string(v)
|
||||
if s != "" && !strings.HasPrefix(s, h.prefix) {
|
||||
args[i] = []byte(h.prefix + s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch strings.ToLower(cmd.Name()) {
|
||||
case "get", "set", "setnx", "setex", "psetex", "incr", "decr", "incrby", "expire", "pexpire", "ttl", "pttl",
|
||||
"hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists",
|
||||
"zadd", "zcard", "zrange", "zrangebyscore", "zrem", "zremrangebyscore", "zrevrange", "zrevrangebyscore", "zscore":
|
||||
prefixOne(1)
|
||||
case "del", "unlink":
|
||||
for i := 1; i < len(args); i++ {
|
||||
prefixOne(i)
|
||||
}
|
||||
case "eval", "evalsha", "eval_ro", "evalsha_ro":
|
||||
if len(args) < 3 {
|
||||
return
|
||||
}
|
||||
numKeys, err := strconv.Atoi(fmt.Sprint(args[2]))
|
||||
if err != nil || numKeys <= 0 {
|
||||
return
|
||||
}
|
||||
for i := 0; i < numKeys && 3+i < len(args); i++ {
|
||||
prefixOne(3 + i)
|
||||
}
|
||||
case "scan":
|
||||
for i := 2; i+1 < len(args); i++ {
|
||||
if strings.EqualFold(fmt.Sprint(args[i]), "match") {
|
||||
prefixOne(i + 1)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IntegrationRedisSuite provides a base suite for Redis integration tests.
|
||||
// Embedding suites should call SetupTest to initialize ctx and rdb.
|
||||
type IntegrationRedisSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
rdb *redisclient.Client
|
||||
}
|
||||
|
||||
// SetupTest initializes ctx and rdb for each test method.
|
||||
func (s *IntegrationRedisSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.rdb = testRedis(s.T())
|
||||
}
|
||||
|
||||
// RequireNoError is a convenience method wrapping require.NoError with s.T().
|
||||
func (s *IntegrationRedisSuite) RequireNoError(err error, msgAndArgs ...any) {
|
||||
s.T().Helper()
|
||||
require.NoError(s.T(), err, msgAndArgs...)
|
||||
}
|
||||
|
||||
// AssertTTLWithin asserts that ttl is within [min, max].
|
||||
func (s *IntegrationRedisSuite) AssertTTLWithin(ttl, min, max time.Duration) {
|
||||
s.T().Helper()
|
||||
assertTTLWithin(s.T(), ttl, min, max)
|
||||
}
|
||||
|
||||
// IntegrationDBSuite provides a base suite for DB integration tests.
|
||||
// Embedding suites should call SetupTest to initialize ctx and client.
|
||||
type IntegrationDBSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
tx *dbent.Tx
|
||||
}
|
||||
|
||||
// SetupTest initializes ctx and client for each test method.
|
||||
func (s *IntegrationDBSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
// 统一使用 ent.Tx,确保每个测试都有独立事务并自动回滚。
|
||||
tx := testEntTx(s.T())
|
||||
s.tx = tx
|
||||
s.client = tx.Client()
|
||||
}
|
||||
|
||||
// RequireNoError is a convenience method wrapping require.NoError with s.T().
|
||||
func (s *IntegrationDBSuite) RequireNoError(err error, msgAndArgs ...any) {
|
||||
s.T().Helper()
|
||||
require.NoError(s.T(), err, msgAndArgs...)
|
||||
}
|
||||
302
backend/internal/repository/migrations_runner.go
Normal file
302
backend/internal/repository/migrations_runner.go
Normal file
@@ -0,0 +1,302 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/migrations"
|
||||
)
|
||||
|
||||
// schemaMigrationsTableDDL 定义迁移记录表的 DDL。
|
||||
// 该表用于跟踪已应用的迁移文件及其校验和。
|
||||
// - filename: 迁移文件名,作为主键唯一标识每个迁移
|
||||
// - checksum: 文件内容的 SHA256 哈希值,用于检测迁移文件是否被篡改
|
||||
// - applied_at: 迁移应用时间戳
|
||||
const schemaMigrationsTableDDL = `
|
||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
filename TEXT PRIMARY KEY,
|
||||
checksum TEXT NOT NULL,
|
||||
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
`
|
||||
|
||||
const atlasSchemaRevisionsTableDDL = `
|
||||
CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
|
||||
version TEXT PRIMARY KEY,
|
||||
description TEXT NOT NULL,
|
||||
type INTEGER NOT NULL,
|
||||
applied INTEGER NOT NULL DEFAULT 0,
|
||||
total INTEGER NOT NULL DEFAULT 0,
|
||||
executed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
execution_time BIGINT NOT NULL DEFAULT 0,
|
||||
error TEXT NULL,
|
||||
error_stmt TEXT NULL,
|
||||
hash TEXT NOT NULL DEFAULT '',
|
||||
partial_hashes TEXT[] NULL,
|
||||
operator_version TEXT NULL
|
||||
);
|
||||
`
|
||||
|
||||
// migrationsAdvisoryLockID 是用于序列化迁移操作的 PostgreSQL Advisory Lock ID。
|
||||
// 在多实例部署场景下,该锁确保同一时间只有一个实例执行迁移。
|
||||
// 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。
|
||||
const migrationsAdvisoryLockID int64 = 694208311321144027
|
||||
const migrationsLockRetryInterval = 500 * time.Millisecond
|
||||
|
||||
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
|
||||
//
|
||||
// 该函数可以在每次应用启动时安全调用:
|
||||
// - 已应用的迁移会被自动跳过(通过校验 filename 判断)
|
||||
// - 如果迁移文件内容被修改(checksum 不匹配),会返回错误
|
||||
// - 使用 PostgreSQL Advisory Lock 确保多实例并发安全
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文,用于超时控制和取消
|
||||
// - db: 数据库连接
|
||||
//
|
||||
// 返回:
|
||||
// - error: 迁移过程中的任何错误
|
||||
func ApplyMigrations(ctx context.Context, db *sql.DB) error {
|
||||
if db == nil {
|
||||
return errors.New("nil sql db")
|
||||
}
|
||||
return applyMigrationsFS(ctx, db, migrations.FS)
|
||||
}
|
||||
|
||||
// applyMigrationsFS 是迁移执行的核心实现。
|
||||
// 它从指定的文件系统读取 SQL 迁移文件并按顺序应用。
|
||||
//
|
||||
// 迁移执行流程:
|
||||
// 1. 获取 PostgreSQL Advisory Lock,防止多实例并发迁移
|
||||
// 2. 确保 schema_migrations 表存在
|
||||
// 3. 按文件名排序读取所有 .sql 文件
|
||||
// 4. 对于每个迁移文件:
|
||||
// - 计算文件内容的 SHA256 校验和
|
||||
// - 检查该迁移是否已应用(通过 filename 查询)
|
||||
// - 如果已应用,验证校验和是否匹配
|
||||
// - 如果未应用,在事务中执行迁移并记录
|
||||
// 5. 释放 Advisory Lock
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - db: 数据库连接
|
||||
// - fsys: 包含迁移文件的文件系统(通常是 embed.FS)
|
||||
func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
||||
if db == nil {
|
||||
return errors.New("nil sql db")
|
||||
}
|
||||
|
||||
// 获取分布式锁,确保多实例部署时只有一个实例执行迁移。
|
||||
// 这是 PostgreSQL 特有的 Advisory Lock 机制。
|
||||
if err := pgAdvisoryLock(ctx, db); err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
// 无论迁移是否成功,都要释放锁。
|
||||
// 使用 context.Background() 确保即使原 ctx 已取消也能释放锁。
|
||||
_ = pgAdvisoryUnlock(context.Background(), db)
|
||||
}()
|
||||
|
||||
// 创建迁移记录表(如果不存在)。
|
||||
// 该表记录所有已应用的迁移及其校验和。
|
||||
if _, err := db.ExecContext(ctx, schemaMigrationsTableDDL); err != nil {
|
||||
return fmt.Errorf("create schema_migrations: %w", err)
|
||||
}
|
||||
|
||||
// 自动对齐 Atlas 基线(如果检测到 legacy schema_migrations 且缺失 atlas_schema_revisions)。
|
||||
if err := ensureAtlasBaselineAligned(ctx, db, fsys); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取所有 .sql 迁移文件并按文件名排序。
|
||||
// 命名规范:使用零填充数字前缀(如 001_init.sql, 002_add_users.sql)。
|
||||
files, err := fs.Glob(fsys, "*.sql")
|
||||
if err != nil {
|
||||
return fmt.Errorf("list migrations: %w", err)
|
||||
}
|
||||
sort.Strings(files) // 确保按文件名顺序执行迁移
|
||||
|
||||
for _, name := range files {
|
||||
// 读取迁移文件内容
|
||||
contentBytes, err := fs.ReadFile(fsys, name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read migration %s: %w", name, err)
|
||||
}
|
||||
|
||||
content := strings.TrimSpace(string(contentBytes))
|
||||
if content == "" {
|
||||
continue // 跳过空文件
|
||||
}
|
||||
|
||||
// 计算文件内容的 SHA256 校验和,用于检测文件是否被修改。
|
||||
// 这是一种防篡改机制:如果有人修改了已应用的迁移文件,系统会拒绝启动。
|
||||
sum := sha256.Sum256([]byte(content))
|
||||
checksum := hex.EncodeToString(sum[:])
|
||||
|
||||
// 检查该迁移是否已经应用
|
||||
var existing string
|
||||
rowErr := db.QueryRowContext(ctx, "SELECT checksum FROM schema_migrations WHERE filename = $1", name).Scan(&existing)
|
||||
if rowErr == nil {
|
||||
// 迁移已应用,验证校验和是否匹配
|
||||
if existing != checksum {
|
||||
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
|
||||
// 正确的做法是创建新的迁移文件来进行变更。
|
||||
return fmt.Errorf(
|
||||
"migration %s checksum mismatch (db=%s file=%s)\n"+
|
||||
"This means the migration file was modified after being applied to the database.\n"+
|
||||
"Solutions:\n"+
|
||||
" 1. Revert to original: git log --oneline -- migrations/%s && git checkout <commit> -- migrations/%s\n"+
|
||||
" 2. For new changes, create a new migration file instead of modifying existing ones\n"+
|
||||
"Note: Modifying applied migrations breaks the immutability principle and can cause inconsistencies across environments",
|
||||
name, existing, checksum, name, name,
|
||||
)
|
||||
}
|
||||
continue // 迁移已应用且校验和匹配,跳过
|
||||
}
|
||||
if !errors.Is(rowErr, sql.ErrNoRows) {
|
||||
return fmt.Errorf("check migration %s: %w", name, rowErr)
|
||||
}
|
||||
|
||||
// 迁移未应用,在事务中执行。
|
||||
// 使用事务确保迁移的原子性:要么完全成功,要么完全回滚。
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin migration %s: %w", name, err)
|
||||
}
|
||||
|
||||
// 执行迁移 SQL
|
||||
if _, err := tx.ExecContext(ctx, content); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("apply migration %s: %w", name, err)
|
||||
}
|
||||
|
||||
// 记录迁移已完成,保存文件名和校验和
|
||||
if _, err := tx.ExecContext(ctx, "INSERT INTO schema_migrations (filename, checksum) VALUES ($1, $2)", name, checksum); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("record migration %s: %w", name, err)
|
||||
}
|
||||
|
||||
// 提交事务
|
||||
if err := tx.Commit(); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("commit migration %s: %w", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureAtlasBaselineAligned(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
||||
hasLegacy, err := tableExists(ctx, db, "schema_migrations")
|
||||
if err != nil {
|
||||
return fmt.Errorf("check schema_migrations: %w", err)
|
||||
}
|
||||
if !hasLegacy {
|
||||
return nil
|
||||
}
|
||||
|
||||
hasAtlas, err := tableExists(ctx, db, "atlas_schema_revisions")
|
||||
if err != nil {
|
||||
return fmt.Errorf("check atlas_schema_revisions: %w", err)
|
||||
}
|
||||
if !hasAtlas {
|
||||
if _, err := db.ExecContext(ctx, atlasSchemaRevisionsTableDDL); err != nil {
|
||||
return fmt.Errorf("create atlas_schema_revisions: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var count int
|
||||
if err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM atlas_schema_revisions").Scan(&count); err != nil {
|
||||
return fmt.Errorf("count atlas_schema_revisions: %w", err)
|
||||
}
|
||||
if count > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
version, description, hash, err := latestMigrationBaseline(fsys)
|
||||
if err != nil {
|
||||
return fmt.Errorf("atlas baseline version: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.ExecContext(ctx, `
|
||||
INSERT INTO atlas_schema_revisions (version, description, type, applied, total, executed_at, execution_time, hash)
|
||||
VALUES ($1, $2, $3, 0, 0, NOW(), 0, $4)
|
||||
`, version, description, 1, hash); err != nil {
|
||||
return fmt.Errorf("insert atlas baseline: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func tableExists(ctx context.Context, db *sql.DB, tableName string) (bool, error) {
|
||||
var exists bool
|
||||
err := db.QueryRowContext(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'public' AND table_name = $1
|
||||
)
|
||||
`, tableName).Scan(&exists)
|
||||
return exists, err
|
||||
}
|
||||
|
||||
func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) {
|
||||
files, err := fs.Glob(fsys, "*.sql")
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
if len(files) == 0 {
|
||||
return "baseline", "baseline", "", nil
|
||||
}
|
||||
sort.Strings(files)
|
||||
name := files[len(files)-1]
|
||||
contentBytes, err := fs.ReadFile(fsys, name)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
content := strings.TrimSpace(string(contentBytes))
|
||||
sum := sha256.Sum256([]byte(content))
|
||||
hash := hex.EncodeToString(sum[:])
|
||||
version := strings.TrimSuffix(name, ".sql")
|
||||
return version, version, hash, nil
|
||||
}
|
||||
|
||||
// pgAdvisoryLock 获取 PostgreSQL Advisory Lock。
|
||||
// Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。
|
||||
// 它非常适合用于应用层面的分布式锁场景,如迁移序列化。
|
||||
func pgAdvisoryLock(ctx context.Context, db *sql.DB) error {
|
||||
ticker := time.NewTicker(migrationsLockRetryInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
var locked bool
|
||||
if err := db.QueryRowContext(ctx, "SELECT pg_try_advisory_lock($1)", migrationsAdvisoryLockID).Scan(&locked); err != nil {
|
||||
return fmt.Errorf("acquire migrations lock: %w", err)
|
||||
}
|
||||
if locked {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("acquire migrations lock: %w", ctx.Err())
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pgAdvisoryUnlock 释放 PostgreSQL Advisory Lock。
|
||||
// 必须在获取锁后确保释放,否则会阻塞其他实例的迁移操作。
|
||||
func pgAdvisoryUnlock(ctx context.Context, db *sql.DB) error {
|
||||
_, err := db.ExecContext(ctx, "SELECT pg_advisory_unlock($1)", migrationsAdvisoryLockID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("release migrations lock: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
||||
tx := testTx(t)
|
||||
|
||||
// Re-apply migrations to verify idempotency (no errors, no duplicate rows).
|
||||
require.NoError(t, ApplyMigrations(context.Background(), integrationDB))
|
||||
|
||||
// schema_migrations should have at least the current migration set.
|
||||
var applied int
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT COUNT(*) FROM schema_migrations").Scan(&applied))
|
||||
require.GreaterOrEqual(t, applied, 7, "expected schema_migrations to contain applied migrations")
|
||||
|
||||
// users: columns required by repository queries
|
||||
requireColumn(t, tx, "users", "username", "character varying", 100, false)
|
||||
requireColumn(t, tx, "users", "notes", "text", 0, false)
|
||||
|
||||
// accounts: schedulable and rate-limit fields
|
||||
requireColumn(t, tx, "accounts", "notes", "text", 0, true)
|
||||
requireColumn(t, tx, "accounts", "schedulable", "boolean", 0, false)
|
||||
requireColumn(t, tx, "accounts", "rate_limited_at", "timestamp with time zone", 0, true)
|
||||
requireColumn(t, tx, "accounts", "rate_limit_reset_at", "timestamp with time zone", 0, true)
|
||||
requireColumn(t, tx, "accounts", "overload_until", "timestamp with time zone", 0, true)
|
||||
requireColumn(t, tx, "accounts", "session_window_status", "character varying", 20, true)
|
||||
|
||||
// api_keys: key length should be 128
|
||||
requireColumn(t, tx, "api_keys", "key", "character varying", 128, false)
|
||||
|
||||
// redeem_codes: subscription fields
|
||||
requireColumn(t, tx, "redeem_codes", "group_id", "bigint", 0, true)
|
||||
requireColumn(t, tx, "redeem_codes", "validity_days", "integer", 0, false)
|
||||
|
||||
// usage_logs: billing_type used by filters/stats
|
||||
requireColumn(t, tx, "usage_logs", "billing_type", "smallint", 0, false)
|
||||
|
||||
// settings table should exist
|
||||
var settingsRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
|
||||
require.True(t, settingsRegclass.Valid, "expected settings table to exist")
|
||||
|
||||
// user_allowed_groups table should exist
|
||||
var uagRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.user_allowed_groups')").Scan(&uagRegclass))
|
||||
require.True(t, uagRegclass.Valid, "expected user_allowed_groups table to exist")
|
||||
|
||||
// user_subscriptions: deleted_at for soft delete support (migration 012)
|
||||
requireColumn(t, tx, "user_subscriptions", "deleted_at", "timestamp with time zone", 0, true)
|
||||
|
||||
// orphan_allowed_groups_audit table should exist (migration 013)
|
||||
var orphanAuditRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.orphan_allowed_groups_audit')").Scan(&orphanAuditRegclass))
|
||||
require.True(t, orphanAuditRegclass.Valid, "expected orphan_allowed_groups_audit table to exist")
|
||||
|
||||
// account_groups: created_at should be timestamptz
|
||||
requireColumn(t, tx, "account_groups", "created_at", "timestamp with time zone", 0, false)
|
||||
|
||||
// user_allowed_groups: created_at should be timestamptz
|
||||
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
|
||||
}
|
||||
|
||||
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
|
||||
t.Helper()
|
||||
|
||||
var row struct {
|
||||
DataType string
|
||||
MaxLen sql.NullInt64
|
||||
Nullable string
|
||||
}
|
||||
|
||||
err := tx.QueryRowContext(context.Background(), `
|
||||
SELECT
|
||||
data_type,
|
||||
character_maximum_length,
|
||||
is_nullable
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = $1
|
||||
AND column_name = $2
|
||||
`, table, column).Scan(&row.DataType, &row.MaxLen, &row.Nullable)
|
||||
require.NoError(t, err, "query information_schema.columns for %s.%s", table, column)
|
||||
require.Equal(t, dataType, row.DataType, "data_type mismatch for %s.%s", table, column)
|
||||
|
||||
if maxLen > 0 {
|
||||
require.True(t, row.MaxLen.Valid, "expected maxLen for %s.%s", table, column)
|
||||
require.Equal(t, int64(maxLen), row.MaxLen.Int64, "maxLen mismatch for %s.%s", table, column)
|
||||
}
|
||||
|
||||
if nullable {
|
||||
require.Equal(t, "YES", row.Nullable, "nullable mismatch for %s.%s", table, column)
|
||||
} else {
|
||||
require.Equal(t, "NO", row.Nullable, "nullable mismatch for %s.%s", table, column)
|
||||
}
|
||||
}
|
||||
89
backend/internal/repository/openai_oauth_service.go
Normal file
89
backend/internal/repository/openai_oauth_service.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
// NewOpenAIOAuthClient creates a new OpenAI OAuth client
|
||||
func NewOpenAIOAuthClient() service.OpenAIOAuthClient {
|
||||
return &openaiOAuthService{tokenURL: openai.TokenURL}
|
||||
}
|
||||
|
||||
type openaiOAuthService struct {
|
||||
tokenURL string
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
|
||||
client := createOpenAIReqClient(proxyURL)
|
||||
|
||||
if redirectURI == "" {
|
||||
redirectURI = openai.DefaultRedirectURI
|
||||
}
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "authorization_code")
|
||||
formData.Set("client_id", openai.ClientID)
|
||||
formData.Set("code", code)
|
||||
formData.Set("redirect_uri", redirectURI)
|
||||
formData.Set("code_verifier", codeVerifier)
|
||||
|
||||
var tokenResp openai.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||
client := createOpenAIReqClient(proxyURL)
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "refresh_token")
|
||||
formData.Set("refresh_token", refreshToken)
|
||||
formData.Set("client_id", openai.ClientID)
|
||||
formData.Set("scope", openai.RefreshScopes)
|
||||
|
||||
var tokenResp openai.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func createOpenAIReqClient(proxyURL string) *req.Client {
|
||||
return getSharedReqClient(reqClientOptions{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 60 * time.Second,
|
||||
})
|
||||
}
|
||||
249
backend/internal/repository/openai_oauth_service_test.go
Normal file
249
backend/internal/repository/openai_oauth_service_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type OpenAIOAuthServiceSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
srv *httptest.Server
|
||||
svc *openaiOAuthService
|
||||
received chan url.Values
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.received = make(chan url.Values, 1)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) setupServer(handler http.HandlerFunc) {
|
||||
s.srv = newLocalTestServer(s.T(), handler)
|
||||
s.svc = &openaiOAuthService{tokenURL: s.srv.URL}
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() {
|
||||
errCh := make(chan string, 1)
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
errCh <- "method mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
errCh <- "ParseForm failed"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("grant_type"); got != "authorization_code" {
|
||||
errCh <- "grant_type mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("client_id"); got != openai.ClientID {
|
||||
errCh <- "client_id mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("code"); got != "code" {
|
||||
errCh <- "code mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("redirect_uri"); got != openai.DefaultRedirectURI {
|
||||
errCh <- "redirect_uri mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("code_verifier"); got != "ver" {
|
||||
errCh <- "code_verifier mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`)
|
||||
}))
|
||||
|
||||
resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "")
|
||||
require.NoError(s.T(), err, "ExchangeCode")
|
||||
select {
|
||||
case msg := <-errCh:
|
||||
require.Fail(s.T(), msg)
|
||||
default:
|
||||
}
|
||||
require.Equal(s.T(), "at", resp.AccessToken)
|
||||
require.Equal(s.T(), "rt", resp.RefreshToken)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
|
||||
errCh := make(chan string, 1)
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
errCh <- "ParseForm failed"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("grant_type"); got != "refresh_token" {
|
||||
errCh <- "grant_type mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("refresh_token"); got != "rt" {
|
||||
errCh <- "refresh_token mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("client_id"); got != openai.ClientID {
|
||||
errCh <- "client_id mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("scope"); got != openai.RefreshScopes {
|
||||
errCh <- "scope mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at2","refresh_token":"rt2","token_type":"bearer","expires_in":3600}`)
|
||||
}))
|
||||
|
||||
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
|
||||
require.NoError(s.T(), err, "RefreshToken")
|
||||
select {
|
||||
case msg := <-errCh:
|
||||
require.Fail(s.T(), msg)
|
||||
default:
|
||||
}
|
||||
require.Equal(s.T(), "at2", resp.AccessToken)
|
||||
require.Equal(s.T(), "rt2", resp.RefreshToken)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = io.WriteString(w, "bad")
|
||||
}))
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "status 400")
|
||||
require.ErrorContains(s.T(), err, "bad")
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
s.srv.Close()
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "request failed")
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
|
||||
started := make(chan struct{})
|
||||
block := make(chan struct{})
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
close(started)
|
||||
<-block
|
||||
}))
|
||||
|
||||
ctx, cancel := context.WithCancel(s.ctx)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||
done <- err
|
||||
}()
|
||||
|
||||
<-started
|
||||
cancel()
|
||||
close(block)
|
||||
|
||||
err := <-done
|
||||
require.Error(s.T(), err)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
|
||||
want := "http://localhost:9999/cb"
|
||||
errCh := make(chan string, 1)
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = r.ParseForm()
|
||||
if got := r.PostForm.Get("redirect_uri"); got != want {
|
||||
errCh <- "redirect_uri mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
|
||||
}))
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "")
|
||||
require.NoError(s.T(), err, "ExchangeCode")
|
||||
select {
|
||||
case msg := <-errCh:
|
||||
require.Fail(s.T(), msg)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = r.ParseForm()
|
||||
s.received <- r.PostForm
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
|
||||
}))
|
||||
s.svc.tokenURL = s.srv.URL + "?x=1"
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||
require.NoError(s.T(), err, "ExchangeCode")
|
||||
select {
|
||||
case <-s.received:
|
||||
default:
|
||||
require.Fail(s.T(), "expected server to receive request")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w, "not-valid-json")
|
||||
}))
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
|
||||
require.Error(s.T(), err, "expected error for invalid JSON response")
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = io.WriteString(w, "unauthorized")
|
||||
}))
|
||||
|
||||
_, err := s.svc.RefreshToken(s.ctx, "rt", "")
|
||||
require.Error(s.T(), err, "expected error for non-2xx status")
|
||||
require.ErrorContains(s.T(), err, "status 401")
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(OpenAIOAuthServiceSuite))
|
||||
}
|
||||
1098
backend/internal/repository/ops_repo.go
Normal file
1098
backend/internal/repository/ops_repo.go
Normal file
File diff suppressed because it is too large
Load Diff
853
backend/internal/repository/ops_repo_alerts.go
Normal file
853
backend/internal/repository/ops_repo_alerts.go
Normal file
@@ -0,0 +1,853 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (r *opsRepository) ListAlertRules(ctx context.Context) ([]*service.OpsAlertRule, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
id,
|
||||
name,
|
||||
COALESCE(description, ''),
|
||||
enabled,
|
||||
COALESCE(severity, ''),
|
||||
metric_type,
|
||||
operator,
|
||||
threshold,
|
||||
window_minutes,
|
||||
sustained_minutes,
|
||||
cooldown_minutes,
|
||||
COALESCE(notify_email, true),
|
||||
filters,
|
||||
last_triggered_at,
|
||||
created_at,
|
||||
updated_at
|
||||
FROM ops_alert_rules
|
||||
ORDER BY id DESC`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
out := []*service.OpsAlertRule{}
|
||||
for rows.Next() {
|
||||
var rule service.OpsAlertRule
|
||||
var filtersRaw []byte
|
||||
var lastTriggeredAt sql.NullTime
|
||||
if err := rows.Scan(
|
||||
&rule.ID,
|
||||
&rule.Name,
|
||||
&rule.Description,
|
||||
&rule.Enabled,
|
||||
&rule.Severity,
|
||||
&rule.MetricType,
|
||||
&rule.Operator,
|
||||
&rule.Threshold,
|
||||
&rule.WindowMinutes,
|
||||
&rule.SustainedMinutes,
|
||||
&rule.CooldownMinutes,
|
||||
&rule.NotifyEmail,
|
||||
&filtersRaw,
|
||||
&lastTriggeredAt,
|
||||
&rule.CreatedAt,
|
||||
&rule.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lastTriggeredAt.Valid {
|
||||
v := lastTriggeredAt.Time
|
||||
rule.LastTriggeredAt = &v
|
||||
}
|
||||
if len(filtersRaw) > 0 && string(filtersRaw) != "null" {
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(filtersRaw, &decoded); err == nil {
|
||||
rule.Filters = decoded
|
||||
}
|
||||
}
|
||||
out = append(out, &rule)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) CreateAlertRule(ctx context.Context, input *service.OpsAlertRule) (*service.OpsAlertRule, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if input == nil {
|
||||
return nil, fmt.Errorf("nil input")
|
||||
}
|
||||
|
||||
filtersArg, err := opsNullJSONMap(input.Filters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := `
|
||||
INSERT INTO ops_alert_rules (
|
||||
name,
|
||||
description,
|
||||
enabled,
|
||||
severity,
|
||||
metric_type,
|
||||
operator,
|
||||
threshold,
|
||||
window_minutes,
|
||||
sustained_minutes,
|
||||
cooldown_minutes,
|
||||
notify_email,
|
||||
filters,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES (
|
||||
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,NOW(),NOW()
|
||||
)
|
||||
RETURNING
|
||||
id,
|
||||
name,
|
||||
COALESCE(description, ''),
|
||||
enabled,
|
||||
COALESCE(severity, ''),
|
||||
metric_type,
|
||||
operator,
|
||||
threshold,
|
||||
window_minutes,
|
||||
sustained_minutes,
|
||||
cooldown_minutes,
|
||||
COALESCE(notify_email, true),
|
||||
filters,
|
||||
last_triggered_at,
|
||||
created_at,
|
||||
updated_at`
|
||||
|
||||
var out service.OpsAlertRule
|
||||
var filtersRaw []byte
|
||||
var lastTriggeredAt sql.NullTime
|
||||
|
||||
if err := r.db.QueryRowContext(
|
||||
ctx,
|
||||
q,
|
||||
strings.TrimSpace(input.Name),
|
||||
strings.TrimSpace(input.Description),
|
||||
input.Enabled,
|
||||
strings.TrimSpace(input.Severity),
|
||||
strings.TrimSpace(input.MetricType),
|
||||
strings.TrimSpace(input.Operator),
|
||||
input.Threshold,
|
||||
input.WindowMinutes,
|
||||
input.SustainedMinutes,
|
||||
input.CooldownMinutes,
|
||||
input.NotifyEmail,
|
||||
filtersArg,
|
||||
).Scan(
|
||||
&out.ID,
|
||||
&out.Name,
|
||||
&out.Description,
|
||||
&out.Enabled,
|
||||
&out.Severity,
|
||||
&out.MetricType,
|
||||
&out.Operator,
|
||||
&out.Threshold,
|
||||
&out.WindowMinutes,
|
||||
&out.SustainedMinutes,
|
||||
&out.CooldownMinutes,
|
||||
&out.NotifyEmail,
|
||||
&filtersRaw,
|
||||
&lastTriggeredAt,
|
||||
&out.CreatedAt,
|
||||
&out.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lastTriggeredAt.Valid {
|
||||
v := lastTriggeredAt.Time
|
||||
out.LastTriggeredAt = &v
|
||||
}
|
||||
if len(filtersRaw) > 0 && string(filtersRaw) != "null" {
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(filtersRaw, &decoded); err == nil {
|
||||
out.Filters = decoded
|
||||
}
|
||||
}
|
||||
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) UpdateAlertRule(ctx context.Context, input *service.OpsAlertRule) (*service.OpsAlertRule, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if input == nil {
|
||||
return nil, fmt.Errorf("nil input")
|
||||
}
|
||||
if input.ID <= 0 {
|
||||
return nil, fmt.Errorf("invalid id")
|
||||
}
|
||||
|
||||
filtersArg, err := opsNullJSONMap(input.Filters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := `
|
||||
UPDATE ops_alert_rules
|
||||
SET
|
||||
name = $2,
|
||||
description = $3,
|
||||
enabled = $4,
|
||||
severity = $5,
|
||||
metric_type = $6,
|
||||
operator = $7,
|
||||
threshold = $8,
|
||||
window_minutes = $9,
|
||||
sustained_minutes = $10,
|
||||
cooldown_minutes = $11,
|
||||
notify_email = $12,
|
||||
filters = $13,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
RETURNING
|
||||
id,
|
||||
name,
|
||||
COALESCE(description, ''),
|
||||
enabled,
|
||||
COALESCE(severity, ''),
|
||||
metric_type,
|
||||
operator,
|
||||
threshold,
|
||||
window_minutes,
|
||||
sustained_minutes,
|
||||
cooldown_minutes,
|
||||
COALESCE(notify_email, true),
|
||||
filters,
|
||||
last_triggered_at,
|
||||
created_at,
|
||||
updated_at`
|
||||
|
||||
var out service.OpsAlertRule
|
||||
var filtersRaw []byte
|
||||
var lastTriggeredAt sql.NullTime
|
||||
|
||||
if err := r.db.QueryRowContext(
|
||||
ctx,
|
||||
q,
|
||||
input.ID,
|
||||
strings.TrimSpace(input.Name),
|
||||
strings.TrimSpace(input.Description),
|
||||
input.Enabled,
|
||||
strings.TrimSpace(input.Severity),
|
||||
strings.TrimSpace(input.MetricType),
|
||||
strings.TrimSpace(input.Operator),
|
||||
input.Threshold,
|
||||
input.WindowMinutes,
|
||||
input.SustainedMinutes,
|
||||
input.CooldownMinutes,
|
||||
input.NotifyEmail,
|
||||
filtersArg,
|
||||
).Scan(
|
||||
&out.ID,
|
||||
&out.Name,
|
||||
&out.Description,
|
||||
&out.Enabled,
|
||||
&out.Severity,
|
||||
&out.MetricType,
|
||||
&out.Operator,
|
||||
&out.Threshold,
|
||||
&out.WindowMinutes,
|
||||
&out.SustainedMinutes,
|
||||
&out.CooldownMinutes,
|
||||
&out.NotifyEmail,
|
||||
&filtersRaw,
|
||||
&lastTriggeredAt,
|
||||
&out.CreatedAt,
|
||||
&out.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if lastTriggeredAt.Valid {
|
||||
v := lastTriggeredAt.Time
|
||||
out.LastTriggeredAt = &v
|
||||
}
|
||||
if len(filtersRaw) > 0 && string(filtersRaw) != "null" {
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(filtersRaw, &decoded); err == nil {
|
||||
out.Filters = decoded
|
||||
}
|
||||
}
|
||||
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) DeleteAlertRule(ctx context.Context, id int64) error {
|
||||
if r == nil || r.db == nil {
|
||||
return fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if id <= 0 {
|
||||
return fmt.Errorf("invalid id")
|
||||
}
|
||||
|
||||
res, err := r.db.ExecContext(ctx, "DELETE FROM ops_alert_rules WHERE id = $1", id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) ListAlertEvents(ctx context.Context, filter *service.OpsAlertEventFilter) ([]*service.OpsAlertEvent, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
filter = &service.OpsAlertEventFilter{}
|
||||
}
|
||||
|
||||
limit := filter.Limit
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
if limit > 500 {
|
||||
limit = 500
|
||||
}
|
||||
|
||||
where, args := buildOpsAlertEventsWhere(filter)
|
||||
args = append(args, limit)
|
||||
limitArg := "$" + itoa(len(args))
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
id,
|
||||
COALESCE(rule_id, 0),
|
||||
COALESCE(severity, ''),
|
||||
COALESCE(status, ''),
|
||||
COALESCE(title, ''),
|
||||
COALESCE(description, ''),
|
||||
metric_value,
|
||||
threshold_value,
|
||||
dimensions,
|
||||
fired_at,
|
||||
resolved_at,
|
||||
email_sent,
|
||||
created_at
|
||||
FROM ops_alert_events
|
||||
` + where + `
|
||||
ORDER BY fired_at DESC, id DESC
|
||||
LIMIT ` + limitArg
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
out := []*service.OpsAlertEvent{}
|
||||
for rows.Next() {
|
||||
var ev service.OpsAlertEvent
|
||||
var metricValue sql.NullFloat64
|
||||
var thresholdValue sql.NullFloat64
|
||||
var dimensionsRaw []byte
|
||||
var resolvedAt sql.NullTime
|
||||
if err := rows.Scan(
|
||||
&ev.ID,
|
||||
&ev.RuleID,
|
||||
&ev.Severity,
|
||||
&ev.Status,
|
||||
&ev.Title,
|
||||
&ev.Description,
|
||||
&metricValue,
|
||||
&thresholdValue,
|
||||
&dimensionsRaw,
|
||||
&ev.FiredAt,
|
||||
&resolvedAt,
|
||||
&ev.EmailSent,
|
||||
&ev.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if metricValue.Valid {
|
||||
v := metricValue.Float64
|
||||
ev.MetricValue = &v
|
||||
}
|
||||
if thresholdValue.Valid {
|
||||
v := thresholdValue.Float64
|
||||
ev.ThresholdValue = &v
|
||||
}
|
||||
if resolvedAt.Valid {
|
||||
v := resolvedAt.Time
|
||||
ev.ResolvedAt = &v
|
||||
}
|
||||
if len(dimensionsRaw) > 0 && string(dimensionsRaw) != "null" {
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(dimensionsRaw, &decoded); err == nil {
|
||||
ev.Dimensions = decoded
|
||||
}
|
||||
}
|
||||
out = append(out, &ev)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) GetAlertEventByID(ctx context.Context, eventID int64) (*service.OpsAlertEvent, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if eventID <= 0 {
|
||||
return nil, fmt.Errorf("invalid event id")
|
||||
}
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
id,
|
||||
COALESCE(rule_id, 0),
|
||||
COALESCE(severity, ''),
|
||||
COALESCE(status, ''),
|
||||
COALESCE(title, ''),
|
||||
COALESCE(description, ''),
|
||||
metric_value,
|
||||
threshold_value,
|
||||
dimensions,
|
||||
fired_at,
|
||||
resolved_at,
|
||||
email_sent,
|
||||
created_at
|
||||
FROM ops_alert_events
|
||||
WHERE id = $1`
|
||||
|
||||
row := r.db.QueryRowContext(ctx, q, eventID)
|
||||
ev, err := scanOpsAlertEvent(row)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return ev, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*service.OpsAlertEvent, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if ruleID <= 0 {
|
||||
return nil, fmt.Errorf("invalid rule id")
|
||||
}
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
id,
|
||||
COALESCE(rule_id, 0),
|
||||
COALESCE(severity, ''),
|
||||
COALESCE(status, ''),
|
||||
COALESCE(title, ''),
|
||||
COALESCE(description, ''),
|
||||
metric_value,
|
||||
threshold_value,
|
||||
dimensions,
|
||||
fired_at,
|
||||
resolved_at,
|
||||
email_sent,
|
||||
created_at
|
||||
FROM ops_alert_events
|
||||
WHERE rule_id = $1 AND status = $2
|
||||
ORDER BY fired_at DESC
|
||||
LIMIT 1`
|
||||
|
||||
row := r.db.QueryRowContext(ctx, q, ruleID, service.OpsAlertStatusFiring)
|
||||
ev, err := scanOpsAlertEvent(row)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return ev, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*service.OpsAlertEvent, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if ruleID <= 0 {
|
||||
return nil, fmt.Errorf("invalid rule id")
|
||||
}
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
id,
|
||||
COALESCE(rule_id, 0),
|
||||
COALESCE(severity, ''),
|
||||
COALESCE(status, ''),
|
||||
COALESCE(title, ''),
|
||||
COALESCE(description, ''),
|
||||
metric_value,
|
||||
threshold_value,
|
||||
dimensions,
|
||||
fired_at,
|
||||
resolved_at,
|
||||
email_sent,
|
||||
created_at
|
||||
FROM ops_alert_events
|
||||
WHERE rule_id = $1
|
||||
ORDER BY fired_at DESC
|
||||
LIMIT 1`
|
||||
|
||||
row := r.db.QueryRowContext(ctx, q, ruleID)
|
||||
ev, err := scanOpsAlertEvent(row)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return ev, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) CreateAlertEvent(ctx context.Context, event *service.OpsAlertEvent) (*service.OpsAlertEvent, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if event == nil {
|
||||
return nil, fmt.Errorf("nil event")
|
||||
}
|
||||
|
||||
dimensionsArg, err := opsNullJSONMap(event.Dimensions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := `
|
||||
INSERT INTO ops_alert_events (
|
||||
rule_id,
|
||||
severity,
|
||||
status,
|
||||
title,
|
||||
description,
|
||||
metric_value,
|
||||
threshold_value,
|
||||
dimensions,
|
||||
fired_at,
|
||||
resolved_at,
|
||||
email_sent,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,NOW()
|
||||
)
|
||||
RETURNING
|
||||
id,
|
||||
COALESCE(rule_id, 0),
|
||||
COALESCE(severity, ''),
|
||||
COALESCE(status, ''),
|
||||
COALESCE(title, ''),
|
||||
COALESCE(description, ''),
|
||||
metric_value,
|
||||
threshold_value,
|
||||
dimensions,
|
||||
fired_at,
|
||||
resolved_at,
|
||||
email_sent,
|
||||
created_at`
|
||||
|
||||
row := r.db.QueryRowContext(
|
||||
ctx,
|
||||
q,
|
||||
opsNullInt64(&event.RuleID),
|
||||
opsNullString(event.Severity),
|
||||
opsNullString(event.Status),
|
||||
opsNullString(event.Title),
|
||||
opsNullString(event.Description),
|
||||
opsNullFloat64(event.MetricValue),
|
||||
opsNullFloat64(event.ThresholdValue),
|
||||
dimensionsArg,
|
||||
event.FiredAt,
|
||||
opsNullTime(event.ResolvedAt),
|
||||
event.EmailSent,
|
||||
)
|
||||
return scanOpsAlertEvent(row)
|
||||
}
|
||||
|
||||
func (r *opsRepository) UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error {
|
||||
if r == nil || r.db == nil {
|
||||
return fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if eventID <= 0 {
|
||||
return fmt.Errorf("invalid event id")
|
||||
}
|
||||
if strings.TrimSpace(status) == "" {
|
||||
return fmt.Errorf("invalid status")
|
||||
}
|
||||
|
||||
q := `
|
||||
UPDATE ops_alert_events
|
||||
SET status = $2,
|
||||
resolved_at = $3
|
||||
WHERE id = $1`
|
||||
|
||||
_, err := r.db.ExecContext(ctx, q, eventID, strings.TrimSpace(status), opsNullTime(resolvedAt))
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *opsRepository) UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error {
|
||||
if r == nil || r.db == nil {
|
||||
return fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if eventID <= 0 {
|
||||
return fmt.Errorf("invalid event id")
|
||||
}
|
||||
|
||||
_, err := r.db.ExecContext(ctx, "UPDATE ops_alert_events SET email_sent = $2 WHERE id = $1", eventID, emailSent)
|
||||
return err
|
||||
}
|
||||
|
||||
type opsAlertEventRow interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func (r *opsRepository) CreateAlertSilence(ctx context.Context, input *service.OpsAlertSilence) (*service.OpsAlertSilence, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if input == nil {
|
||||
return nil, fmt.Errorf("nil input")
|
||||
}
|
||||
if input.RuleID <= 0 {
|
||||
return nil, fmt.Errorf("invalid rule_id")
|
||||
}
|
||||
platform := strings.TrimSpace(input.Platform)
|
||||
if platform == "" {
|
||||
return nil, fmt.Errorf("invalid platform")
|
||||
}
|
||||
if input.Until.IsZero() {
|
||||
return nil, fmt.Errorf("invalid until")
|
||||
}
|
||||
|
||||
q := `
|
||||
INSERT INTO ops_alert_silences (
|
||||
rule_id,
|
||||
platform,
|
||||
group_id,
|
||||
region,
|
||||
until,
|
||||
reason,
|
||||
created_by,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1,$2,$3,$4,$5,$6,$7,NOW()
|
||||
)
|
||||
RETURNING id, rule_id, platform, group_id, region, until, COALESCE(reason,''), created_by, created_at`
|
||||
|
||||
row := r.db.QueryRowContext(
|
||||
ctx,
|
||||
q,
|
||||
input.RuleID,
|
||||
platform,
|
||||
opsNullInt64(input.GroupID),
|
||||
opsNullString(input.Region),
|
||||
input.Until,
|
||||
opsNullString(input.Reason),
|
||||
opsNullInt64(input.CreatedBy),
|
||||
)
|
||||
|
||||
var out service.OpsAlertSilence
|
||||
var groupID sql.NullInt64
|
||||
var region sql.NullString
|
||||
var createdBy sql.NullInt64
|
||||
if err := row.Scan(
|
||||
&out.ID,
|
||||
&out.RuleID,
|
||||
&out.Platform,
|
||||
&groupID,
|
||||
®ion,
|
||||
&out.Until,
|
||||
&out.Reason,
|
||||
&createdBy,
|
||||
&out.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if groupID.Valid {
|
||||
v := groupID.Int64
|
||||
out.GroupID = &v
|
||||
}
|
||||
if region.Valid {
|
||||
v := strings.TrimSpace(region.String)
|
||||
if v != "" {
|
||||
out.Region = &v
|
||||
}
|
||||
}
|
||||
if createdBy.Valid {
|
||||
v := createdBy.Int64
|
||||
out.CreatedBy = &v
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return false, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if ruleID <= 0 {
|
||||
return false, fmt.Errorf("invalid rule id")
|
||||
}
|
||||
platform = strings.TrimSpace(platform)
|
||||
if platform == "" {
|
||||
return false, nil
|
||||
}
|
||||
if now.IsZero() {
|
||||
now = time.Now().UTC()
|
||||
}
|
||||
|
||||
q := `
|
||||
SELECT 1
|
||||
FROM ops_alert_silences
|
||||
WHERE rule_id = $1
|
||||
AND platform = $2
|
||||
AND (group_id IS NOT DISTINCT FROM $3)
|
||||
AND (region IS NOT DISTINCT FROM $4)
|
||||
AND until > $5
|
||||
LIMIT 1`
|
||||
|
||||
var dummy int
|
||||
err := r.db.QueryRowContext(ctx, q, ruleID, platform, opsNullInt64(groupID), opsNullString(region), now).Scan(&dummy)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func scanOpsAlertEvent(row opsAlertEventRow) (*service.OpsAlertEvent, error) {
|
||||
var ev service.OpsAlertEvent
|
||||
var metricValue sql.NullFloat64
|
||||
var thresholdValue sql.NullFloat64
|
||||
var dimensionsRaw []byte
|
||||
var resolvedAt sql.NullTime
|
||||
|
||||
if err := row.Scan(
|
||||
&ev.ID,
|
||||
&ev.RuleID,
|
||||
&ev.Severity,
|
||||
&ev.Status,
|
||||
&ev.Title,
|
||||
&ev.Description,
|
||||
&metricValue,
|
||||
&thresholdValue,
|
||||
&dimensionsRaw,
|
||||
&ev.FiredAt,
|
||||
&resolvedAt,
|
||||
&ev.EmailSent,
|
||||
&ev.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if metricValue.Valid {
|
||||
v := metricValue.Float64
|
||||
ev.MetricValue = &v
|
||||
}
|
||||
if thresholdValue.Valid {
|
||||
v := thresholdValue.Float64
|
||||
ev.ThresholdValue = &v
|
||||
}
|
||||
if resolvedAt.Valid {
|
||||
v := resolvedAt.Time
|
||||
ev.ResolvedAt = &v
|
||||
}
|
||||
if len(dimensionsRaw) > 0 && string(dimensionsRaw) != "null" {
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(dimensionsRaw, &decoded); err == nil {
|
||||
ev.Dimensions = decoded
|
||||
}
|
||||
}
|
||||
return &ev, nil
|
||||
}
|
||||
|
||||
func buildOpsAlertEventsWhere(filter *service.OpsAlertEventFilter) (string, []any) {
|
||||
clauses := []string{"1=1"}
|
||||
args := []any{}
|
||||
|
||||
if filter == nil {
|
||||
return "WHERE " + strings.Join(clauses, " AND "), args
|
||||
}
|
||||
|
||||
if status := strings.TrimSpace(filter.Status); status != "" {
|
||||
args = append(args, status)
|
||||
clauses = append(clauses, "status = $"+itoa(len(args)))
|
||||
}
|
||||
if severity := strings.TrimSpace(filter.Severity); severity != "" {
|
||||
args = append(args, severity)
|
||||
clauses = append(clauses, "severity = $"+itoa(len(args)))
|
||||
}
|
||||
if filter.EmailSent != nil {
|
||||
args = append(args, *filter.EmailSent)
|
||||
clauses = append(clauses, "email_sent = $"+itoa(len(args)))
|
||||
}
|
||||
if filter.StartTime != nil && !filter.StartTime.IsZero() {
|
||||
args = append(args, *filter.StartTime)
|
||||
clauses = append(clauses, "fired_at >= $"+itoa(len(args)))
|
||||
}
|
||||
if filter.EndTime != nil && !filter.EndTime.IsZero() {
|
||||
args = append(args, *filter.EndTime)
|
||||
clauses = append(clauses, "fired_at < $"+itoa(len(args)))
|
||||
}
|
||||
|
||||
// Cursor pagination (descending by fired_at, then id)
|
||||
if filter.BeforeFiredAt != nil && !filter.BeforeFiredAt.IsZero() && filter.BeforeID != nil && *filter.BeforeID > 0 {
|
||||
args = append(args, *filter.BeforeFiredAt)
|
||||
tsArg := "$" + itoa(len(args))
|
||||
args = append(args, *filter.BeforeID)
|
||||
idArg := "$" + itoa(len(args))
|
||||
clauses = append(clauses, fmt.Sprintf("(fired_at < %s OR (fired_at = %s AND id < %s))", tsArg, tsArg, idArg))
|
||||
}
|
||||
// Dimensions are stored in JSONB. We filter best-effort without requiring GIN indexes.
|
||||
if platform := strings.TrimSpace(filter.Platform); platform != "" {
|
||||
args = append(args, platform)
|
||||
clauses = append(clauses, "(dimensions->>'platform') = $"+itoa(len(args)))
|
||||
}
|
||||
if filter.GroupID != nil && *filter.GroupID > 0 {
|
||||
args = append(args, fmt.Sprintf("%d", *filter.GroupID))
|
||||
clauses = append(clauses, "(dimensions->>'group_id') = $"+itoa(len(args)))
|
||||
}
|
||||
|
||||
return "WHERE " + strings.Join(clauses, " AND "), args
|
||||
}
|
||||
|
||||
func opsNullJSONMap(v map[string]any) (any, error) {
|
||||
if v == nil {
|
||||
return sql.NullString{}, nil
|
||||
}
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(b) == 0 {
|
||||
return sql.NullString{}, nil
|
||||
}
|
||||
return sql.NullString{String: string(b), Valid: true}, nil
|
||||
}
|
||||
1015
backend/internal/repository/ops_repo_dashboard.go
Normal file
1015
backend/internal/repository/ops_repo_dashboard.go
Normal file
File diff suppressed because it is too large
Load Diff
79
backend/internal/repository/ops_repo_histograms.go
Normal file
79
backend/internal/repository/ops_repo_histograms.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (r *opsRepository) GetLatencyHistogram(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsLatencyHistogramResponse, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, fmt.Errorf("nil filter")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, fmt.Errorf("start_time/end_time required")
|
||||
}
|
||||
|
||||
start := filter.StartTime.UTC()
|
||||
end := filter.EndTime.UTC()
|
||||
|
||||
join, where, args, _ := buildUsageWhere(filter, start, end, 1)
|
||||
rangeExpr := latencyHistogramRangeCaseExpr("ul.duration_ms")
|
||||
orderExpr := latencyHistogramRangeOrderCaseExpr("ul.duration_ms")
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
` + rangeExpr + ` AS range,
|
||||
COALESCE(COUNT(*), 0) AS count,
|
||||
` + orderExpr + ` AS ord
|
||||
FROM usage_logs ul
|
||||
` + join + `
|
||||
` + where + `
|
||||
AND ul.duration_ms IS NOT NULL
|
||||
GROUP BY 1, 3
|
||||
ORDER BY 3 ASC`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
counts := make(map[string]int64, len(latencyHistogramOrderedRanges))
|
||||
var total int64
|
||||
for rows.Next() {
|
||||
var label string
|
||||
var count int64
|
||||
var _ord int
|
||||
if err := rows.Scan(&label, &count, &_ord); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts[label] = count
|
||||
total += count
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buckets := make([]*service.OpsLatencyHistogramBucket, 0, len(latencyHistogramOrderedRanges))
|
||||
for _, label := range latencyHistogramOrderedRanges {
|
||||
buckets = append(buckets, &service.OpsLatencyHistogramBucket{
|
||||
Range: label,
|
||||
Count: counts[label],
|
||||
})
|
||||
}
|
||||
|
||||
return &service.OpsLatencyHistogramResponse{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Platform: strings.TrimSpace(filter.Platform),
|
||||
GroupID: filter.GroupID,
|
||||
TotalRequests: total,
|
||||
Buckets: buckets,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type latencyHistogramBucket struct {
|
||||
upperMs int
|
||||
label string
|
||||
}
|
||||
|
||||
var latencyHistogramBuckets = []latencyHistogramBucket{
|
||||
{upperMs: 100, label: "0-100ms"},
|
||||
{upperMs: 200, label: "100-200ms"},
|
||||
{upperMs: 500, label: "200-500ms"},
|
||||
{upperMs: 1000, label: "500-1000ms"},
|
||||
{upperMs: 2000, label: "1000-2000ms"},
|
||||
{upperMs: 0, label: "2000ms+"}, // default bucket
|
||||
}
|
||||
|
||||
var latencyHistogramOrderedRanges = func() []string {
|
||||
out := make([]string, 0, len(latencyHistogramBuckets))
|
||||
for _, b := range latencyHistogramBuckets {
|
||||
out = append(out, b.label)
|
||||
}
|
||||
return out
|
||||
}()
|
||||
|
||||
func latencyHistogramRangeCaseExpr(column string) string {
|
||||
var sb strings.Builder
|
||||
_, _ = sb.WriteString("CASE\n")
|
||||
|
||||
for _, b := range latencyHistogramBuckets {
|
||||
if b.upperMs <= 0 {
|
||||
continue
|
||||
}
|
||||
_, _ = sb.WriteString(fmt.Sprintf("\tWHEN %s < %d THEN '%s'\n", column, b.upperMs, b.label))
|
||||
}
|
||||
|
||||
// Default bucket.
|
||||
last := latencyHistogramBuckets[len(latencyHistogramBuckets)-1]
|
||||
_, _ = sb.WriteString(fmt.Sprintf("\tELSE '%s'\n", last.label))
|
||||
_, _ = sb.WriteString("END")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func latencyHistogramRangeOrderCaseExpr(column string) string {
|
||||
var sb strings.Builder
|
||||
_, _ = sb.WriteString("CASE\n")
|
||||
|
||||
order := 1
|
||||
for _, b := range latencyHistogramBuckets {
|
||||
if b.upperMs <= 0 {
|
||||
continue
|
||||
}
|
||||
_, _ = sb.WriteString(fmt.Sprintf("\tWHEN %s < %d THEN %d\n", column, b.upperMs, order))
|
||||
order++
|
||||
}
|
||||
|
||||
_, _ = sb.WriteString(fmt.Sprintf("\tELSE %d\n", order))
|
||||
_, _ = sb.WriteString("END")
|
||||
return sb.String()
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLatencyHistogramBuckets_AreConsistent(t *testing.T) {
|
||||
require.Equal(t, len(latencyHistogramBuckets), len(latencyHistogramOrderedRanges))
|
||||
for i, b := range latencyHistogramBuckets {
|
||||
require.Equal(t, b.label, latencyHistogramOrderedRanges[i])
|
||||
}
|
||||
}
|
||||
422
backend/internal/repository/ops_repo_metrics.go
Normal file
422
backend/internal/repository/ops_repo_metrics.go
Normal file
@@ -0,0 +1,422 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (r *opsRepository) InsertSystemMetrics(ctx context.Context, input *service.OpsInsertSystemMetricsInput) error {
|
||||
if r == nil || r.db == nil {
|
||||
return fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if input == nil {
|
||||
return fmt.Errorf("nil input")
|
||||
}
|
||||
|
||||
window := input.WindowMinutes
|
||||
if window <= 0 {
|
||||
window = 1
|
||||
}
|
||||
createdAt := input.CreatedAt
|
||||
if createdAt.IsZero() {
|
||||
createdAt = time.Now().UTC()
|
||||
}
|
||||
|
||||
q := `
|
||||
INSERT INTO ops_system_metrics (
|
||||
created_at,
|
||||
window_minutes,
|
||||
platform,
|
||||
group_id,
|
||||
|
||||
success_count,
|
||||
error_count_total,
|
||||
business_limited_count,
|
||||
error_count_sla,
|
||||
|
||||
upstream_error_count_excl_429_529,
|
||||
upstream_429_count,
|
||||
upstream_529_count,
|
||||
|
||||
token_consumed,
|
||||
qps,
|
||||
tps,
|
||||
|
||||
duration_p50_ms,
|
||||
duration_p90_ms,
|
||||
duration_p95_ms,
|
||||
duration_p99_ms,
|
||||
duration_avg_ms,
|
||||
duration_max_ms,
|
||||
|
||||
ttft_p50_ms,
|
||||
ttft_p90_ms,
|
||||
ttft_p95_ms,
|
||||
ttft_p99_ms,
|
||||
ttft_avg_ms,
|
||||
ttft_max_ms,
|
||||
|
||||
cpu_usage_percent,
|
||||
memory_used_mb,
|
||||
memory_total_mb,
|
||||
memory_usage_percent,
|
||||
|
||||
db_ok,
|
||||
redis_ok,
|
||||
|
||||
redis_conn_total,
|
||||
redis_conn_idle,
|
||||
|
||||
db_conn_active,
|
||||
db_conn_idle,
|
||||
db_conn_waiting,
|
||||
|
||||
goroutine_count,
|
||||
concurrency_queue_depth
|
||||
) VALUES (
|
||||
$1,$2,$3,$4,
|
||||
$5,$6,$7,$8,
|
||||
$9,$10,$11,
|
||||
$12,$13,$14,
|
||||
$15,$16,$17,$18,$19,$20,
|
||||
$21,$22,$23,$24,$25,$26,
|
||||
$27,$28,$29,$30,
|
||||
$31,$32,
|
||||
$33,$34,
|
||||
$35,$36,$37,
|
||||
$38,$39
|
||||
)`
|
||||
|
||||
_, err := r.db.ExecContext(
|
||||
ctx,
|
||||
q,
|
||||
createdAt,
|
||||
window,
|
||||
opsNullString(input.Platform),
|
||||
opsNullInt64(input.GroupID),
|
||||
|
||||
input.SuccessCount,
|
||||
input.ErrorCountTotal,
|
||||
input.BusinessLimitedCount,
|
||||
input.ErrorCountSLA,
|
||||
|
||||
input.UpstreamErrorCountExcl429529,
|
||||
input.Upstream429Count,
|
||||
input.Upstream529Count,
|
||||
|
||||
input.TokenConsumed,
|
||||
opsNullFloat64(input.QPS),
|
||||
opsNullFloat64(input.TPS),
|
||||
|
||||
opsNullInt(input.DurationP50Ms),
|
||||
opsNullInt(input.DurationP90Ms),
|
||||
opsNullInt(input.DurationP95Ms),
|
||||
opsNullInt(input.DurationP99Ms),
|
||||
opsNullFloat64(input.DurationAvgMs),
|
||||
opsNullInt(input.DurationMaxMs),
|
||||
|
||||
opsNullInt(input.TTFTP50Ms),
|
||||
opsNullInt(input.TTFTP90Ms),
|
||||
opsNullInt(input.TTFTP95Ms),
|
||||
opsNullInt(input.TTFTP99Ms),
|
||||
opsNullFloat64(input.TTFTAvgMs),
|
||||
opsNullInt(input.TTFTMaxMs),
|
||||
|
||||
opsNullFloat64(input.CPUUsagePercent),
|
||||
opsNullInt(input.MemoryUsedMB),
|
||||
opsNullInt(input.MemoryTotalMB),
|
||||
opsNullFloat64(input.MemoryUsagePercent),
|
||||
|
||||
opsNullBool(input.DBOK),
|
||||
opsNullBool(input.RedisOK),
|
||||
|
||||
opsNullInt(input.RedisConnTotal),
|
||||
opsNullInt(input.RedisConnIdle),
|
||||
|
||||
opsNullInt(input.DBConnActive),
|
||||
opsNullInt(input.DBConnIdle),
|
||||
opsNullInt(input.DBConnWaiting),
|
||||
|
||||
opsNullInt(input.GoroutineCount),
|
||||
opsNullInt(input.ConcurrencyQueueDepth),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *opsRepository) GetLatestSystemMetrics(ctx context.Context, windowMinutes int) (*service.OpsSystemMetricsSnapshot, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if windowMinutes <= 0 {
|
||||
windowMinutes = 1
|
||||
}
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
id,
|
||||
created_at,
|
||||
window_minutes,
|
||||
|
||||
cpu_usage_percent,
|
||||
memory_used_mb,
|
||||
memory_total_mb,
|
||||
memory_usage_percent,
|
||||
|
||||
db_ok,
|
||||
redis_ok,
|
||||
|
||||
redis_conn_total,
|
||||
redis_conn_idle,
|
||||
|
||||
db_conn_active,
|
||||
db_conn_idle,
|
||||
db_conn_waiting,
|
||||
|
||||
goroutine_count,
|
||||
concurrency_queue_depth
|
||||
FROM ops_system_metrics
|
||||
WHERE window_minutes = $1
|
||||
AND platform IS NULL
|
||||
AND group_id IS NULL
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1`
|
||||
|
||||
var out service.OpsSystemMetricsSnapshot
|
||||
var cpu sql.NullFloat64
|
||||
var memUsed sql.NullInt64
|
||||
var memTotal sql.NullInt64
|
||||
var memPct sql.NullFloat64
|
||||
var dbOK sql.NullBool
|
||||
var redisOK sql.NullBool
|
||||
var redisTotal sql.NullInt64
|
||||
var redisIdle sql.NullInt64
|
||||
var dbActive sql.NullInt64
|
||||
var dbIdle sql.NullInt64
|
||||
var dbWaiting sql.NullInt64
|
||||
var goroutines sql.NullInt64
|
||||
var queueDepth sql.NullInt64
|
||||
|
||||
if err := r.db.QueryRowContext(ctx, q, windowMinutes).Scan(
|
||||
&out.ID,
|
||||
&out.CreatedAt,
|
||||
&out.WindowMinutes,
|
||||
&cpu,
|
||||
&memUsed,
|
||||
&memTotal,
|
||||
&memPct,
|
||||
&dbOK,
|
||||
&redisOK,
|
||||
&redisTotal,
|
||||
&redisIdle,
|
||||
&dbActive,
|
||||
&dbIdle,
|
||||
&dbWaiting,
|
||||
&goroutines,
|
||||
&queueDepth,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cpu.Valid {
|
||||
v := cpu.Float64
|
||||
out.CPUUsagePercent = &v
|
||||
}
|
||||
if memUsed.Valid {
|
||||
v := memUsed.Int64
|
||||
out.MemoryUsedMB = &v
|
||||
}
|
||||
if memTotal.Valid {
|
||||
v := memTotal.Int64
|
||||
out.MemoryTotalMB = &v
|
||||
}
|
||||
if memPct.Valid {
|
||||
v := memPct.Float64
|
||||
out.MemoryUsagePercent = &v
|
||||
}
|
||||
if dbOK.Valid {
|
||||
v := dbOK.Bool
|
||||
out.DBOK = &v
|
||||
}
|
||||
if redisOK.Valid {
|
||||
v := redisOK.Bool
|
||||
out.RedisOK = &v
|
||||
}
|
||||
if redisTotal.Valid {
|
||||
v := int(redisTotal.Int64)
|
||||
out.RedisConnTotal = &v
|
||||
}
|
||||
if redisIdle.Valid {
|
||||
v := int(redisIdle.Int64)
|
||||
out.RedisConnIdle = &v
|
||||
}
|
||||
if dbActive.Valid {
|
||||
v := int(dbActive.Int64)
|
||||
out.DBConnActive = &v
|
||||
}
|
||||
if dbIdle.Valid {
|
||||
v := int(dbIdle.Int64)
|
||||
out.DBConnIdle = &v
|
||||
}
|
||||
if dbWaiting.Valid {
|
||||
v := int(dbWaiting.Int64)
|
||||
out.DBConnWaiting = &v
|
||||
}
|
||||
if goroutines.Valid {
|
||||
v := int(goroutines.Int64)
|
||||
out.GoroutineCount = &v
|
||||
}
|
||||
if queueDepth.Valid {
|
||||
v := int(queueDepth.Int64)
|
||||
out.ConcurrencyQueueDepth = &v
|
||||
}
|
||||
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) UpsertJobHeartbeat(ctx context.Context, input *service.OpsUpsertJobHeartbeatInput) error {
|
||||
if r == nil || r.db == nil {
|
||||
return fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if input == nil {
|
||||
return fmt.Errorf("nil input")
|
||||
}
|
||||
if input.JobName == "" {
|
||||
return fmt.Errorf("job_name required")
|
||||
}
|
||||
|
||||
q := `
|
||||
INSERT INTO ops_job_heartbeats (
|
||||
job_name,
|
||||
last_run_at,
|
||||
last_success_at,
|
||||
last_error_at,
|
||||
last_error,
|
||||
last_duration_ms,
|
||||
updated_at
|
||||
) VALUES (
|
||||
$1,$2,$3,$4,$5,$6,NOW()
|
||||
)
|
||||
ON CONFLICT (job_name) DO UPDATE SET
|
||||
last_run_at = COALESCE(EXCLUDED.last_run_at, ops_job_heartbeats.last_run_at),
|
||||
last_success_at = COALESCE(EXCLUDED.last_success_at, ops_job_heartbeats.last_success_at),
|
||||
last_error_at = CASE
|
||||
WHEN EXCLUDED.last_success_at IS NOT NULL THEN NULL
|
||||
ELSE COALESCE(EXCLUDED.last_error_at, ops_job_heartbeats.last_error_at)
|
||||
END,
|
||||
last_error = CASE
|
||||
WHEN EXCLUDED.last_success_at IS NOT NULL THEN NULL
|
||||
ELSE COALESCE(EXCLUDED.last_error, ops_job_heartbeats.last_error)
|
||||
END,
|
||||
last_duration_ms = COALESCE(EXCLUDED.last_duration_ms, ops_job_heartbeats.last_duration_ms),
|
||||
updated_at = NOW()`
|
||||
|
||||
_, err := r.db.ExecContext(
|
||||
ctx,
|
||||
q,
|
||||
input.JobName,
|
||||
opsNullTime(input.LastRunAt),
|
||||
opsNullTime(input.LastSuccessAt),
|
||||
opsNullTime(input.LastErrorAt),
|
||||
opsNullString(input.LastError),
|
||||
opsNullInt(input.LastDurationMs),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *opsRepository) ListJobHeartbeats(ctx context.Context) ([]*service.OpsJobHeartbeat, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
job_name,
|
||||
last_run_at,
|
||||
last_success_at,
|
||||
last_error_at,
|
||||
last_error,
|
||||
last_duration_ms,
|
||||
updated_at
|
||||
FROM ops_job_heartbeats
|
||||
ORDER BY job_name ASC`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
out := make([]*service.OpsJobHeartbeat, 0, 8)
|
||||
for rows.Next() {
|
||||
var item service.OpsJobHeartbeat
|
||||
var lastRun sql.NullTime
|
||||
var lastSuccess sql.NullTime
|
||||
var lastErrorAt sql.NullTime
|
||||
var lastError sql.NullString
|
||||
var lastDuration sql.NullInt64
|
||||
|
||||
if err := rows.Scan(
|
||||
&item.JobName,
|
||||
&lastRun,
|
||||
&lastSuccess,
|
||||
&lastErrorAt,
|
||||
&lastError,
|
||||
&lastDuration,
|
||||
&item.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if lastRun.Valid {
|
||||
v := lastRun.Time
|
||||
item.LastRunAt = &v
|
||||
}
|
||||
if lastSuccess.Valid {
|
||||
v := lastSuccess.Time
|
||||
item.LastSuccessAt = &v
|
||||
}
|
||||
if lastErrorAt.Valid {
|
||||
v := lastErrorAt.Time
|
||||
item.LastErrorAt = &v
|
||||
}
|
||||
if lastError.Valid {
|
||||
v := lastError.String
|
||||
item.LastError = &v
|
||||
}
|
||||
if lastDuration.Valid {
|
||||
v := lastDuration.Int64
|
||||
item.LastDurationMs = &v
|
||||
}
|
||||
|
||||
out = append(out, &item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func opsNullBool(v *bool) any {
|
||||
if v == nil {
|
||||
return sql.NullBool{}
|
||||
}
|
||||
return sql.NullBool{Bool: *v, Valid: true}
|
||||
}
|
||||
|
||||
func opsNullFloat64(v *float64) any {
|
||||
if v == nil {
|
||||
return sql.NullFloat64{}
|
||||
}
|
||||
return sql.NullFloat64{Float64: *v, Valid: true}
|
||||
}
|
||||
|
||||
func opsNullTime(v *time.Time) any {
|
||||
if v == nil || v.IsZero() {
|
||||
return sql.NullTime{}
|
||||
}
|
||||
return sql.NullTime{Time: *v, Valid: true}
|
||||
}
|
||||
363
backend/internal/repository/ops_repo_preagg.go
Normal file
363
backend/internal/repository/ops_repo_preagg.go
Normal file
@@ -0,0 +1,363 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (r *opsRepository) UpsertHourlyMetrics(ctx context.Context, startTime, endTime time.Time) error {
|
||||
if r == nil || r.db == nil {
|
||||
return fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if startTime.IsZero() || endTime.IsZero() || !endTime.After(startTime) {
|
||||
return nil
|
||||
}
|
||||
|
||||
start := startTime.UTC()
|
||||
end := endTime.UTC()
|
||||
|
||||
// NOTE:
|
||||
// - We aggregate usage_logs + ops_error_logs into ops_metrics_hourly.
|
||||
// - We emit three dimension granularities via GROUPING SETS:
|
||||
// 1) overall: (bucket_start)
|
||||
// 2) platform: (bucket_start, platform)
|
||||
// 3) group: (bucket_start, platform, group_id)
|
||||
//
|
||||
// IMPORTANT: Postgres UNIQUE treats NULLs as distinct, so the table uses a COALESCE-based
|
||||
// unique index; our ON CONFLICT target must match that expression set.
|
||||
q := `
|
||||
WITH usage_base AS (
|
||||
SELECT
|
||||
date_trunc('hour', ul.created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
|
||||
g.platform AS platform,
|
||||
ul.group_id AS group_id,
|
||||
ul.duration_ms AS duration_ms,
|
||||
ul.first_token_ms AS first_token_ms,
|
||||
(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens) AS tokens
|
||||
FROM usage_logs ul
|
||||
JOIN groups g ON g.id = ul.group_id
|
||||
WHERE ul.created_at >= $1 AND ul.created_at < $2
|
||||
),
|
||||
usage_agg AS (
|
||||
SELECT
|
||||
bucket_start,
|
||||
CASE WHEN GROUPING(platform) = 1 THEN NULL ELSE platform END AS platform,
|
||||
CASE WHEN GROUPING(group_id) = 1 THEN NULL ELSE group_id END AS group_id,
|
||||
COUNT(*) AS success_count,
|
||||
COALESCE(SUM(tokens), 0) AS token_consumed,
|
||||
|
||||
percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p50_ms,
|
||||
percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p90_ms,
|
||||
percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p95_ms,
|
||||
percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p99_ms,
|
||||
AVG(duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_avg_ms,
|
||||
MAX(duration_ms) AS duration_max_ms,
|
||||
|
||||
percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p50_ms,
|
||||
percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p90_ms,
|
||||
percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p95_ms,
|
||||
percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p99_ms,
|
||||
AVG(first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_avg_ms,
|
||||
MAX(first_token_ms) AS ttft_max_ms
|
||||
FROM usage_base
|
||||
GROUP BY GROUPING SETS (
|
||||
(bucket_start),
|
||||
(bucket_start, platform),
|
||||
(bucket_start, platform, group_id)
|
||||
)
|
||||
),
|
||||
error_base AS (
|
||||
SELECT
|
||||
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
|
||||
-- platform is NULL for some early-phase errors (e.g. before routing); map to a sentinel
|
||||
-- value so platform-level GROUPING SETS don't collide with the overall (platform=NULL) row.
|
||||
COALESCE(platform, 'unknown') AS platform,
|
||||
group_id AS group_id,
|
||||
is_business_limited AS is_business_limited,
|
||||
error_owner AS error_owner,
|
||||
status_code AS client_status_code,
|
||||
COALESCE(upstream_status_code, status_code, 0) AS effective_status_code
|
||||
FROM ops_error_logs
|
||||
-- Exclude count_tokens requests from error metrics as they are informational probes
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
AND is_count_tokens = FALSE
|
||||
),
|
||||
error_agg AS (
|
||||
SELECT
|
||||
bucket_start,
|
||||
CASE WHEN GROUPING(platform) = 1 THEN NULL ELSE platform END AS platform,
|
||||
CASE WHEN GROUPING(group_id) = 1 THEN NULL ELSE group_id END AS group_id,
|
||||
COUNT(*) FILTER (WHERE COALESCE(client_status_code, 0) >= 400) AS error_count_total,
|
||||
COUNT(*) FILTER (WHERE COALESCE(client_status_code, 0) >= 400 AND is_business_limited) AS business_limited_count,
|
||||
COUNT(*) FILTER (WHERE COALESCE(client_status_code, 0) >= 400 AND NOT is_business_limited) AS error_count_sla,
|
||||
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(effective_status_code, 0) NOT IN (429, 529)) AS upstream_error_count_excl_429_529,
|
||||
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(effective_status_code, 0) = 429) AS upstream_429_count,
|
||||
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(effective_status_code, 0) = 529) AS upstream_529_count
|
||||
FROM error_base
|
||||
GROUP BY GROUPING SETS (
|
||||
(bucket_start),
|
||||
(bucket_start, platform),
|
||||
(bucket_start, platform, group_id)
|
||||
)
|
||||
HAVING GROUPING(group_id) = 1 OR group_id IS NOT NULL
|
||||
),
|
||||
combined AS (
|
||||
SELECT
|
||||
COALESCE(u.bucket_start, e.bucket_start) AS bucket_start,
|
||||
COALESCE(u.platform, e.platform) AS platform,
|
||||
COALESCE(u.group_id, e.group_id) AS group_id,
|
||||
|
||||
COALESCE(u.success_count, 0) AS success_count,
|
||||
COALESCE(e.error_count_total, 0) AS error_count_total,
|
||||
COALESCE(e.business_limited_count, 0) AS business_limited_count,
|
||||
COALESCE(e.error_count_sla, 0) AS error_count_sla,
|
||||
COALESCE(e.upstream_error_count_excl_429_529, 0) AS upstream_error_count_excl_429_529,
|
||||
COALESCE(e.upstream_429_count, 0) AS upstream_429_count,
|
||||
COALESCE(e.upstream_529_count, 0) AS upstream_529_count,
|
||||
|
||||
COALESCE(u.token_consumed, 0) AS token_consumed,
|
||||
|
||||
u.duration_p50_ms,
|
||||
u.duration_p90_ms,
|
||||
u.duration_p95_ms,
|
||||
u.duration_p99_ms,
|
||||
u.duration_avg_ms,
|
||||
u.duration_max_ms,
|
||||
|
||||
u.ttft_p50_ms,
|
||||
u.ttft_p90_ms,
|
||||
u.ttft_p95_ms,
|
||||
u.ttft_p99_ms,
|
||||
u.ttft_avg_ms,
|
||||
u.ttft_max_ms
|
||||
FROM usage_agg u
|
||||
FULL OUTER JOIN error_agg e
|
||||
ON u.bucket_start = e.bucket_start
|
||||
AND COALESCE(u.platform, '') = COALESCE(e.platform, '')
|
||||
AND COALESCE(u.group_id, 0) = COALESCE(e.group_id, 0)
|
||||
)
|
||||
INSERT INTO ops_metrics_hourly (
|
||||
bucket_start,
|
||||
platform,
|
||||
group_id,
|
||||
success_count,
|
||||
error_count_total,
|
||||
business_limited_count,
|
||||
error_count_sla,
|
||||
upstream_error_count_excl_429_529,
|
||||
upstream_429_count,
|
||||
upstream_529_count,
|
||||
token_consumed,
|
||||
duration_p50_ms,
|
||||
duration_p90_ms,
|
||||
duration_p95_ms,
|
||||
duration_p99_ms,
|
||||
duration_avg_ms,
|
||||
duration_max_ms,
|
||||
ttft_p50_ms,
|
||||
ttft_p90_ms,
|
||||
ttft_p95_ms,
|
||||
ttft_p99_ms,
|
||||
ttft_avg_ms,
|
||||
ttft_max_ms,
|
||||
computed_at
|
||||
)
|
||||
SELECT
|
||||
bucket_start,
|
||||
NULLIF(platform, '') AS platform,
|
||||
group_id,
|
||||
success_count,
|
||||
error_count_total,
|
||||
business_limited_count,
|
||||
error_count_sla,
|
||||
upstream_error_count_excl_429_529,
|
||||
upstream_429_count,
|
||||
upstream_529_count,
|
||||
token_consumed,
|
||||
duration_p50_ms::int,
|
||||
duration_p90_ms::int,
|
||||
duration_p95_ms::int,
|
||||
duration_p99_ms::int,
|
||||
duration_avg_ms,
|
||||
duration_max_ms::int,
|
||||
ttft_p50_ms::int,
|
||||
ttft_p90_ms::int,
|
||||
ttft_p95_ms::int,
|
||||
ttft_p99_ms::int,
|
||||
ttft_avg_ms,
|
||||
ttft_max_ms::int,
|
||||
NOW()
|
||||
FROM combined
|
||||
WHERE bucket_start IS NOT NULL
|
||||
AND (platform IS NULL OR platform <> '')
|
||||
ON CONFLICT (bucket_start, COALESCE(platform, ''), COALESCE(group_id, 0)) DO UPDATE SET
|
||||
success_count = EXCLUDED.success_count,
|
||||
error_count_total = EXCLUDED.error_count_total,
|
||||
business_limited_count = EXCLUDED.business_limited_count,
|
||||
error_count_sla = EXCLUDED.error_count_sla,
|
||||
upstream_error_count_excl_429_529 = EXCLUDED.upstream_error_count_excl_429_529,
|
||||
upstream_429_count = EXCLUDED.upstream_429_count,
|
||||
upstream_529_count = EXCLUDED.upstream_529_count,
|
||||
token_consumed = EXCLUDED.token_consumed,
|
||||
|
||||
duration_p50_ms = EXCLUDED.duration_p50_ms,
|
||||
duration_p90_ms = EXCLUDED.duration_p90_ms,
|
||||
duration_p95_ms = EXCLUDED.duration_p95_ms,
|
||||
duration_p99_ms = EXCLUDED.duration_p99_ms,
|
||||
duration_avg_ms = EXCLUDED.duration_avg_ms,
|
||||
duration_max_ms = EXCLUDED.duration_max_ms,
|
||||
|
||||
ttft_p50_ms = EXCLUDED.ttft_p50_ms,
|
||||
ttft_p90_ms = EXCLUDED.ttft_p90_ms,
|
||||
ttft_p95_ms = EXCLUDED.ttft_p95_ms,
|
||||
ttft_p99_ms = EXCLUDED.ttft_p99_ms,
|
||||
ttft_avg_ms = EXCLUDED.ttft_avg_ms,
|
||||
ttft_max_ms = EXCLUDED.ttft_max_ms,
|
||||
|
||||
computed_at = NOW()
|
||||
`
|
||||
|
||||
_, err := r.db.ExecContext(ctx, q, start, end)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *opsRepository) UpsertDailyMetrics(ctx context.Context, startTime, endTime time.Time) error {
|
||||
if r == nil || r.db == nil {
|
||||
return fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if startTime.IsZero() || endTime.IsZero() || !endTime.After(startTime) {
|
||||
return nil
|
||||
}
|
||||
|
||||
start := startTime.UTC()
|
||||
end := endTime.UTC()
|
||||
|
||||
q := `
|
||||
INSERT INTO ops_metrics_daily (
|
||||
bucket_date,
|
||||
platform,
|
||||
group_id,
|
||||
success_count,
|
||||
error_count_total,
|
||||
business_limited_count,
|
||||
error_count_sla,
|
||||
upstream_error_count_excl_429_529,
|
||||
upstream_429_count,
|
||||
upstream_529_count,
|
||||
token_consumed,
|
||||
duration_p50_ms,
|
||||
duration_p90_ms,
|
||||
duration_p95_ms,
|
||||
duration_p99_ms,
|
||||
duration_avg_ms,
|
||||
duration_max_ms,
|
||||
ttft_p50_ms,
|
||||
ttft_p90_ms,
|
||||
ttft_p95_ms,
|
||||
ttft_p99_ms,
|
||||
ttft_avg_ms,
|
||||
ttft_max_ms,
|
||||
computed_at
|
||||
)
|
||||
SELECT
|
||||
(bucket_start AT TIME ZONE 'UTC')::date AS bucket_date,
|
||||
platform,
|
||||
group_id,
|
||||
|
||||
COALESCE(SUM(success_count), 0) AS success_count,
|
||||
COALESCE(SUM(error_count_total), 0) AS error_count_total,
|
||||
COALESCE(SUM(business_limited_count), 0) AS business_limited_count,
|
||||
COALESCE(SUM(error_count_sla), 0) AS error_count_sla,
|
||||
COALESCE(SUM(upstream_error_count_excl_429_529), 0) AS upstream_error_count_excl_429_529,
|
||||
COALESCE(SUM(upstream_429_count), 0) AS upstream_429_count,
|
||||
COALESCE(SUM(upstream_529_count), 0) AS upstream_529_count,
|
||||
COALESCE(SUM(token_consumed), 0) AS token_consumed,
|
||||
|
||||
-- Approximation: weighted average for p50/p90, max for p95/p99 (conservative tail).
|
||||
ROUND(SUM(duration_p50_ms::double precision * success_count) FILTER (WHERE duration_p50_ms IS NOT NULL)
|
||||
/ NULLIF(SUM(success_count) FILTER (WHERE duration_p50_ms IS NOT NULL), 0))::int AS duration_p50_ms,
|
||||
ROUND(SUM(duration_p90_ms::double precision * success_count) FILTER (WHERE duration_p90_ms IS NOT NULL)
|
||||
/ NULLIF(SUM(success_count) FILTER (WHERE duration_p90_ms IS NOT NULL), 0))::int AS duration_p90_ms,
|
||||
MAX(duration_p95_ms) AS duration_p95_ms,
|
||||
MAX(duration_p99_ms) AS duration_p99_ms,
|
||||
SUM(duration_avg_ms * success_count) FILTER (WHERE duration_avg_ms IS NOT NULL)
|
||||
/ NULLIF(SUM(success_count) FILTER (WHERE duration_avg_ms IS NOT NULL), 0) AS duration_avg_ms,
|
||||
MAX(duration_max_ms) AS duration_max_ms,
|
||||
|
||||
ROUND(SUM(ttft_p50_ms::double precision * success_count) FILTER (WHERE ttft_p50_ms IS NOT NULL)
|
||||
/ NULLIF(SUM(success_count) FILTER (WHERE ttft_p50_ms IS NOT NULL), 0))::int AS ttft_p50_ms,
|
||||
ROUND(SUM(ttft_p90_ms::double precision * success_count) FILTER (WHERE ttft_p90_ms IS NOT NULL)
|
||||
/ NULLIF(SUM(success_count) FILTER (WHERE ttft_p90_ms IS NOT NULL), 0))::int AS ttft_p90_ms,
|
||||
MAX(ttft_p95_ms) AS ttft_p95_ms,
|
||||
MAX(ttft_p99_ms) AS ttft_p99_ms,
|
||||
SUM(ttft_avg_ms * success_count) FILTER (WHERE ttft_avg_ms IS NOT NULL)
|
||||
/ NULLIF(SUM(success_count) FILTER (WHERE ttft_avg_ms IS NOT NULL), 0) AS ttft_avg_ms,
|
||||
MAX(ttft_max_ms) AS ttft_max_ms,
|
||||
|
||||
NOW()
|
||||
FROM ops_metrics_hourly
|
||||
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||
GROUP BY 1, 2, 3
|
||||
ON CONFLICT (bucket_date, COALESCE(platform, ''), COALESCE(group_id, 0)) DO UPDATE SET
|
||||
success_count = EXCLUDED.success_count,
|
||||
error_count_total = EXCLUDED.error_count_total,
|
||||
business_limited_count = EXCLUDED.business_limited_count,
|
||||
error_count_sla = EXCLUDED.error_count_sla,
|
||||
upstream_error_count_excl_429_529 = EXCLUDED.upstream_error_count_excl_429_529,
|
||||
upstream_429_count = EXCLUDED.upstream_429_count,
|
||||
upstream_529_count = EXCLUDED.upstream_529_count,
|
||||
token_consumed = EXCLUDED.token_consumed,
|
||||
|
||||
duration_p50_ms = EXCLUDED.duration_p50_ms,
|
||||
duration_p90_ms = EXCLUDED.duration_p90_ms,
|
||||
duration_p95_ms = EXCLUDED.duration_p95_ms,
|
||||
duration_p99_ms = EXCLUDED.duration_p99_ms,
|
||||
duration_avg_ms = EXCLUDED.duration_avg_ms,
|
||||
duration_max_ms = EXCLUDED.duration_max_ms,
|
||||
|
||||
ttft_p50_ms = EXCLUDED.ttft_p50_ms,
|
||||
ttft_p90_ms = EXCLUDED.ttft_p90_ms,
|
||||
ttft_p95_ms = EXCLUDED.ttft_p95_ms,
|
||||
ttft_p99_ms = EXCLUDED.ttft_p99_ms,
|
||||
ttft_avg_ms = EXCLUDED.ttft_avg_ms,
|
||||
ttft_max_ms = EXCLUDED.ttft_max_ms,
|
||||
|
||||
computed_at = NOW()
|
||||
`
|
||||
|
||||
_, err := r.db.ExecContext(ctx, q, start, end)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *opsRepository) GetLatestHourlyBucketStart(ctx context.Context) (time.Time, bool, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return time.Time{}, false, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
|
||||
var value sql.NullTime
|
||||
if err := r.db.QueryRowContext(ctx, `SELECT MAX(bucket_start) FROM ops_metrics_hourly`).Scan(&value); err != nil {
|
||||
return time.Time{}, false, err
|
||||
}
|
||||
if !value.Valid {
|
||||
return time.Time{}, false, nil
|
||||
}
|
||||
return value.Time.UTC(), true, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) GetLatestDailyBucketDate(ctx context.Context) (time.Time, bool, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return time.Time{}, false, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
|
||||
var value sql.NullTime
|
||||
if err := r.db.QueryRowContext(ctx, `SELECT MAX(bucket_date) FROM ops_metrics_daily`).Scan(&value); err != nil {
|
||||
return time.Time{}, false, err
|
||||
}
|
||||
if !value.Valid {
|
||||
return time.Time{}, false, nil
|
||||
}
|
||||
t := value.Time.UTC()
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC), true, nil
|
||||
}
|
||||
129
backend/internal/repository/ops_repo_realtime_traffic.go
Normal file
129
backend/internal/repository/ops_repo_realtime_traffic.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (r *opsRepository) GetRealtimeTrafficSummary(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsRealtimeTrafficSummary, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, fmt.Errorf("nil filter")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, fmt.Errorf("start_time/end_time required")
|
||||
}
|
||||
|
||||
start := filter.StartTime.UTC()
|
||||
end := filter.EndTime.UTC()
|
||||
if start.After(end) {
|
||||
return nil, fmt.Errorf("start_time must be <= end_time")
|
||||
}
|
||||
|
||||
window := end.Sub(start)
|
||||
if window <= 0 {
|
||||
return nil, fmt.Errorf("invalid time window")
|
||||
}
|
||||
if window > time.Hour {
|
||||
return nil, fmt.Errorf("window too large")
|
||||
}
|
||||
|
||||
usageJoin, usageWhere, usageArgs, next := buildUsageWhere(filter, start, end, 1)
|
||||
errorWhere, errorArgs, _ := buildErrorWhere(filter, start, end, next)
|
||||
|
||||
q := `
|
||||
WITH usage_buckets AS (
|
||||
SELECT
|
||||
date_trunc('minute', ul.created_at) AS bucket,
|
||||
COALESCE(COUNT(*), 0) AS success_count,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_sum
|
||||
FROM usage_logs ul
|
||||
` + usageJoin + `
|
||||
` + usageWhere + `
|
||||
GROUP BY 1
|
||||
),
|
||||
error_buckets AS (
|
||||
SELECT
|
||||
date_trunc('minute', created_at) AS bucket,
|
||||
COALESCE(COUNT(*), 0) AS error_count
|
||||
FROM ops_error_logs
|
||||
` + errorWhere + `
|
||||
AND COALESCE(status_code, 0) >= 400
|
||||
GROUP BY 1
|
||||
),
|
||||
combined AS (
|
||||
SELECT
|
||||
COALESCE(u.bucket, e.bucket) AS bucket,
|
||||
COALESCE(u.success_count, 0) AS success_count,
|
||||
COALESCE(u.token_sum, 0) AS token_sum,
|
||||
COALESCE(e.error_count, 0) AS error_count,
|
||||
COALESCE(u.success_count, 0) + COALESCE(e.error_count, 0) AS request_total
|
||||
FROM usage_buckets u
|
||||
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket
|
||||
)
|
||||
SELECT
|
||||
COALESCE(SUM(success_count), 0) AS success_total,
|
||||
COALESCE(SUM(error_count), 0) AS error_total,
|
||||
COALESCE(SUM(token_sum), 0) AS token_total,
|
||||
COALESCE(MAX(request_total), 0) AS peak_requests_per_min,
|
||||
COALESCE(MAX(token_sum), 0) AS peak_tokens_per_min
|
||||
FROM combined`
|
||||
|
||||
args := append(usageArgs, errorArgs...)
|
||||
var successCount int64
|
||||
var errorTotal int64
|
||||
var tokenConsumed int64
|
||||
var peakRequestsPerMin int64
|
||||
var peakTokensPerMin int64
|
||||
if err := r.db.QueryRowContext(ctx, q, args...).Scan(
|
||||
&successCount,
|
||||
&errorTotal,
|
||||
&tokenConsumed,
|
||||
&peakRequestsPerMin,
|
||||
&peakTokensPerMin,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
windowSeconds := window.Seconds()
|
||||
if windowSeconds <= 0 {
|
||||
windowSeconds = 1
|
||||
}
|
||||
|
||||
requestCountTotal := successCount + errorTotal
|
||||
qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds)
|
||||
tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds)
|
||||
|
||||
// Keep "current" consistent with the dashboard overview semantics: last 1 minute.
|
||||
// This remains "within the selected window" since end=start+window.
|
||||
qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
qpsPeak := roundTo1DP(float64(peakRequestsPerMin) / 60.0)
|
||||
tpsPeak := roundTo1DP(float64(peakTokensPerMin) / 60.0)
|
||||
|
||||
return &service.OpsRealtimeTrafficSummary{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Platform: strings.TrimSpace(filter.Platform),
|
||||
GroupID: filter.GroupID,
|
||||
QPS: service.OpsRateSummary{
|
||||
Current: qpsCurrent,
|
||||
Peak: qpsPeak,
|
||||
Avg: qpsAvg,
|
||||
},
|
||||
TPS: service.OpsRateSummary{
|
||||
Current: tpsCurrent,
|
||||
Peak: tpsPeak,
|
||||
Avg: tpsAvg,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
286
backend/internal/repository/ops_repo_request_details.go
Normal file
286
backend/internal/repository/ops_repo_request_details.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (r *opsRepository) ListRequestDetails(ctx context.Context, filter *service.OpsRequestDetailFilter) ([]*service.OpsRequestDetail, int64, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, 0, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
|
||||
page, pageSize, startTime, endTime := filter.Normalize()
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
conditions := make([]string, 0, 16)
|
||||
args := make([]any, 0, 24)
|
||||
|
||||
// Placeholders $1/$2 reserved for time window inside the CTE.
|
||||
args = append(args, startTime.UTC(), endTime.UTC())
|
||||
|
||||
addCondition := func(condition string, values ...any) {
|
||||
conditions = append(conditions, condition)
|
||||
args = append(args, values...)
|
||||
}
|
||||
|
||||
if filter != nil {
|
||||
if kind := strings.TrimSpace(strings.ToLower(filter.Kind)); kind != "" && kind != "all" {
|
||||
if kind != string(service.OpsRequestKindSuccess) && kind != string(service.OpsRequestKindError) {
|
||||
return nil, 0, fmt.Errorf("invalid kind")
|
||||
}
|
||||
addCondition(fmt.Sprintf("kind = $%d", len(args)+1), kind)
|
||||
}
|
||||
|
||||
if platform := strings.TrimSpace(strings.ToLower(filter.Platform)); platform != "" {
|
||||
addCondition(fmt.Sprintf("platform = $%d", len(args)+1), platform)
|
||||
}
|
||||
if filter.GroupID != nil && *filter.GroupID > 0 {
|
||||
addCondition(fmt.Sprintf("group_id = $%d", len(args)+1), *filter.GroupID)
|
||||
}
|
||||
|
||||
if filter.UserID != nil && *filter.UserID > 0 {
|
||||
addCondition(fmt.Sprintf("user_id = $%d", len(args)+1), *filter.UserID)
|
||||
}
|
||||
if filter.APIKeyID != nil && *filter.APIKeyID > 0 {
|
||||
addCondition(fmt.Sprintf("api_key_id = $%d", len(args)+1), *filter.APIKeyID)
|
||||
}
|
||||
if filter.AccountID != nil && *filter.AccountID > 0 {
|
||||
addCondition(fmt.Sprintf("account_id = $%d", len(args)+1), *filter.AccountID)
|
||||
}
|
||||
|
||||
if model := strings.TrimSpace(filter.Model); model != "" {
|
||||
addCondition(fmt.Sprintf("model = $%d", len(args)+1), model)
|
||||
}
|
||||
if requestID := strings.TrimSpace(filter.RequestID); requestID != "" {
|
||||
addCondition(fmt.Sprintf("request_id = $%d", len(args)+1), requestID)
|
||||
}
|
||||
if q := strings.TrimSpace(filter.Query); q != "" {
|
||||
like := "%" + strings.ToLower(q) + "%"
|
||||
startIdx := len(args) + 1
|
||||
addCondition(
|
||||
fmt.Sprintf("(LOWER(COALESCE(request_id,'')) LIKE $%d OR LOWER(COALESCE(model,'')) LIKE $%d OR LOWER(COALESCE(message,'')) LIKE $%d)",
|
||||
startIdx, startIdx+1, startIdx+2,
|
||||
),
|
||||
like, like, like,
|
||||
)
|
||||
}
|
||||
|
||||
if filter.MinDurationMs != nil {
|
||||
addCondition(fmt.Sprintf("duration_ms >= $%d", len(args)+1), *filter.MinDurationMs)
|
||||
}
|
||||
if filter.MaxDurationMs != nil {
|
||||
addCondition(fmt.Sprintf("duration_ms <= $%d", len(args)+1), *filter.MaxDurationMs)
|
||||
}
|
||||
}
|
||||
|
||||
where := ""
|
||||
if len(conditions) > 0 {
|
||||
where = "WHERE " + strings.Join(conditions, " AND ")
|
||||
}
|
||||
|
||||
cte := `
|
||||
WITH combined AS (
|
||||
SELECT
|
||||
'success'::TEXT AS kind,
|
||||
ul.created_at AS created_at,
|
||||
ul.request_id AS request_id,
|
||||
COALESCE(NULLIF(g.platform, ''), NULLIF(a.platform, ''), '') AS platform,
|
||||
ul.model AS model,
|
||||
ul.duration_ms AS duration_ms,
|
||||
NULL::INT AS status_code,
|
||||
NULL::BIGINT AS error_id,
|
||||
NULL::TEXT AS phase,
|
||||
NULL::TEXT AS severity,
|
||||
NULL::TEXT AS message,
|
||||
ul.user_id AS user_id,
|
||||
ul.api_key_id AS api_key_id,
|
||||
ul.account_id AS account_id,
|
||||
ul.group_id AS group_id,
|
||||
ul.stream AS stream
|
||||
FROM usage_logs ul
|
||||
LEFT JOIN groups g ON g.id = ul.group_id
|
||||
LEFT JOIN accounts a ON a.id = ul.account_id
|
||||
WHERE ul.created_at >= $1 AND ul.created_at < $2
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT
|
||||
'error'::TEXT AS kind,
|
||||
o.created_at AS created_at,
|
||||
COALESCE(NULLIF(o.request_id,''), NULLIF(o.client_request_id,''), '') AS request_id,
|
||||
COALESCE(NULLIF(o.platform, ''), NULLIF(g.platform, ''), NULLIF(a.platform, ''), '') AS platform,
|
||||
o.model AS model,
|
||||
o.duration_ms AS duration_ms,
|
||||
o.status_code AS status_code,
|
||||
o.id AS error_id,
|
||||
o.error_phase AS phase,
|
||||
o.severity AS severity,
|
||||
o.error_message AS message,
|
||||
o.user_id AS user_id,
|
||||
o.api_key_id AS api_key_id,
|
||||
o.account_id AS account_id,
|
||||
o.group_id AS group_id,
|
||||
o.stream AS stream
|
||||
FROM ops_error_logs o
|
||||
LEFT JOIN groups g ON g.id = o.group_id
|
||||
LEFT JOIN accounts a ON a.id = o.account_id
|
||||
WHERE o.created_at >= $1 AND o.created_at < $2
|
||||
AND COALESCE(o.status_code, 0) >= 400
|
||||
)
|
||||
`
|
||||
|
||||
countQuery := fmt.Sprintf(`%s SELECT COUNT(1) FROM combined %s`, cte, where)
|
||||
var total int64
|
||||
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
total = 0
|
||||
} else {
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
sort := "ORDER BY created_at DESC"
|
||||
if filter != nil {
|
||||
switch strings.TrimSpace(strings.ToLower(filter.Sort)) {
|
||||
case "", "created_at_desc":
|
||||
// default
|
||||
case "duration_desc":
|
||||
sort = "ORDER BY duration_ms DESC NULLS LAST, created_at DESC"
|
||||
default:
|
||||
return nil, 0, fmt.Errorf("invalid sort")
|
||||
}
|
||||
}
|
||||
|
||||
listQuery := fmt.Sprintf(`
|
||||
%s
|
||||
SELECT
|
||||
kind,
|
||||
created_at,
|
||||
request_id,
|
||||
platform,
|
||||
model,
|
||||
duration_ms,
|
||||
status_code,
|
||||
error_id,
|
||||
phase,
|
||||
severity,
|
||||
message,
|
||||
user_id,
|
||||
api_key_id,
|
||||
account_id,
|
||||
group_id,
|
||||
stream
|
||||
FROM combined
|
||||
%s
|
||||
%s
|
||||
LIMIT $%d OFFSET $%d
|
||||
`, cte, where, sort, len(args)+1, len(args)+2)
|
||||
|
||||
listArgs := append(append([]any{}, args...), pageSize, offset)
|
||||
rows, err := r.db.QueryContext(ctx, listQuery, listArgs...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
toIntPtr := func(v sql.NullInt64) *int {
|
||||
if !v.Valid {
|
||||
return nil
|
||||
}
|
||||
i := int(v.Int64)
|
||||
return &i
|
||||
}
|
||||
toInt64Ptr := func(v sql.NullInt64) *int64 {
|
||||
if !v.Valid {
|
||||
return nil
|
||||
}
|
||||
i := v.Int64
|
||||
return &i
|
||||
}
|
||||
|
||||
out := make([]*service.OpsRequestDetail, 0, pageSize)
|
||||
for rows.Next() {
|
||||
var (
|
||||
kind string
|
||||
createdAt time.Time
|
||||
requestID sql.NullString
|
||||
platform sql.NullString
|
||||
model sql.NullString
|
||||
|
||||
durationMs sql.NullInt64
|
||||
statusCode sql.NullInt64
|
||||
errorID sql.NullInt64
|
||||
|
||||
phase sql.NullString
|
||||
severity sql.NullString
|
||||
message sql.NullString
|
||||
|
||||
userID sql.NullInt64
|
||||
apiKeyID sql.NullInt64
|
||||
accountID sql.NullInt64
|
||||
groupID sql.NullInt64
|
||||
|
||||
stream bool
|
||||
)
|
||||
|
||||
if err := rows.Scan(
|
||||
&kind,
|
||||
&createdAt,
|
||||
&requestID,
|
||||
&platform,
|
||||
&model,
|
||||
&durationMs,
|
||||
&statusCode,
|
||||
&errorID,
|
||||
&phase,
|
||||
&severity,
|
||||
&message,
|
||||
&userID,
|
||||
&apiKeyID,
|
||||
&accountID,
|
||||
&groupID,
|
||||
&stream,
|
||||
); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
item := &service.OpsRequestDetail{
|
||||
Kind: service.OpsRequestKind(kind),
|
||||
CreatedAt: createdAt,
|
||||
RequestID: strings.TrimSpace(requestID.String),
|
||||
Platform: strings.TrimSpace(platform.String),
|
||||
Model: strings.TrimSpace(model.String),
|
||||
|
||||
DurationMs: toIntPtr(durationMs),
|
||||
StatusCode: toIntPtr(statusCode),
|
||||
ErrorID: toInt64Ptr(errorID),
|
||||
Phase: phase.String,
|
||||
Severity: severity.String,
|
||||
Message: message.String,
|
||||
|
||||
UserID: toInt64Ptr(userID),
|
||||
APIKeyID: toInt64Ptr(apiKeyID),
|
||||
AccountID: toInt64Ptr(accountID),
|
||||
GroupID: toInt64Ptr(groupID),
|
||||
|
||||
Stream: stream,
|
||||
}
|
||||
|
||||
if item.Platform == "" {
|
||||
item.Platform = "unknown"
|
||||
}
|
||||
|
||||
out = append(out, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return out, total, nil
|
||||
}
|
||||
573
backend/internal/repository/ops_repo_trends.go
Normal file
573
backend/internal/repository/ops_repo_trends.go
Normal file
@@ -0,0 +1,573 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (r *opsRepository) GetThroughputTrend(ctx context.Context, filter *service.OpsDashboardFilter, bucketSeconds int) (*service.OpsThroughputTrendResponse, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, fmt.Errorf("nil filter")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, fmt.Errorf("start_time/end_time required")
|
||||
}
|
||||
|
||||
if bucketSeconds <= 0 {
|
||||
bucketSeconds = 60
|
||||
}
|
||||
if bucketSeconds != 60 && bucketSeconds != 300 && bucketSeconds != 3600 {
|
||||
// Keep a small, predictable set of supported buckets for now.
|
||||
bucketSeconds = 60
|
||||
}
|
||||
|
||||
start := filter.StartTime.UTC()
|
||||
end := filter.EndTime.UTC()
|
||||
|
||||
usageJoin, usageWhere, usageArgs, next := buildUsageWhere(filter, start, end, 1)
|
||||
errorWhere, errorArgs, _ := buildErrorWhere(filter, start, end, next)
|
||||
|
||||
usageBucketExpr := opsBucketExprForUsage(bucketSeconds)
|
||||
errorBucketExpr := opsBucketExprForError(bucketSeconds)
|
||||
|
||||
q := `
|
||||
WITH usage_buckets AS (
|
||||
SELECT ` + usageBucketExpr + ` AS bucket,
|
||||
COUNT(*) AS success_count,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_consumed
|
||||
FROM usage_logs ul
|
||||
` + usageJoin + `
|
||||
` + usageWhere + `
|
||||
GROUP BY 1
|
||||
),
|
||||
error_buckets AS (
|
||||
SELECT ` + errorBucketExpr + ` AS bucket,
|
||||
COUNT(*) AS error_count
|
||||
FROM ops_error_logs
|
||||
` + errorWhere + `
|
||||
AND COALESCE(status_code, 0) >= 400
|
||||
GROUP BY 1
|
||||
),
|
||||
combined AS (
|
||||
SELECT COALESCE(u.bucket, e.bucket) AS bucket,
|
||||
COALESCE(u.success_count, 0) AS success_count,
|
||||
COALESCE(e.error_count, 0) AS error_count,
|
||||
COALESCE(u.token_consumed, 0) AS token_consumed
|
||||
FROM usage_buckets u
|
||||
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket
|
||||
)
|
||||
SELECT
|
||||
bucket,
|
||||
(success_count + error_count) AS request_count,
|
||||
token_consumed
|
||||
FROM combined
|
||||
ORDER BY bucket ASC`
|
||||
|
||||
args := append(usageArgs, errorArgs...)
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
points := make([]*service.OpsThroughputTrendPoint, 0, 256)
|
||||
for rows.Next() {
|
||||
var bucket time.Time
|
||||
var requests int64
|
||||
var tokens sql.NullInt64
|
||||
if err := rows.Scan(&bucket, &requests, &tokens); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenConsumed := int64(0)
|
||||
if tokens.Valid {
|
||||
tokenConsumed = tokens.Int64
|
||||
}
|
||||
|
||||
denom := float64(bucketSeconds)
|
||||
if denom <= 0 {
|
||||
denom = 60
|
||||
}
|
||||
qps := roundTo1DP(float64(requests) / denom)
|
||||
tps := roundTo1DP(float64(tokenConsumed) / denom)
|
||||
|
||||
points = append(points, &service.OpsThroughputTrendPoint{
|
||||
BucketStart: bucket.UTC(),
|
||||
RequestCount: requests,
|
||||
TokenConsumed: tokenConsumed,
|
||||
QPS: qps,
|
||||
TPS: tps,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fill missing buckets with zeros so charts render continuous timelines.
|
||||
points = fillOpsThroughputBuckets(start, end, bucketSeconds, points)
|
||||
|
||||
var byPlatform []*service.OpsThroughputPlatformBreakdownItem
|
||||
var topGroups []*service.OpsThroughputGroupBreakdownItem
|
||||
|
||||
platform := ""
|
||||
if filter != nil {
|
||||
platform = strings.TrimSpace(strings.ToLower(filter.Platform))
|
||||
}
|
||||
groupID := (*int64)(nil)
|
||||
if filter != nil {
|
||||
groupID = filter.GroupID
|
||||
}
|
||||
|
||||
// Drilldown helpers:
|
||||
// - No platform/group: totals by platform
|
||||
// - Platform selected but no group: top groups in that platform
|
||||
if platform == "" && (groupID == nil || *groupID <= 0) {
|
||||
items, err := r.getThroughputBreakdownByPlatform(ctx, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
byPlatform = items
|
||||
} else if platform != "" && (groupID == nil || *groupID <= 0) {
|
||||
items, err := r.getThroughputTopGroupsByPlatform(ctx, start, end, platform, 10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
topGroups = items
|
||||
}
|
||||
|
||||
return &service.OpsThroughputTrendResponse{
|
||||
Bucket: opsBucketLabel(bucketSeconds),
|
||||
Points: points,
|
||||
|
||||
ByPlatform: byPlatform,
|
||||
TopGroups: topGroups,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) getThroughputBreakdownByPlatform(ctx context.Context, start, end time.Time) ([]*service.OpsThroughputPlatformBreakdownItem, error) {
|
||||
q := `
|
||||
WITH usage_totals AS (
|
||||
SELECT COALESCE(NULLIF(g.platform,''), a.platform) AS platform,
|
||||
COUNT(*) AS success_count,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_consumed
|
||||
FROM usage_logs ul
|
||||
LEFT JOIN groups g ON g.id = ul.group_id
|
||||
LEFT JOIN accounts a ON a.id = ul.account_id
|
||||
WHERE ul.created_at >= $1 AND ul.created_at < $2
|
||||
GROUP BY 1
|
||||
),
|
||||
error_totals AS (
|
||||
SELECT platform,
|
||||
COUNT(*) AS error_count
|
||||
FROM ops_error_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
AND COALESCE(status_code, 0) >= 400
|
||||
AND is_count_tokens = FALSE -- 排除 count_tokens 请求的错误
|
||||
GROUP BY 1
|
||||
),
|
||||
combined AS (
|
||||
SELECT COALESCE(u.platform, e.platform) AS platform,
|
||||
COALESCE(u.success_count, 0) AS success_count,
|
||||
COALESCE(e.error_count, 0) AS error_count,
|
||||
COALESCE(u.token_consumed, 0) AS token_consumed
|
||||
FROM usage_totals u
|
||||
FULL OUTER JOIN error_totals e ON u.platform = e.platform
|
||||
)
|
||||
SELECT platform, (success_count + error_count) AS request_count, token_consumed
|
||||
FROM combined
|
||||
WHERE platform IS NOT NULL AND platform <> ''
|
||||
ORDER BY request_count DESC`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
items := make([]*service.OpsThroughputPlatformBreakdownItem, 0, 8)
|
||||
for rows.Next() {
|
||||
var platform string
|
||||
var requests int64
|
||||
var tokens sql.NullInt64
|
||||
if err := rows.Scan(&platform, &requests, &tokens); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenConsumed := int64(0)
|
||||
if tokens.Valid {
|
||||
tokenConsumed = tokens.Int64
|
||||
}
|
||||
items = append(items, &service.OpsThroughputPlatformBreakdownItem{
|
||||
Platform: platform,
|
||||
RequestCount: requests,
|
||||
TokenConsumed: tokenConsumed,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) getThroughputTopGroupsByPlatform(ctx context.Context, start, end time.Time, platform string, limit int) ([]*service.OpsThroughputGroupBreakdownItem, error) {
|
||||
if strings.TrimSpace(platform) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
if limit <= 0 || limit > 100 {
|
||||
limit = 10
|
||||
}
|
||||
|
||||
q := `
|
||||
WITH usage_totals AS (
|
||||
SELECT ul.group_id AS group_id,
|
||||
g.name AS group_name,
|
||||
COUNT(*) AS success_count,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_consumed
|
||||
FROM usage_logs ul
|
||||
JOIN groups g ON g.id = ul.group_id
|
||||
WHERE ul.created_at >= $1 AND ul.created_at < $2
|
||||
AND g.platform = $3
|
||||
GROUP BY 1, 2
|
||||
),
|
||||
error_totals AS (
|
||||
SELECT group_id,
|
||||
COUNT(*) AS error_count
|
||||
FROM ops_error_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
AND platform = $3
|
||||
AND group_id IS NOT NULL
|
||||
AND COALESCE(status_code, 0) >= 400
|
||||
AND is_count_tokens = FALSE -- 排除 count_tokens 请求的错误
|
||||
GROUP BY 1
|
||||
),
|
||||
combined AS (
|
||||
SELECT COALESCE(u.group_id, e.group_id) AS group_id,
|
||||
COALESCE(u.group_name, g2.name, '') AS group_name,
|
||||
COALESCE(u.success_count, 0) AS success_count,
|
||||
COALESCE(e.error_count, 0) AS error_count,
|
||||
COALESCE(u.token_consumed, 0) AS token_consumed
|
||||
FROM usage_totals u
|
||||
FULL OUTER JOIN error_totals e ON u.group_id = e.group_id
|
||||
LEFT JOIN groups g2 ON g2.id = COALESCE(u.group_id, e.group_id)
|
||||
)
|
||||
SELECT group_id, group_name, (success_count + error_count) AS request_count, token_consumed
|
||||
FROM combined
|
||||
WHERE group_id IS NOT NULL
|
||||
ORDER BY request_count DESC
|
||||
LIMIT $4`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q, start, end, platform, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
items := make([]*service.OpsThroughputGroupBreakdownItem, 0, limit)
|
||||
for rows.Next() {
|
||||
var groupID int64
|
||||
var groupName sql.NullString
|
||||
var requests int64
|
||||
var tokens sql.NullInt64
|
||||
if err := rows.Scan(&groupID, &groupName, &requests, &tokens); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenConsumed := int64(0)
|
||||
if tokens.Valid {
|
||||
tokenConsumed = tokens.Int64
|
||||
}
|
||||
name := ""
|
||||
if groupName.Valid {
|
||||
name = groupName.String
|
||||
}
|
||||
items = append(items, &service.OpsThroughputGroupBreakdownItem{
|
||||
GroupID: groupID,
|
||||
GroupName: name,
|
||||
RequestCount: requests,
|
||||
TokenConsumed: tokenConsumed,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func opsBucketExprForUsage(bucketSeconds int) string {
|
||||
switch bucketSeconds {
|
||||
case 3600:
|
||||
return "date_trunc('hour', ul.created_at)"
|
||||
case 300:
|
||||
// 5-minute buckets in UTC.
|
||||
return "to_timestamp(floor(extract(epoch from ul.created_at) / 300) * 300)"
|
||||
default:
|
||||
return "date_trunc('minute', ul.created_at)"
|
||||
}
|
||||
}
|
||||
|
||||
func opsBucketExprForError(bucketSeconds int) string {
|
||||
switch bucketSeconds {
|
||||
case 3600:
|
||||
return "date_trunc('hour', created_at)"
|
||||
case 300:
|
||||
return "to_timestamp(floor(extract(epoch from created_at) / 300) * 300)"
|
||||
default:
|
||||
return "date_trunc('minute', created_at)"
|
||||
}
|
||||
}
|
||||
|
||||
func opsBucketLabel(bucketSeconds int) string {
|
||||
if bucketSeconds <= 0 {
|
||||
return "1m"
|
||||
}
|
||||
if bucketSeconds%3600 == 0 {
|
||||
h := bucketSeconds / 3600
|
||||
if h <= 0 {
|
||||
h = 1
|
||||
}
|
||||
return fmt.Sprintf("%dh", h)
|
||||
}
|
||||
m := bucketSeconds / 60
|
||||
if m <= 0 {
|
||||
m = 1
|
||||
}
|
||||
return fmt.Sprintf("%dm", m)
|
||||
}
|
||||
|
||||
func opsFloorToBucketStart(t time.Time, bucketSeconds int) time.Time {
|
||||
t = t.UTC()
|
||||
if bucketSeconds <= 0 {
|
||||
bucketSeconds = 60
|
||||
}
|
||||
secs := t.Unix()
|
||||
floored := secs - (secs % int64(bucketSeconds))
|
||||
return time.Unix(floored, 0).UTC()
|
||||
}
|
||||
|
||||
func fillOpsThroughputBuckets(start, end time.Time, bucketSeconds int, points []*service.OpsThroughputTrendPoint) []*service.OpsThroughputTrendPoint {
|
||||
if bucketSeconds <= 0 {
|
||||
bucketSeconds = 60
|
||||
}
|
||||
if !start.Before(end) {
|
||||
return points
|
||||
}
|
||||
|
||||
endMinus := end.Add(-time.Nanosecond)
|
||||
if endMinus.Before(start) {
|
||||
return points
|
||||
}
|
||||
|
||||
first := opsFloorToBucketStart(start, bucketSeconds)
|
||||
last := opsFloorToBucketStart(endMinus, bucketSeconds)
|
||||
step := time.Duration(bucketSeconds) * time.Second
|
||||
|
||||
existing := make(map[int64]*service.OpsThroughputTrendPoint, len(points))
|
||||
for _, p := range points {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
existing[p.BucketStart.UTC().Unix()] = p
|
||||
}
|
||||
|
||||
out := make([]*service.OpsThroughputTrendPoint, 0, int(last.Sub(first)/step)+1)
|
||||
for cursor := first; !cursor.After(last); cursor = cursor.Add(step) {
|
||||
if p, ok := existing[cursor.Unix()]; ok && p != nil {
|
||||
out = append(out, p)
|
||||
continue
|
||||
}
|
||||
out = append(out, &service.OpsThroughputTrendPoint{
|
||||
BucketStart: cursor,
|
||||
RequestCount: 0,
|
||||
TokenConsumed: 0,
|
||||
QPS: 0,
|
||||
TPS: 0,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (r *opsRepository) GetErrorTrend(ctx context.Context, filter *service.OpsDashboardFilter, bucketSeconds int) (*service.OpsErrorTrendResponse, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, fmt.Errorf("nil filter")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, fmt.Errorf("start_time/end_time required")
|
||||
}
|
||||
|
||||
if bucketSeconds <= 0 {
|
||||
bucketSeconds = 60
|
||||
}
|
||||
if bucketSeconds != 60 && bucketSeconds != 300 && bucketSeconds != 3600 {
|
||||
bucketSeconds = 60
|
||||
}
|
||||
|
||||
start := filter.StartTime.UTC()
|
||||
end := filter.EndTime.UTC()
|
||||
where, args, _ := buildErrorWhere(filter, start, end, 1)
|
||||
bucketExpr := opsBucketExprForError(bucketSeconds)
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
` + bucketExpr + ` AS bucket,
|
||||
COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400) AS error_total,
|
||||
COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400 AND is_business_limited) AS business_limited,
|
||||
COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400 AND NOT is_business_limited) AS error_sla,
|
||||
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)) AS upstream_excl,
|
||||
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) = 429) AS upstream_429,
|
||||
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) = 529) AS upstream_529
|
||||
FROM ops_error_logs
|
||||
` + where + `
|
||||
GROUP BY 1
|
||||
ORDER BY 1 ASC`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
points := make([]*service.OpsErrorTrendPoint, 0, 256)
|
||||
for rows.Next() {
|
||||
var bucket time.Time
|
||||
var total, businessLimited, sla, upstreamExcl, upstream429, upstream529 int64
|
||||
if err := rows.Scan(&bucket, &total, &businessLimited, &sla, &upstreamExcl, &upstream429, &upstream529); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
points = append(points, &service.OpsErrorTrendPoint{
|
||||
BucketStart: bucket.UTC(),
|
||||
|
||||
ErrorCountTotal: total,
|
||||
BusinessLimitedCount: businessLimited,
|
||||
ErrorCountSLA: sla,
|
||||
|
||||
UpstreamErrorCountExcl429529: upstreamExcl,
|
||||
Upstream429Count: upstream429,
|
||||
Upstream529Count: upstream529,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
points = fillOpsErrorTrendBuckets(start, end, bucketSeconds, points)
|
||||
|
||||
return &service.OpsErrorTrendResponse{
|
||||
Bucket: opsBucketLabel(bucketSeconds),
|
||||
Points: points,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func fillOpsErrorTrendBuckets(start, end time.Time, bucketSeconds int, points []*service.OpsErrorTrendPoint) []*service.OpsErrorTrendPoint {
|
||||
if bucketSeconds <= 0 {
|
||||
bucketSeconds = 60
|
||||
}
|
||||
if !start.Before(end) {
|
||||
return points
|
||||
}
|
||||
|
||||
endMinus := end.Add(-time.Nanosecond)
|
||||
if endMinus.Before(start) {
|
||||
return points
|
||||
}
|
||||
|
||||
first := opsFloorToBucketStart(start, bucketSeconds)
|
||||
last := opsFloorToBucketStart(endMinus, bucketSeconds)
|
||||
step := time.Duration(bucketSeconds) * time.Second
|
||||
|
||||
existing := make(map[int64]*service.OpsErrorTrendPoint, len(points))
|
||||
for _, p := range points {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
existing[p.BucketStart.UTC().Unix()] = p
|
||||
}
|
||||
|
||||
out := make([]*service.OpsErrorTrendPoint, 0, int(last.Sub(first)/step)+1)
|
||||
for cursor := first; !cursor.After(last); cursor = cursor.Add(step) {
|
||||
if p, ok := existing[cursor.Unix()]; ok && p != nil {
|
||||
out = append(out, p)
|
||||
continue
|
||||
}
|
||||
out = append(out, &service.OpsErrorTrendPoint{
|
||||
BucketStart: cursor,
|
||||
|
||||
ErrorCountTotal: 0,
|
||||
BusinessLimitedCount: 0,
|
||||
ErrorCountSLA: 0,
|
||||
|
||||
UpstreamErrorCountExcl429529: 0,
|
||||
Upstream429Count: 0,
|
||||
Upstream529Count: 0,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (r *opsRepository) GetErrorDistribution(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsErrorDistributionResponse, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, fmt.Errorf("nil filter")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, fmt.Errorf("start_time/end_time required")
|
||||
}
|
||||
|
||||
start := filter.StartTime.UTC()
|
||||
end := filter.EndTime.UTC()
|
||||
where, args, _ := buildErrorWhere(filter, start, end, 1)
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
COALESCE(upstream_status_code, status_code, 0) AS status_code,
|
||||
COUNT(*) AS total,
|
||||
COUNT(*) FILTER (WHERE NOT is_business_limited) AS sla,
|
||||
COUNT(*) FILTER (WHERE is_business_limited) AS business_limited
|
||||
FROM ops_error_logs
|
||||
` + where + `
|
||||
AND COALESCE(status_code, 0) >= 400
|
||||
GROUP BY 1
|
||||
ORDER BY total DESC
|
||||
LIMIT 20`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
items := make([]*service.OpsErrorDistributionItem, 0, 16)
|
||||
var total int64
|
||||
for rows.Next() {
|
||||
var statusCode int
|
||||
var cntTotal, cntSLA, cntBiz int64
|
||||
if err := rows.Scan(&statusCode, &cntTotal, &cntSLA, &cntBiz); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
total += cntTotal
|
||||
items = append(items, &service.OpsErrorDistributionItem{
|
||||
StatusCode: statusCode,
|
||||
Total: cntTotal,
|
||||
SLA: cntSLA,
|
||||
BusinessLimited: cntBiz,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &service.OpsErrorDistributionResponse{
|
||||
Total: total,
|
||||
Items: items,
|
||||
}, nil
|
||||
}
|
||||
50
backend/internal/repository/ops_repo_window_stats.go
Normal file
50
backend/internal/repository/ops_repo_window_stats.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (r *opsRepository) GetWindowStats(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsWindowStats, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, fmt.Errorf("nil filter")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, fmt.Errorf("start_time/end_time required")
|
||||
}
|
||||
|
||||
start := filter.StartTime.UTC()
|
||||
end := filter.EndTime.UTC()
|
||||
if start.After(end) {
|
||||
return nil, fmt.Errorf("start_time must be <= end_time")
|
||||
}
|
||||
// Bound excessively large windows to prevent accidental heavy queries.
|
||||
if end.Sub(start) > 24*time.Hour {
|
||||
return nil, fmt.Errorf("window too large")
|
||||
}
|
||||
|
||||
successCount, tokenConsumed, err := r.queryUsageCounts(ctx, filter, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
errorTotal, _, _, _, _, _, err := r.queryErrorCounts(ctx, filter, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &service.OpsWindowStats{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
|
||||
SuccessCount: successCount,
|
||||
ErrorCountTotal: errorTotal,
|
||||
TokenConsumed: tokenConsumed,
|
||||
}, nil
|
||||
}
|
||||
16
backend/internal/repository/pagination.go
Normal file
16
backend/internal/repository/pagination.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package repository
|
||||
|
||||
import "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
func paginationResultFromTotal(total int64, params pagination.PaginationParams) *pagination.PaginationResult {
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
}
|
||||
return &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}
|
||||
}
|
||||
81
backend/internal/repository/pricing_service.go
Normal file
81
backend/internal/repository/pricing_service.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type pricingRemoteClient struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewPricingRemoteClient 创建定价数据远程客户端
|
||||
// proxyURL 为空时直连,支持 http/https/socks5/socks5h 协议
|
||||
func NewPricingRemoteClient(proxyURL string) service.PricingRemoteClient {
|
||||
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 30 * time.Second,
|
||||
ProxyURL: proxyURL,
|
||||
})
|
||||
if err != nil {
|
||||
sharedClient = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
return &pricingRemoteClient{
|
||||
httpClient: sharedClient,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *pricingRemoteClient) FetchPricingJSON(ctx context.Context, url string) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
func (c *pricingRemoteClient) FetchHashText(ctx context.Context, url string) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 哈希文件格式:hash filename 或者纯 hash
|
||||
hash := strings.TrimSpace(string(body))
|
||||
parts := strings.Fields(hash)
|
||||
if len(parts) > 0 {
|
||||
return parts[0], nil
|
||||
}
|
||||
return hash, nil
|
||||
}
|
||||
145
backend/internal/repository/pricing_service_test.go
Normal file
145
backend/internal/repository/pricing_service_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type PricingServiceSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
srv *httptest.Server
|
||||
client *pricingRemoteClient
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
client, ok := NewPricingRemoteClient("").(*pricingRemoteClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) setupServer(handler http.HandlerFunc) {
|
||||
s.srv = newLocalTestServer(s.T(), handler)
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchPricingJSON_Success() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/ok" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"ok":true}`))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
|
||||
body, err := s.client.FetchPricingJSON(s.ctx, s.srv.URL+"/ok")
|
||||
require.NoError(s.T(), err, "FetchPricingJSON")
|
||||
require.Equal(s.T(), `{"ok":true}`, string(body), "body mismatch")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchPricingJSON_NonOKStatus() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
|
||||
_, err := s.client.FetchPricingJSON(s.ctx, s.srv.URL+"/err")
|
||||
require.Error(s.T(), err, "expected error for non-200 status")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchHashText_ParsesFields() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/hashfile":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("abc123 model_prices.json\n"))
|
||||
case "/hashonly":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("def456\n"))
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
|
||||
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/hashfile")
|
||||
require.NoError(s.T(), err, "FetchHashText")
|
||||
require.Equal(s.T(), "abc123", hash, "hash mismatch")
|
||||
|
||||
hash2, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/hashonly")
|
||||
require.NoError(s.T(), err, "FetchHashText")
|
||||
require.Equal(s.T(), "def456", hash2, "hash mismatch")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchHashText_NonOKStatus() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
|
||||
_, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/nope")
|
||||
require.Error(s.T(), err, "expected error for non-200 status")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchPricingJSON_InvalidURL() {
|
||||
_, err := s.client.FetchPricingJSON(s.ctx, "://invalid-url")
|
||||
require.Error(s.T(), err, "expected error for invalid URL")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchHashText_EmptyBody() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
// empty body
|
||||
}))
|
||||
|
||||
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/empty")
|
||||
require.NoError(s.T(), err, "FetchHashText empty body should not error")
|
||||
require.Equal(s.T(), "", hash, "expected empty hash")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchHashText_WhitespaceOnly() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(" \n"))
|
||||
}))
|
||||
|
||||
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/ws")
|
||||
require.NoError(s.T(), err, "FetchHashText whitespace body should not error")
|
||||
require.Equal(s.T(), "", hash, "expected empty hash after trimming")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() {
|
||||
started := make(chan struct{})
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
close(started)
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
ctx, cancel := context.WithCancel(s.ctx)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := s.client.FetchPricingJSON(ctx, s.srv.URL+"/block")
|
||||
done <- err
|
||||
}()
|
||||
|
||||
<-started
|
||||
cancel()
|
||||
|
||||
err := <-done
|
||||
require.Error(s.T(), err)
|
||||
}
|
||||
|
||||
func TestPricingServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(PricingServiceSuite))
|
||||
}
|
||||
273
backend/internal/repository/promo_code_repo.go
Normal file
273
backend/internal/repository/promo_code_repo.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type promoCodeRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewPromoCodeRepository(client *dbent.Client) service.PromoCodeRepository {
|
||||
return &promoCodeRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) Create(ctx context.Context, code *service.PromoCode) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
builder := client.PromoCode.Create().
|
||||
SetCode(code.Code).
|
||||
SetBonusAmount(code.BonusAmount).
|
||||
SetMaxUses(code.MaxUses).
|
||||
SetUsedCount(code.UsedCount).
|
||||
SetStatus(code.Status).
|
||||
SetNotes(code.Notes)
|
||||
|
||||
if code.ExpiresAt != nil {
|
||||
builder.SetExpiresAt(*code.ExpiresAt)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
code.ID = created.ID
|
||||
code.CreatedAt = created.CreatedAt
|
||||
code.UpdatedAt = created.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetByID(ctx context.Context, id int64) (*service.PromoCode, error) {
|
||||
m, err := r.client.PromoCode.Query().
|
||||
Where(promocode.IDEQ(id)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrPromoCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetByCode(ctx context.Context, code string) (*service.PromoCode, error) {
|
||||
m, err := r.client.PromoCode.Query().
|
||||
Where(promocode.CodeEqualFold(code)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrPromoCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetByCodeForUpdate(ctx context.Context, code string) (*service.PromoCode, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
m, err := client.PromoCode.Query().
|
||||
Where(promocode.CodeEqualFold(code)).
|
||||
ForUpdate().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrPromoCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) Update(ctx context.Context, code *service.PromoCode) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
builder := client.PromoCode.UpdateOneID(code.ID).
|
||||
SetCode(code.Code).
|
||||
SetBonusAmount(code.BonusAmount).
|
||||
SetMaxUses(code.MaxUses).
|
||||
SetUsedCount(code.UsedCount).
|
||||
SetStatus(code.Status).
|
||||
SetNotes(code.Notes)
|
||||
|
||||
if code.ExpiresAt != nil {
|
||||
builder.SetExpiresAt(*code.ExpiresAt)
|
||||
} else {
|
||||
builder.ClearExpiresAt()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrPromoCodeNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
code.UpdatedAt = updated.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) Delete(ctx context.Context, id int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.PromoCode.Delete().Where(promocode.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.PromoCode, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "")
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, search string) ([]service.PromoCode, *pagination.PaginationResult, error) {
|
||||
q := r.client.PromoCode.Query()
|
||||
|
||||
if status != "" {
|
||||
q = q.Where(promocode.StatusEQ(status))
|
||||
}
|
||||
if search != "" {
|
||||
q = q.Where(promocode.CodeContainsFold(search))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
codes, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(promocode.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outCodes := promoCodeEntitiesToService(codes)
|
||||
|
||||
return outCodes, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) CreateUsage(ctx context.Context, usage *service.PromoCodeUsage) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
created, err := client.PromoCodeUsage.Create().
|
||||
SetPromoCodeID(usage.PromoCodeID).
|
||||
SetUserID(usage.UserID).
|
||||
SetBonusAmount(usage.BonusAmount).
|
||||
SetUsedAt(usage.UsedAt).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
usage.ID = created.ID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetUsageByPromoCodeAndUser(ctx context.Context, promoCodeID, userID int64) (*service.PromoCodeUsage, error) {
|
||||
m, err := r.client.PromoCodeUsage.Query().
|
||||
Where(
|
||||
promocodeusage.PromoCodeIDEQ(promoCodeID),
|
||||
promocodeusage.UserIDEQ(userID),
|
||||
).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeUsageEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]service.PromoCodeUsage, *pagination.PaginationResult, error) {
|
||||
q := r.client.PromoCodeUsage.Query().
|
||||
Where(promocodeusage.PromoCodeIDEQ(promoCodeID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
usages, err := q.
|
||||
WithUser().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(promocodeusage.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outUsages := promoCodeUsageEntitiesToService(usages)
|
||||
|
||||
return outUsages, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) IncrementUsedCount(ctx context.Context, id int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.PromoCode.UpdateOneID(id).
|
||||
AddUsedCount(1).
|
||||
Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// Entity to Service conversions
|
||||
|
||||
func promoCodeEntityToService(m *dbent.PromoCode) *service.PromoCode {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.PromoCode{
|
||||
ID: m.ID,
|
||||
Code: m.Code,
|
||||
BonusAmount: m.BonusAmount,
|
||||
MaxUses: m.MaxUses,
|
||||
UsedCount: m.UsedCount,
|
||||
Status: m.Status,
|
||||
ExpiresAt: m.ExpiresAt,
|
||||
Notes: derefString(m.Notes),
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func promoCodeEntitiesToService(models []*dbent.PromoCode) []service.PromoCode {
|
||||
out := make([]service.PromoCode, 0, len(models))
|
||||
for i := range models {
|
||||
if s := promoCodeEntityToService(models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func promoCodeUsageEntityToService(m *dbent.PromoCodeUsage) *service.PromoCodeUsage {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
out := &service.PromoCodeUsage{
|
||||
ID: m.ID,
|
||||
PromoCodeID: m.PromoCodeID,
|
||||
UserID: m.UserID,
|
||||
BonusAmount: m.BonusAmount,
|
||||
UsedAt: m.UsedAt,
|
||||
}
|
||||
if m.Edges.User != nil {
|
||||
out.User = userEntityToService(m.Edges.User)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func promoCodeUsageEntitiesToService(models []*dbent.PromoCodeUsage) []service.PromoCodeUsage {
|
||||
out := make([]service.PromoCodeUsage, 0, len(models))
|
||||
for i := range models {
|
||||
if s := promoCodeUsageEntityToService(models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
74
backend/internal/repository/proxy_latency_cache.go
Normal file
74
backend/internal/repository/proxy_latency_cache.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const proxyLatencyKeyPrefix = "proxy:latency:"
|
||||
|
||||
func proxyLatencyKey(proxyID int64) string {
|
||||
return fmt.Sprintf("%s%d", proxyLatencyKeyPrefix, proxyID)
|
||||
}
|
||||
|
||||
type proxyLatencyCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewProxyLatencyCache(rdb *redis.Client) service.ProxyLatencyCache {
|
||||
return &proxyLatencyCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *proxyLatencyCache) GetProxyLatencies(ctx context.Context, proxyIDs []int64) (map[int64]*service.ProxyLatencyInfo, error) {
|
||||
results := make(map[int64]*service.ProxyLatencyInfo)
|
||||
if len(proxyIDs) == 0 {
|
||||
return results, nil
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(proxyIDs))
|
||||
for _, id := range proxyIDs {
|
||||
keys = append(keys, proxyLatencyKey(id))
|
||||
}
|
||||
|
||||
values, err := c.rdb.MGet(ctx, keys...).Result()
|
||||
if err != nil {
|
||||
return results, err
|
||||
}
|
||||
|
||||
for i, raw := range values {
|
||||
if raw == nil {
|
||||
continue
|
||||
}
|
||||
var payload []byte
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
payload = []byte(v)
|
||||
case []byte:
|
||||
payload = v
|
||||
default:
|
||||
continue
|
||||
}
|
||||
var info service.ProxyLatencyInfo
|
||||
if err := json.Unmarshal(payload, &info); err != nil {
|
||||
continue
|
||||
}
|
||||
results[proxyIDs[i]] = &info
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (c *proxyLatencyCache) SetProxyLatency(ctx context.Context, proxyID int64, info *service.ProxyLatencyInfo) error {
|
||||
if info == nil {
|
||||
return nil
|
||||
}
|
||||
payload, err := json.Marshal(info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.rdb.Set(ctx, proxyLatencyKey(proxyID), payload, 0).Err()
|
||||
}
|
||||
118
backend/internal/repository/proxy_probe_service.go
Normal file
118
backend/internal/repository/proxy_probe_service.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
||||
insecure := false
|
||||
allowPrivate := false
|
||||
validateResolvedIP := true
|
||||
if cfg != nil {
|
||||
insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
|
||||
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
|
||||
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
|
||||
}
|
||||
if insecure {
|
||||
log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
|
||||
}
|
||||
return &proxyProbeService{
|
||||
ipInfoURL: defaultIPInfoURL,
|
||||
insecureSkipVerify: insecure,
|
||||
allowPrivateHosts: allowPrivate,
|
||||
validateResolvedIP: validateResolvedIP,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
defaultIPInfoURL = "http://ip-api.com/json/?lang=zh-CN"
|
||||
defaultProxyProbeTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type proxyProbeService struct {
|
||||
ipInfoURL string
|
||||
insecureSkipVerify bool
|
||||
allowPrivateHosts bool
|
||||
validateResolvedIP bool
|
||||
}
|
||||
|
||||
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: defaultProxyProbeTimeout,
|
||||
InsecureSkipVerify: s.insecureSkipVerify,
|
||||
ProxyStrict: true,
|
||||
ValidateResolvedIP: s.validateResolvedIP,
|
||||
AllowPrivateHosts: s.allowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", s.ipInfoURL, nil)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("proxy connection failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
latencyMs := time.Since(startTime).Milliseconds()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var ipInfo struct {
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message"`
|
||||
Query string `json:"query"`
|
||||
City string `json:"city"`
|
||||
Region string `json:"region"`
|
||||
RegionName string `json:"regionName"`
|
||||
Country string `json:"country"`
|
||||
CountryCode string `json:"countryCode"`
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &ipInfo); err != nil {
|
||||
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
if strings.ToLower(ipInfo.Status) != "success" {
|
||||
if ipInfo.Message == "" {
|
||||
ipInfo.Message = "ip-api request failed"
|
||||
}
|
||||
return nil, latencyMs, fmt.Errorf("ip-api request failed: %s", ipInfo.Message)
|
||||
}
|
||||
|
||||
region := ipInfo.RegionName
|
||||
if region == "" {
|
||||
region = ipInfo.Region
|
||||
}
|
||||
return &service.ProxyExitInfo{
|
||||
IP: ipInfo.Query,
|
||||
City: ipInfo.City,
|
||||
Region: region,
|
||||
Country: ipInfo.Country,
|
||||
CountryCode: ipInfo.CountryCode,
|
||||
}, latencyMs, nil
|
||||
}
|
||||
119
backend/internal/repository/proxy_probe_service_test.go
Normal file
119
backend/internal/repository/proxy_probe_service_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ProxyProbeServiceSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
proxySrv *httptest.Server
|
||||
prober *proxyProbeService
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.prober = &proxyProbeService{
|
||||
ipInfoURL: "http://ip-api.test/json/?lang=zh-CN",
|
||||
allowPrivateHosts: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TearDownTest() {
|
||||
if s.proxySrv != nil {
|
||||
s.proxySrv.Close()
|
||||
s.proxySrv = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) {
|
||||
s.proxySrv = newLocalTestServer(s.T(), handler)
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidProxyURL() {
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, "://bad")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "failed to create proxy client")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_UnsupportedProxyScheme() {
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, "ftp://127.0.0.1:1")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "failed to create proxy client")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
|
||||
seen := make(chan string, 1)
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
seen <- r.RequestURI
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`)
|
||||
}))
|
||||
|
||||
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.NoError(s.T(), err, "ProbeProxy")
|
||||
require.GreaterOrEqual(s.T(), latencyMs, int64(0), "unexpected latency")
|
||||
require.Equal(s.T(), "1.2.3.4", info.IP)
|
||||
require.Equal(s.T(), "c", info.City)
|
||||
require.Equal(s.T(), "r", info.Region)
|
||||
require.Equal(s.T(), "cc", info.Country)
|
||||
require.Equal(s.T(), "CC", info.CountryCode)
|
||||
|
||||
// Verify proxy received the request
|
||||
select {
|
||||
case uri := <-seen:
|
||||
require.Contains(s.T(), uri, "ip-api.test", "expected request to go through proxy")
|
||||
default:
|
||||
require.Fail(s.T(), "expected proxy to receive request")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_NonOKStatus() {
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "status: 503")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidJSON() {
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, "not-json")
|
||||
}))
|
||||
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "failed to parse response")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidIPInfoURL() {
|
||||
s.prober.ipInfoURL = "://invalid-url"
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.Error(s.T(), err, "expected error for invalid ipInfoURL")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() {
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
s.proxySrv.Close()
|
||||
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.Error(s.T(), err, "expected error when proxy server is closed")
|
||||
}
|
||||
|
||||
func TestProxyProbeServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(ProxyProbeServiceSuite))
|
||||
}
|
||||
359
backend/internal/repository/proxy_repo.go
Normal file
359
backend/internal/repository/proxy_repo.go
Normal file
@@ -0,0 +1,359 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/proxy"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type sqlQuerier interface {
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
type proxyRepository struct {
|
||||
client *dbent.Client
|
||||
sql sqlQuerier
|
||||
}
|
||||
|
||||
func NewProxyRepository(client *dbent.Client, sqlDB *sql.DB) service.ProxyRepository {
|
||||
return newProxyRepositoryWithSQL(client, sqlDB)
|
||||
}
|
||||
|
||||
func newProxyRepositoryWithSQL(client *dbent.Client, sqlq sqlQuerier) *proxyRepository {
|
||||
return &proxyRepository{client: client, sql: sqlq}
|
||||
}
|
||||
|
||||
func (r *proxyRepository) Create(ctx context.Context, proxyIn *service.Proxy) error {
|
||||
builder := r.client.Proxy.Create().
|
||||
SetName(proxyIn.Name).
|
||||
SetProtocol(proxyIn.Protocol).
|
||||
SetHost(proxyIn.Host).
|
||||
SetPort(proxyIn.Port).
|
||||
SetStatus(proxyIn.Status)
|
||||
if proxyIn.Username != "" {
|
||||
builder.SetUsername(proxyIn.Username)
|
||||
}
|
||||
if proxyIn.Password != "" {
|
||||
builder.SetPassword(proxyIn.Password)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
applyProxyEntityToService(proxyIn, created)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*service.Proxy, error) {
|
||||
m, err := r.client.Proxy.Get(ctx, id)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrProxyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return proxyEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *proxyRepository) Update(ctx context.Context, proxyIn *service.Proxy) error {
|
||||
builder := r.client.Proxy.UpdateOneID(proxyIn.ID).
|
||||
SetName(proxyIn.Name).
|
||||
SetProtocol(proxyIn.Protocol).
|
||||
SetHost(proxyIn.Host).
|
||||
SetPort(proxyIn.Port).
|
||||
SetStatus(proxyIn.Status)
|
||||
if proxyIn.Username != "" {
|
||||
builder.SetUsername(proxyIn.Username)
|
||||
} else {
|
||||
builder.ClearUsername()
|
||||
}
|
||||
if proxyIn.Password != "" {
|
||||
builder.SetPassword(proxyIn.Password)
|
||||
} else {
|
||||
builder.ClearPassword()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
applyProxyEntityToService(proxyIn, updated)
|
||||
return nil
|
||||
}
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrProxyNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *proxyRepository) Delete(ctx context.Context, id int64) error {
|
||||
_, err := r.client.Proxy.Delete().Where(proxy.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Proxy, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query
|
||||
func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.Proxy, *pagination.PaginationResult, error) {
|
||||
q := r.client.Proxy.Query()
|
||||
if protocol != "" {
|
||||
q = q.Where(proxy.ProtocolEQ(protocol))
|
||||
}
|
||||
if status != "" {
|
||||
q = q.Where(proxy.StatusEQ(status))
|
||||
}
|
||||
if search != "" {
|
||||
q = q.Where(proxy.NameContainsFold(search))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
proxies, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(proxy.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outProxies := make([]service.Proxy, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
outProxies = append(outProxies, *proxyEntityToService(proxies[i]))
|
||||
}
|
||||
|
||||
return outProxies, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
// ListWithFiltersAndAccountCount lists proxies with filters and includes account count per proxy
|
||||
func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) {
|
||||
q := r.client.Proxy.Query()
|
||||
if protocol != "" {
|
||||
q = q.Where(proxy.ProtocolEQ(protocol))
|
||||
}
|
||||
if status != "" {
|
||||
q = q.Where(proxy.StatusEQ(status))
|
||||
}
|
||||
if search != "" {
|
||||
q = q.Where(proxy.NameContainsFold(search))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
proxies, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(proxy.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Get account counts
|
||||
counts, err := r.GetAccountCountsForProxies(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Build result with account counts
|
||||
result := make([]service.ProxyWithAccountCount, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
proxyOut := proxyEntityToService(proxies[i])
|
||||
if proxyOut == nil {
|
||||
continue
|
||||
}
|
||||
result = append(result, service.ProxyWithAccountCount{
|
||||
Proxy: *proxyOut,
|
||||
AccountCount: counts[proxyOut.ID],
|
||||
})
|
||||
}
|
||||
|
||||
return result, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) {
|
||||
proxies, err := r.client.Proxy.Query().
|
||||
Where(proxy.StatusEQ(service.StatusActive)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outProxies := make([]service.Proxy, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
outProxies = append(outProxies, *proxyEntityToService(proxies[i]))
|
||||
}
|
||||
return outProxies, nil
|
||||
}
|
||||
|
||||
// ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists
|
||||
func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
||||
q := r.client.Proxy.Query().
|
||||
Where(proxy.HostEQ(host), proxy.PortEQ(port))
|
||||
|
||||
if username == "" {
|
||||
q = q.Where(proxy.Or(proxy.UsernameIsNil(), proxy.UsernameEQ("")))
|
||||
} else {
|
||||
q = q.Where(proxy.UsernameEQ(username))
|
||||
}
|
||||
if password == "" {
|
||||
q = q.Where(proxy.Or(proxy.PasswordIsNil(), proxy.PasswordEQ("")))
|
||||
} else {
|
||||
q = q.Where(proxy.PasswordEQ(password))
|
||||
}
|
||||
|
||||
count, err := q.Count(ctx)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CountAccountsByProxyID returns the number of accounts using a specific proxy
|
||||
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
||||
var count int64
|
||||
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1 AND deleted_at IS NULL", []any{proxyID}, &count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *proxyRepository) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT id, name, platform, type, notes
|
||||
FROM accounts
|
||||
WHERE proxy_id = $1 AND deleted_at IS NULL
|
||||
ORDER BY id DESC
|
||||
`, proxyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
out := make([]service.ProxyAccountSummary, 0)
|
||||
for rows.Next() {
|
||||
var (
|
||||
id int64
|
||||
name string
|
||||
platform string
|
||||
accType string
|
||||
notes sql.NullString
|
||||
)
|
||||
if err := rows.Scan(&id, &name, &platform, &accType, ¬es); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var notesPtr *string
|
||||
if notes.Valid {
|
||||
notesPtr = ¬es.String
|
||||
}
|
||||
out = append(out, service.ProxyAccountSummary{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Platform: platform,
|
||||
Type: accType,
|
||||
Notes: notesPtr,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
|
||||
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (counts map[int64]int64, err error) {
|
||||
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL AND deleted_at IS NULL GROUP BY proxy_id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
counts = nil
|
||||
}
|
||||
}()
|
||||
|
||||
counts = make(map[int64]int64)
|
||||
for rows.Next() {
|
||||
var proxyID, count int64
|
||||
if err = rows.Scan(&proxyID, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts[proxyID] = count
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
// ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending
|
||||
func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) {
|
||||
proxies, err := r.client.Proxy.Query().
|
||||
Where(proxy.StatusEQ(service.StatusActive)).
|
||||
Order(dbent.Desc(proxy.FieldCreatedAt)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get account counts
|
||||
counts, err := r.GetAccountCountsForProxies(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build result with account counts
|
||||
result := make([]service.ProxyWithAccountCount, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
proxyOut := proxyEntityToService(proxies[i])
|
||||
if proxyOut == nil {
|
||||
continue
|
||||
}
|
||||
result = append(result, service.ProxyWithAccountCount{
|
||||
Proxy: *proxyOut,
|
||||
AccountCount: counts[proxyOut.ID],
|
||||
})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func proxyEntityToService(m *dbent.Proxy) *service.Proxy {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
out := &service.Proxy{
|
||||
ID: m.ID,
|
||||
Name: m.Name,
|
||||
Protocol: m.Protocol,
|
||||
Host: m.Host,
|
||||
Port: m.Port,
|
||||
Status: m.Status,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
if m.Username != nil {
|
||||
out.Username = *m.Username
|
||||
}
|
||||
if m.Password != nil {
|
||||
out.Password = *m.Password
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func applyProxyEntityToService(dst *service.Proxy, src *dbent.Proxy) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
dst.ID = src.ID
|
||||
dst.CreatedAt = src.CreatedAt
|
||||
dst.UpdatedAt = src.UpdatedAt
|
||||
}
|
||||
329
backend/internal/repository/proxy_repo_integration_test.go
Normal file
329
backend/internal/repository/proxy_repo_integration_test.go
Normal file
@@ -0,0 +1,329 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ProxyRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
tx *dbent.Tx
|
||||
repo *proxyRepository
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.tx = tx
|
||||
s.repo = newProxyRepositoryWithSQL(tx.Client(), tx)
|
||||
}
|
||||
|
||||
func TestProxyRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(ProxyRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / GetByID / Update / Delete ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestCreate() {
|
||||
proxy := &service.Proxy{
|
||||
Name: "test-create",
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, proxy)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(proxy.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("test-create", got.Name)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestUpdate() {
|
||||
proxy := &service.Proxy{
|
||||
Name: "original",
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, proxy))
|
||||
|
||||
proxy.Name = "updated"
|
||||
err := s.repo.Update(s.ctx, proxy)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("updated", got.Name)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestDelete() {
|
||||
proxy := &service.Proxy{
|
||||
Name: "to-delete",
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, proxy))
|
||||
|
||||
err := s.repo.Delete(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, proxy.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestList() {
|
||||
s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
|
||||
|
||||
proxies, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(proxies, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
|
||||
s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "socks5", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
|
||||
|
||||
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "socks5", "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(proxies, 1)
|
||||
s.Require().Equal("socks5", proxies[0].Protocol)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestListWithFilters_Status() {
|
||||
s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusDisabled})
|
||||
|
||||
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(proxies, 1)
|
||||
s.Require().Equal(service.StatusDisabled, proxies[0].Status)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestListWithFilters_Search() {
|
||||
s.mustCreateProxy(&service.Proxy{Name: "production-proxy", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
s.mustCreateProxy(&service.Proxy{Name: "dev-proxy", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
|
||||
|
||||
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "prod")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(proxies, 1)
|
||||
s.Require().Contains(proxies[0].Name, "production")
|
||||
}
|
||||
|
||||
// --- ListActive ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestListActive() {
|
||||
s.mustCreateProxy(&service.Proxy{Name: "active1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
s.mustCreateProxy(&service.Proxy{Name: "inactive1", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusDisabled})
|
||||
|
||||
proxies, err := s.repo.ListActive(s.ctx)
|
||||
s.Require().NoError(err, "ListActive")
|
||||
s.Require().Len(proxies, 1)
|
||||
s.Require().Equal("active1", proxies[0].Name)
|
||||
}
|
||||
|
||||
// --- ExistsByHostPortAuth ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
|
||||
s.mustCreateProxy(&service.Proxy{
|
||||
Name: "p1",
|
||||
Protocol: "http",
|
||||
Host: "1.2.3.4",
|
||||
Port: 8080,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "user", "pass")
|
||||
s.Require().NoError(err, "ExistsByHostPortAuth")
|
||||
s.Require().True(exists)
|
||||
|
||||
notExists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "wrong", "creds")
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(notExists)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
|
||||
s.mustCreateProxy(&service.Proxy{
|
||||
Name: "p-noauth",
|
||||
Protocol: "http",
|
||||
Host: "5.6.7.8",
|
||||
Port: 8081,
|
||||
Username: "",
|
||||
Password: "",
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "5.6.7.8", 8081, "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(exists)
|
||||
}
|
||||
|
||||
// --- CountAccountsByProxyID ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
|
||||
proxy := s.mustCreateProxy(&service.Proxy{Name: "p-count", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
s.mustInsertAccount("a1", &proxy.ID)
|
||||
s.mustInsertAccount("a2", &proxy.ID)
|
||||
s.mustInsertAccount("a3", nil) // no proxy
|
||||
|
||||
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err, "CountAccountsByProxyID")
|
||||
s.Require().Equal(int64(2), count)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
|
||||
proxy := s.mustCreateProxy(&service.Proxy{Name: "p-zero", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
|
||||
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
|
||||
// --- GetAccountCountsForProxies ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies() {
|
||||
p1 := s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
p2 := s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
|
||||
|
||||
s.mustInsertAccount("a1", &p1.ID)
|
||||
s.mustInsertAccount("a2", &p1.ID)
|
||||
s.mustInsertAccount("a3", &p2.ID)
|
||||
|
||||
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
|
||||
s.Require().NoError(err, "GetAccountCountsForProxies")
|
||||
s.Require().Equal(int64(2), counts[p1.ID])
|
||||
s.Require().Equal(int64(1), counts[p2.ID])
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies_Empty() {
|
||||
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(counts)
|
||||
}
|
||||
|
||||
// --- ListActiveWithAccountCount ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
|
||||
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
p1 := s.mustCreateProxyWithTimes("p1", service.StatusActive, base.Add(-1*time.Hour))
|
||||
p2 := s.mustCreateProxyWithTimes("p2", service.StatusActive, base)
|
||||
s.mustCreateProxyWithTimes("p3-inactive", service.StatusDisabled, base.Add(1*time.Hour))
|
||||
|
||||
s.mustInsertAccount("a1", &p1.ID)
|
||||
s.mustInsertAccount("a2", &p1.ID)
|
||||
s.mustInsertAccount("a3", &p2.ID)
|
||||
|
||||
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
|
||||
s.Require().NoError(err, "ListActiveWithAccountCount")
|
||||
s.Require().Len(withCounts, 2, "expected 2 active proxies")
|
||||
|
||||
// Sorted by created_at DESC, so p2 first
|
||||
s.Require().Equal(p2.ID, withCounts[0].ID)
|
||||
s.Require().Equal(int64(1), withCounts[0].AccountCount)
|
||||
s.Require().Equal(p1.ID, withCounts[1].ID)
|
||||
s.Require().Equal(int64(2), withCounts[1].AccountCount)
|
||||
}
|
||||
|
||||
// --- Combined original test ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
|
||||
p1 := s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "1.2.3.4", Port: 8080, Username: "u", Password: "p", Status: service.StatusActive})
|
||||
p2 := s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "5.6.7.8", Port: 8081, Username: "", Password: "", Status: service.StatusActive})
|
||||
|
||||
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "u", "p")
|
||||
s.Require().NoError(err, "ExistsByHostPortAuth")
|
||||
s.Require().True(exists, "expected proxy to exist")
|
||||
|
||||
s.mustInsertAccount("a1", &p1.ID)
|
||||
s.mustInsertAccount("a2", &p1.ID)
|
||||
s.mustInsertAccount("a3", &p2.ID)
|
||||
|
||||
count1, err := s.repo.CountAccountsByProxyID(s.ctx, p1.ID)
|
||||
s.Require().NoError(err, "CountAccountsByProxyID")
|
||||
s.Require().Equal(int64(2), count1, "expected 2 accounts for p1")
|
||||
|
||||
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
|
||||
s.Require().NoError(err, "GetAccountCountsForProxies")
|
||||
s.Require().Equal(int64(2), counts[p1.ID])
|
||||
s.Require().Equal(int64(1), counts[p2.ID])
|
||||
|
||||
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
|
||||
s.Require().NoError(err, "ListActiveWithAccountCount")
|
||||
s.Require().Len(withCounts, 2, "expected 2 proxies")
|
||||
for _, pc := range withCounts {
|
||||
switch pc.ID {
|
||||
case p1.ID:
|
||||
s.Require().Equal(int64(2), pc.AccountCount, "p1 count mismatch")
|
||||
case p2.ID:
|
||||
s.Require().Equal(int64(1), pc.AccountCount, "p2 count mismatch")
|
||||
default:
|
||||
s.Require().Fail("unexpected proxy id", pc.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) mustCreateProxy(p *service.Proxy) *service.Proxy {
|
||||
s.T().Helper()
|
||||
s.Require().NoError(s.repo.Create(s.ctx, p), "create proxy")
|
||||
return p
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) mustCreateProxyWithTimes(name, status string, createdAt time.Time) *service.Proxy {
|
||||
s.T().Helper()
|
||||
|
||||
// Use the repository create for standard fields, then update timestamps via raw SQL to keep deterministic ordering.
|
||||
p := s.mustCreateProxy(&service.Proxy{
|
||||
Name: name,
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Status: status,
|
||||
})
|
||||
_, err := s.tx.ExecContext(s.ctx, "UPDATE proxies SET created_at = $1, updated_at = $1 WHERE id = $2", createdAt, p.ID)
|
||||
s.Require().NoError(err, "update proxy timestamps")
|
||||
return p
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) mustInsertAccount(name string, proxyID *int64) {
|
||||
s.T().Helper()
|
||||
var pid any
|
||||
if proxyID != nil {
|
||||
pid = *proxyID
|
||||
}
|
||||
_, err := s.tx.ExecContext(
|
||||
s.ctx,
|
||||
"INSERT INTO accounts (name, platform, type, proxy_id) VALUES ($1, $2, $3, $4)",
|
||||
name,
|
||||
service.PlatformAnthropic,
|
||||
service.AccountTypeOAuth,
|
||||
pid,
|
||||
)
|
||||
s.Require().NoError(err, "insert account")
|
||||
}
|
||||
62
backend/internal/repository/redeem_cache.go
Normal file
62
backend/internal/repository/redeem_cache.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
redeemRateLimitKeyPrefix = "redeem:ratelimit:"
|
||||
redeemLockKeyPrefix = "redeem:lock:"
|
||||
redeemRateLimitDuration = 24 * time.Hour
|
||||
)
|
||||
|
||||
// redeemRateLimitKey generates the Redis key for redeem attempt rate limiting.
|
||||
func redeemRateLimitKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||
}
|
||||
|
||||
// redeemLockKey generates the Redis key for redeem code locking.
|
||||
func redeemLockKey(code string) string {
|
||||
return redeemLockKeyPrefix + code
|
||||
}
|
||||
|
||||
type redeemCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewRedeemCache(rdb *redis.Client) service.RedeemCache {
|
||||
return &redeemCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *redeemCache) GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
key := redeemRateLimitKey(userID)
|
||||
count, err := c.rdb.Get(ctx, key).Int()
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID int64) error {
|
||||
key := redeemRateLimitKey(userID)
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, redeemRateLimitDuration)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *redeemCache) AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error) {
|
||||
key := redeemLockKey(code)
|
||||
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
|
||||
}
|
||||
|
||||
func (c *redeemCache) ReleaseRedeemLock(ctx context.Context, code string) error {
|
||||
key := redeemLockKey(code)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
103
backend/internal/repository/redeem_cache_integration_test.go
Normal file
103
backend/internal/repository/redeem_cache_integration_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type RedeemCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache *redeemCache
|
||||
}
|
||||
|
||||
func (s *RedeemCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewRedeemCache(s.rdb).(*redeemCache)
|
||||
}
|
||||
|
||||
func (s *RedeemCacheSuite) TestGetRedeemAttemptCount_Missing() {
|
||||
missingUserID := int64(99999)
|
||||
count, err := s.cache.GetRedeemAttemptCount(s.ctx, missingUserID)
|
||||
require.NoError(s.T(), err, "expected nil error for missing rate-limit key")
|
||||
require.Equal(s.T(), 0, count, "expected zero count for missing key")
|
||||
}
|
||||
|
||||
func (s *RedeemCacheSuite) TestIncrementAndGetRedeemAttemptCount() {
|
||||
userID := int64(1)
|
||||
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID), "IncrementRedeemAttemptCount")
|
||||
count, err := s.cache.GetRedeemAttemptCount(s.ctx, userID)
|
||||
require.NoError(s.T(), err, "GetRedeemAttemptCount")
|
||||
require.Equal(s.T(), 1, count, "count mismatch")
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, key).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, redeemRateLimitDuration)
|
||||
}
|
||||
|
||||
func (s *RedeemCacheSuite) TestMultipleIncrements() {
|
||||
userID := int64(2)
|
||||
|
||||
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID))
|
||||
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID))
|
||||
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID))
|
||||
|
||||
count, err := s.cache.GetRedeemAttemptCount(s.ctx, userID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 3, count, "count after 3 increments")
|
||||
}
|
||||
|
||||
func (s *RedeemCacheSuite) TestAcquireAndReleaseRedeemLock() {
|
||||
ok, err := s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second)
|
||||
require.NoError(s.T(), err, "AcquireRedeemLock")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Second acquire should fail
|
||||
ok, err = s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second)
|
||||
require.NoError(s.T(), err, "AcquireRedeemLock 2")
|
||||
require.False(s.T(), ok, "expected lock to be held")
|
||||
|
||||
// Release
|
||||
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "CODE"), "ReleaseRedeemLock")
|
||||
|
||||
// Now acquire should succeed
|
||||
ok, err = s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second)
|
||||
require.NoError(s.T(), err, "AcquireRedeemLock after release")
|
||||
require.True(s.T(), ok)
|
||||
}
|
||||
|
||||
func (s *RedeemCacheSuite) TestAcquireRedeemLock_TTL() {
|
||||
lockKey := redeemLockKeyPrefix + "CODE2"
|
||||
lockTTL := 15 * time.Second
|
||||
|
||||
ok, err := s.cache.AcquireRedeemLock(s.ctx, "CODE2", lockTTL)
|
||||
require.NoError(s.T(), err, "AcquireRedeemLock CODE2")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, lockKey).Result()
|
||||
require.NoError(s.T(), err, "TTL lock key")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, lockTTL)
|
||||
}
|
||||
|
||||
func (s *RedeemCacheSuite) TestReleaseRedeemLock_Idempotent() {
|
||||
// Release a lock that doesn't exist should not error
|
||||
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "NONEXISTENT"))
|
||||
|
||||
// Acquire, release, release again
|
||||
ok, err := s.cache.AcquireRedeemLock(s.ctx, "IDEMPOTENT", 10*time.Second)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "IDEMPOTENT"))
|
||||
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "IDEMPOTENT"), "second release should be idempotent")
|
||||
}
|
||||
|
||||
func TestRedeemCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(RedeemCacheSuite))
|
||||
}
|
||||
77
backend/internal/repository/redeem_cache_test.go
Normal file
77
backend/internal/repository/redeem_cache_test.go
Normal file
@@ -0,0 +1,77 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRedeemRateLimitKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_user_id",
|
||||
userID: 123,
|
||||
expected: "redeem:ratelimit:123",
|
||||
},
|
||||
{
|
||||
name: "zero_user_id",
|
||||
userID: 0,
|
||||
expected: "redeem:ratelimit:0",
|
||||
},
|
||||
{
|
||||
name: "negative_user_id",
|
||||
userID: -1,
|
||||
expected: "redeem:ratelimit:-1",
|
||||
},
|
||||
{
|
||||
name: "max_int64",
|
||||
userID: math.MaxInt64,
|
||||
expected: "redeem:ratelimit:9223372036854775807",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := redeemRateLimitKey(tc.userID)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedeemLockKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_code",
|
||||
code: "ABC123",
|
||||
expected: "redeem:lock:ABC123",
|
||||
},
|
||||
{
|
||||
name: "empty_code",
|
||||
code: "",
|
||||
expected: "redeem:lock:",
|
||||
},
|
||||
{
|
||||
name: "code_with_special_chars",
|
||||
code: "CODE-2024:test",
|
||||
expected: "redeem:lock:CODE-2024:test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := redeemLockKey(tc.code)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
239
backend/internal/repository/redeem_code_repo.go
Normal file
239
backend/internal/repository/redeem_code_repo.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type redeemCodeRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewRedeemCodeRepository(client *dbent.Client) service.RedeemCodeRepository {
|
||||
return &redeemCodeRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) Create(ctx context.Context, code *service.RedeemCode) error {
|
||||
created, err := r.client.RedeemCode.Create().
|
||||
SetCode(code.Code).
|
||||
SetType(code.Type).
|
||||
SetValue(code.Value).
|
||||
SetStatus(code.Status).
|
||||
SetNotes(code.Notes).
|
||||
SetValidityDays(code.ValidityDays).
|
||||
SetNillableUsedBy(code.UsedBy).
|
||||
SetNillableUsedAt(code.UsedAt).
|
||||
SetNillableGroupID(code.GroupID).
|
||||
Save(ctx)
|
||||
if err == nil {
|
||||
code.ID = created.ID
|
||||
code.CreatedAt = created.CreatedAt
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []service.RedeemCode) error {
|
||||
if len(codes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
builders := make([]*dbent.RedeemCodeCreate, 0, len(codes))
|
||||
for i := range codes {
|
||||
c := &codes[i]
|
||||
b := r.client.RedeemCode.Create().
|
||||
SetCode(c.Code).
|
||||
SetType(c.Type).
|
||||
SetValue(c.Value).
|
||||
SetStatus(c.Status).
|
||||
SetNotes(c.Notes).
|
||||
SetValidityDays(c.ValidityDays).
|
||||
SetNillableUsedBy(c.UsedBy).
|
||||
SetNillableUsedAt(c.UsedAt).
|
||||
SetNillableGroupID(c.GroupID)
|
||||
builders = append(builders, b)
|
||||
}
|
||||
|
||||
return r.client.RedeemCode.CreateBulk(builders...).Exec(ctx)
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*service.RedeemCode, error) {
|
||||
m, err := r.client.RedeemCode.Query().
|
||||
Where(redeemcode.IDEQ(id)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrRedeemCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return redeemCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) {
|
||||
m, err := r.client.RedeemCode.Query().
|
||||
Where(redeemcode.CodeEQ(code)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrRedeemCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return redeemCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) Delete(ctx context.Context, id int64) error {
|
||||
_, err := r.client.RedeemCode.Delete().Where(redeemcode.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "")
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
|
||||
q := r.client.RedeemCode.Query()
|
||||
|
||||
if codeType != "" {
|
||||
q = q.Where(redeemcode.TypeEQ(codeType))
|
||||
}
|
||||
if status != "" {
|
||||
q = q.Where(redeemcode.StatusEQ(status))
|
||||
}
|
||||
if search != "" {
|
||||
q = q.Where(redeemcode.CodeContainsFold(search))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
codes, err := q.
|
||||
WithUser().
|
||||
WithGroup().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(redeemcode.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outCodes := redeemCodeEntitiesToService(codes)
|
||||
|
||||
return outCodes, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemCode) error {
|
||||
up := r.client.RedeemCode.UpdateOneID(code.ID).
|
||||
SetCode(code.Code).
|
||||
SetType(code.Type).
|
||||
SetValue(code.Value).
|
||||
SetStatus(code.Status).
|
||||
SetNotes(code.Notes).
|
||||
SetValidityDays(code.ValidityDays)
|
||||
|
||||
if code.UsedBy != nil {
|
||||
up.SetUsedBy(*code.UsedBy)
|
||||
} else {
|
||||
up.ClearUsedBy()
|
||||
}
|
||||
if code.UsedAt != nil {
|
||||
up.SetUsedAt(*code.UsedAt)
|
||||
} else {
|
||||
up.ClearUsedAt()
|
||||
}
|
||||
if code.GroupID != nil {
|
||||
up.SetGroupID(*code.GroupID)
|
||||
} else {
|
||||
up.ClearGroupID()
|
||||
}
|
||||
|
||||
updated, err := up.Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrRedeemCodeNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
code.CreatedAt = updated.CreatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
|
||||
now := time.Now()
|
||||
client := clientFromContext(ctx, r.client)
|
||||
affected, err := client.RedeemCode.Update().
|
||||
Where(redeemcode.IDEQ(id), redeemcode.StatusEQ(service.StatusUnused)).
|
||||
SetStatus(service.StatusUsed).
|
||||
SetUsedBy(userID).
|
||||
SetUsedAt(now).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrRedeemCodeUsed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) {
|
||||
if limit <= 0 {
|
||||
limit = 10
|
||||
}
|
||||
|
||||
codes, err := r.client.RedeemCode.Query().
|
||||
Where(redeemcode.UsedByEQ(userID)).
|
||||
WithGroup().
|
||||
Order(dbent.Desc(redeemcode.FieldUsedAt)).
|
||||
Limit(limit).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return redeemCodeEntitiesToService(codes), nil
|
||||
}
|
||||
|
||||
func redeemCodeEntityToService(m *dbent.RedeemCode) *service.RedeemCode {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
out := &service.RedeemCode{
|
||||
ID: m.ID,
|
||||
Code: m.Code,
|
||||
Type: m.Type,
|
||||
Value: m.Value,
|
||||
Status: m.Status,
|
||||
UsedBy: m.UsedBy,
|
||||
UsedAt: m.UsedAt,
|
||||
Notes: derefString(m.Notes),
|
||||
CreatedAt: m.CreatedAt,
|
||||
GroupID: m.GroupID,
|
||||
ValidityDays: m.ValidityDays,
|
||||
}
|
||||
if m.Edges.User != nil {
|
||||
out.User = userEntityToService(m.Edges.User)
|
||||
}
|
||||
if m.Edges.Group != nil {
|
||||
out.Group = groupEntityToService(m.Edges.Group)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func redeemCodeEntitiesToService(models []*dbent.RedeemCode) []service.RedeemCode {
|
||||
out := make([]service.RedeemCode, 0, len(models))
|
||||
for i := range models {
|
||||
if s := redeemCodeEntityToService(models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
390
backend/internal/repository/redeem_code_repo_integration_test.go
Normal file
390
backend/internal/repository/redeem_code_repo_integration_test.go
Normal file
@@ -0,0 +1,390 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type RedeemCodeRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
repo *redeemCodeRepository
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.repo = NewRedeemCodeRepository(s.client).(*redeemCodeRepository)
|
||||
}
|
||||
|
||||
func TestRedeemCodeRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(RedeemCodeRepoSuite))
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) createUser(email string) *dbent.User {
|
||||
u, err := s.client.User.Create().
|
||||
SetEmail(email).
|
||||
SetPasswordHash("test-password-hash").
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "create user")
|
||||
return u
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) createGroup(name string) *dbent.Group {
|
||||
g, err := s.client.Group.Create().
|
||||
SetName(name).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "create group")
|
||||
return g
|
||||
}
|
||||
|
||||
// --- Create / CreateBatch / GetByID / GetByCode ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestCreate() {
|
||||
code := &service.RedeemCode{
|
||||
Code: "TEST-CREATE",
|
||||
Type: service.RedeemTypeBalance,
|
||||
Value: 100,
|
||||
Status: service.StatusUnused,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, code)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(code.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, code.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("TEST-CREATE", got.Code)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestCreateBatch() {
|
||||
codes := []service.RedeemCode{
|
||||
{Code: "BATCH-1", Type: service.RedeemTypeBalance, Value: 10, Status: service.StatusUnused},
|
||||
{Code: "BATCH-2", Type: service.RedeemTypeBalance, Value: 20, Status: service.StatusUnused},
|
||||
}
|
||||
|
||||
err := s.repo.CreateBatch(s.ctx, codes)
|
||||
s.Require().NoError(err, "CreateBatch")
|
||||
|
||||
got1, err := s.repo.GetByCode(s.ctx, "BATCH-1")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(float64(10), got1.Value)
|
||||
|
||||
got2, err := s.repo.GetByCode(s.ctx, "BATCH-2")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(float64(20), got2.Value)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
s.Require().ErrorIs(err, service.ErrRedeemCodeNotFound)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestGetByCode() {
|
||||
_, err := s.client.RedeemCode.Create().
|
||||
SetCode("GET-BY-CODE").
|
||||
SetType(service.RedeemTypeBalance).
|
||||
SetStatus(service.StatusUnused).
|
||||
SetValue(0).
|
||||
SetNotes("").
|
||||
SetValidityDays(30).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "seed redeem code")
|
||||
|
||||
got, err := s.repo.GetByCode(s.ctx, "GET-BY-CODE")
|
||||
s.Require().NoError(err, "GetByCode")
|
||||
s.Require().Equal("GET-BY-CODE", got.Code)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestGetByCode_NotFound() {
|
||||
_, err := s.repo.GetByCode(s.ctx, "NON-EXISTENT")
|
||||
s.Require().Error(err, "expected error for non-existent code")
|
||||
s.Require().ErrorIs(err, service.ErrRedeemCodeNotFound)
|
||||
}
|
||||
|
||||
// --- Delete ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestDelete() {
|
||||
created, err := s.client.RedeemCode.Create().
|
||||
SetCode("TO-DELETE").
|
||||
SetType(service.RedeemTypeBalance).
|
||||
SetStatus(service.StatusUnused).
|
||||
SetValue(0).
|
||||
SetNotes("").
|
||||
SetValidityDays(30).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.repo.Delete(s.ctx, created.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, created.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
s.Require().ErrorIs(err, service.ErrRedeemCodeNotFound)
|
||||
}
|
||||
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestList() {
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "LIST-1", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "LIST-2", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
|
||||
|
||||
codes, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(codes, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListWithFilters_Type() {
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "TYPE-BAL", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "TYPE-SUB", Type: service.RedeemTypeSubscription, Value: 0, Status: service.StatusUnused}))
|
||||
|
||||
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.RedeemTypeSubscription, "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(codes, 1)
|
||||
s.Require().Equal(service.RedeemTypeSubscription, codes[0].Type)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() {
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "STAT-UNUSED", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "STAT-USED", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUsed}))
|
||||
|
||||
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusUsed, "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(codes, 1)
|
||||
s.Require().Equal(service.StatusUsed, codes[0].Status)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "ALPHA-CODE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "BETA-CODE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
|
||||
|
||||
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alpha")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(codes, 1)
|
||||
s.Require().Contains(codes[0].Code, "ALPHA")
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() {
|
||||
group := s.createGroup(uniqueTestValue(s.T(), "g-preload"))
|
||||
_, err := s.client.RedeemCode.Create().
|
||||
SetCode("WITH-GROUP").
|
||||
SetType(service.RedeemTypeSubscription).
|
||||
SetStatus(service.StatusUnused).
|
||||
SetValue(0).
|
||||
SetNotes("").
|
||||
SetValidityDays(30).
|
||||
SetGroupID(group.ID).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(codes, 1)
|
||||
s.Require().NotNil(codes[0].Group, "expected Group preload")
|
||||
s.Require().Equal(group.ID, codes[0].Group.ID)
|
||||
}
|
||||
|
||||
// --- Update ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestUpdate() {
|
||||
code := &service.RedeemCode{
|
||||
Code: "UPDATE-ME",
|
||||
Type: service.RedeemTypeBalance,
|
||||
Value: 10,
|
||||
Status: service.StatusUnused,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, code))
|
||||
|
||||
code.Value = 50
|
||||
err := s.repo.Update(s.ctx, code)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, code.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(float64(50), got.Value)
|
||||
}
|
||||
|
||||
// --- Use ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestUse() {
|
||||
user := s.createUser(uniqueTestValue(s.T(), "use") + "@example.com")
|
||||
code := &service.RedeemCode{Code: "USE-ME", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, code))
|
||||
|
||||
err := s.repo.Use(s.ctx, code.ID, user.ID)
|
||||
s.Require().NoError(err, "Use")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, code.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(service.StatusUsed, got.Status)
|
||||
s.Require().NotNil(got.UsedBy)
|
||||
s.Require().Equal(user.ID, *got.UsedBy)
|
||||
s.Require().NotNil(got.UsedAt)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
|
||||
user := s.createUser(uniqueTestValue(s.T(), "idem") + "@example.com")
|
||||
code := &service.RedeemCode{Code: "IDEM-CODE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, code))
|
||||
|
||||
err := s.repo.Use(s.ctx, code.ID, user.ID)
|
||||
s.Require().NoError(err, "Use first time")
|
||||
|
||||
// Second use should fail
|
||||
err = s.repo.Use(s.ctx, code.ID, user.ID)
|
||||
s.Require().Error(err, "Use expected error on second call")
|
||||
s.Require().ErrorIs(err, service.ErrRedeemCodeUsed)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
|
||||
user := s.createUser(uniqueTestValue(s.T(), "already") + "@example.com")
|
||||
code := &service.RedeemCode{Code: "ALREADY-USED", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUsed}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, code))
|
||||
|
||||
err := s.repo.Use(s.ctx, code.ID, user.ID)
|
||||
s.Require().Error(err, "expected error for already used code")
|
||||
s.Require().ErrorIs(err, service.ErrRedeemCodeUsed)
|
||||
}
|
||||
|
||||
// --- ListByUser ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListByUser() {
|
||||
user := s.createUser(uniqueTestValue(s.T(), "listby") + "@example.com")
|
||||
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
usedAt1 := base
|
||||
_, err := s.client.RedeemCode.Create().
|
||||
SetCode("USER-1").
|
||||
SetType(service.RedeemTypeBalance).
|
||||
SetStatus(service.StatusUsed).
|
||||
SetValue(0).
|
||||
SetNotes("").
|
||||
SetValidityDays(30).
|
||||
SetUsedBy(user.ID).
|
||||
SetUsedAt(usedAt1).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
usedAt2 := base.Add(1 * time.Hour)
|
||||
_, err = s.client.RedeemCode.Create().
|
||||
SetCode("USER-2").
|
||||
SetType(service.RedeemTypeBalance).
|
||||
SetStatus(service.StatusUsed).
|
||||
SetValue(0).
|
||||
SetNotes("").
|
||||
SetValidityDays(30).
|
||||
SetUsedBy(user.ID).
|
||||
SetUsedAt(usedAt2).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
codes, err := s.repo.ListByUser(s.ctx, user.ID, 10)
|
||||
s.Require().NoError(err, "ListByUser")
|
||||
s.Require().Len(codes, 2)
|
||||
// Ordered by used_at DESC, so USER-2 first
|
||||
s.Require().Equal("USER-2", codes[0].Code)
|
||||
s.Require().Equal("USER-1", codes[1].Code)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() {
|
||||
user := s.createUser(uniqueTestValue(s.T(), "grp") + "@example.com")
|
||||
group := s.createGroup(uniqueTestValue(s.T(), "g-listby"))
|
||||
|
||||
_, err := s.client.RedeemCode.Create().
|
||||
SetCode("WITH-GRP").
|
||||
SetType(service.RedeemTypeSubscription).
|
||||
SetStatus(service.StatusUsed).
|
||||
SetValue(0).
|
||||
SetNotes("").
|
||||
SetValidityDays(30).
|
||||
SetUsedBy(user.ID).
|
||||
SetUsedAt(time.Now()).
|
||||
SetGroupID(group.ID).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
codes, err := s.repo.ListByUser(s.ctx, user.ID, 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(codes, 1)
|
||||
s.Require().NotNil(codes[0].Group)
|
||||
s.Require().Equal(group.ID, codes[0].Group.ID)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() {
|
||||
user := s.createUser(uniqueTestValue(s.T(), "deflimit") + "@example.com")
|
||||
_, err := s.client.RedeemCode.Create().
|
||||
SetCode("DEF-LIM").
|
||||
SetType(service.RedeemTypeBalance).
|
||||
SetStatus(service.StatusUsed).
|
||||
SetValue(0).
|
||||
SetNotes("").
|
||||
SetValidityDays(30).
|
||||
SetUsedBy(user.ID).
|
||||
SetUsedAt(time.Now()).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// limit <= 0 should default to 10
|
||||
codes, err := s.repo.ListByUser(s.ctx, user.ID, 0)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(codes, 1)
|
||||
}
|
||||
|
||||
// --- Combined original test ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser() {
|
||||
user := s.createUser(uniqueTestValue(s.T(), "rc") + "@example.com")
|
||||
group := s.createGroup(uniqueTestValue(s.T(), "g-rc"))
|
||||
groupID := group.ID
|
||||
|
||||
codes := []service.RedeemCode{
|
||||
{Code: "CODEA", Type: service.RedeemTypeBalance, Value: 1, Status: service.StatusUnused, Notes: ""},
|
||||
{Code: "CODEB", Type: service.RedeemTypeSubscription, Value: 0, Status: service.StatusUnused, Notes: "", GroupID: &groupID, ValidityDays: 7},
|
||||
}
|
||||
s.Require().NoError(s.repo.CreateBatch(s.ctx, codes), "CreateBatch")
|
||||
|
||||
list, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.RedeemTypeSubscription, service.StatusUnused, "code")
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(list, 1)
|
||||
s.Require().NotNil(list[0].Group, "expected Group preload")
|
||||
s.Require().Equal(group.ID, list[0].Group.ID)
|
||||
|
||||
codeB, err := s.repo.GetByCode(s.ctx, "CODEB")
|
||||
s.Require().NoError(err, "GetByCode")
|
||||
s.Require().NoError(s.repo.Use(s.ctx, codeB.ID, user.ID), "Use")
|
||||
err = s.repo.Use(s.ctx, codeB.ID, user.ID)
|
||||
s.Require().Error(err, "Use expected error on second call")
|
||||
s.Require().ErrorIs(err, service.ErrRedeemCodeUsed)
|
||||
|
||||
codeA, err := s.repo.GetByCode(s.ctx, "CODEA")
|
||||
s.Require().NoError(err, "GetByCode")
|
||||
|
||||
// Use fixed time instead of time.Sleep for deterministic ordering.
|
||||
_, err = s.client.RedeemCode.UpdateOneID(codeB.ID).
|
||||
SetUsedAt(time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NoError(s.repo.Use(s.ctx, codeA.ID, user.ID), "Use codeA")
|
||||
_, err = s.client.RedeemCode.UpdateOneID(codeA.ID).
|
||||
SetUsedAt(time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC)).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
used, err := s.repo.ListByUser(s.ctx, user.ID, 10)
|
||||
s.Require().NoError(err, "ListByUser")
|
||||
s.Require().Len(used, 2, "expected 2 used codes")
|
||||
s.Require().Equal("CODEA", used[0].Code, "expected newest used code first")
|
||||
}
|
||||
39
backend/internal/repository/redis.go
Normal file
39
backend/internal/repository/redis.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// InitRedis 初始化 Redis 客户端
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现使用 go-redis 默认配置,未设置连接池和超时参数:
|
||||
// 1. 默认连接池大小可能不足以支撑高并发
|
||||
// 2. 无超时控制可能导致慢操作阻塞
|
||||
//
|
||||
// 新实现支持可配置的连接池和超时参数:
|
||||
// 1. PoolSize: 控制最大并发连接数(默认 128)
|
||||
// 2. MinIdleConns: 保持最小空闲连接,减少冷启动延迟(默认 10)
|
||||
// 3. DialTimeout/ReadTimeout/WriteTimeout: 精确控制各阶段超时
|
||||
func InitRedis(cfg *config.Config) *redis.Client {
|
||||
return redis.NewClient(buildRedisOptions(cfg))
|
||||
}
|
||||
|
||||
// buildRedisOptions 构建 Redis 连接选项
|
||||
// 从配置文件读取连接池和超时参数,支持生产环境调优
|
||||
func buildRedisOptions(cfg *config.Config) *redis.Options {
|
||||
return &redis.Options{
|
||||
Addr: cfg.Redis.Address(),
|
||||
Password: cfg.Redis.Password,
|
||||
DB: cfg.Redis.DB,
|
||||
DialTimeout: time.Duration(cfg.Redis.DialTimeoutSeconds) * time.Second, // 建连超时
|
||||
ReadTimeout: time.Duration(cfg.Redis.ReadTimeoutSeconds) * time.Second, // 读取超时
|
||||
WriteTimeout: time.Duration(cfg.Redis.WriteTimeoutSeconds) * time.Second, // 写入超时
|
||||
PoolSize: cfg.Redis.PoolSize, // 连接池大小
|
||||
MinIdleConns: cfg.Redis.MinIdleConns, // 最小空闲连接
|
||||
}
|
||||
}
|
||||
35
backend/internal/repository/redis_test.go
Normal file
35
backend/internal/repository/redis_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildRedisOptions(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Redis: config.RedisConfig{
|
||||
Host: "localhost",
|
||||
Port: 6379,
|
||||
Password: "secret",
|
||||
DB: 2,
|
||||
DialTimeoutSeconds: 5,
|
||||
ReadTimeoutSeconds: 3,
|
||||
WriteTimeoutSeconds: 4,
|
||||
PoolSize: 100,
|
||||
MinIdleConns: 10,
|
||||
},
|
||||
}
|
||||
|
||||
opts := buildRedisOptions(cfg)
|
||||
require.Equal(t, "localhost:6379", opts.Addr)
|
||||
require.Equal(t, "secret", opts.Password)
|
||||
require.Equal(t, 2, opts.DB)
|
||||
require.Equal(t, 5*time.Second, opts.DialTimeout)
|
||||
require.Equal(t, 3*time.Second, opts.ReadTimeout)
|
||||
require.Equal(t, 4*time.Second, opts.WriteTimeout)
|
||||
require.Equal(t, 100, opts.PoolSize)
|
||||
require.Equal(t, 10, opts.MinIdleConns)
|
||||
}
|
||||
64
backend/internal/repository/req_client_pool.go
Normal file
64
backend/internal/repository/req_client_pool.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
// reqClientOptions 定义 req 客户端的构建参数
|
||||
type reqClientOptions struct {
|
||||
ProxyURL string // 代理 URL(支持 http/https/socks5)
|
||||
Timeout time.Duration // 请求超时时间
|
||||
Impersonate bool // 是否模拟 Chrome 浏览器指纹
|
||||
}
|
||||
|
||||
// sharedReqClients 存储按配置参数缓存的 req 客户端实例
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现在每次 OAuth 刷新时都创建新的 req.Client:
|
||||
// 1. claude_oauth_service.go: 每次刷新创建新客户端
|
||||
// 2. openai_oauth_service.go: 每次刷新创建新客户端
|
||||
// 3. gemini_oauth_client.go: 每次刷新创建新客户端
|
||||
//
|
||||
// 新实现使用 sync.Map 缓存客户端:
|
||||
// 1. 相同配置(代理+超时+模拟设置)复用同一客户端
|
||||
// 2. 复用底层连接池,减少 TLS 握手开销
|
||||
// 3. LoadOrStore 保证并发安全,避免重复创建
|
||||
var sharedReqClients sync.Map
|
||||
|
||||
// getSharedReqClient 获取共享的 req 客户端实例
|
||||
// 性能优化:相同配置复用同一客户端,避免重复创建
|
||||
func getSharedReqClient(opts reqClientOptions) *req.Client {
|
||||
key := buildReqClientKey(opts)
|
||||
if cached, ok := sharedReqClients.Load(key); ok {
|
||||
if c, ok := cached.(*req.Client); ok {
|
||||
return c
|
||||
}
|
||||
}
|
||||
|
||||
client := req.C().SetTimeout(opts.Timeout)
|
||||
if opts.Impersonate {
|
||||
client = client.ImpersonateChrome()
|
||||
}
|
||||
if strings.TrimSpace(opts.ProxyURL) != "" {
|
||||
client.SetProxyURL(strings.TrimSpace(opts.ProxyURL))
|
||||
}
|
||||
|
||||
actual, _ := sharedReqClients.LoadOrStore(key, client)
|
||||
if c, ok := actual.(*req.Client); ok {
|
||||
return c
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
func buildReqClientKey(opts reqClientOptions) string {
|
||||
return fmt.Sprintf("%s|%s|%t",
|
||||
strings.TrimSpace(opts.ProxyURL),
|
||||
opts.Timeout.String(),
|
||||
opts.Impersonate,
|
||||
)
|
||||
}
|
||||
276
backend/internal/repository/scheduler_cache.go
Normal file
276
backend/internal/repository/scheduler_cache.go
Normal file
@@ -0,0 +1,276 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
schedulerBucketSetKey = "sched:buckets"
|
||||
schedulerOutboxWatermarkKey = "sched:outbox:watermark"
|
||||
schedulerAccountPrefix = "sched:acc:"
|
||||
schedulerActivePrefix = "sched:active:"
|
||||
schedulerReadyPrefix = "sched:ready:"
|
||||
schedulerVersionPrefix = "sched:ver:"
|
||||
schedulerSnapshotPrefix = "sched:"
|
||||
schedulerLockPrefix = "sched:lock:"
|
||||
)
|
||||
|
||||
type schedulerCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewSchedulerCache(rdb *redis.Client) service.SchedulerCache {
|
||||
return &schedulerCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
|
||||
readyKey := schedulerBucketKey(schedulerReadyPrefix, bucket)
|
||||
readyVal, err := c.rdb.Get(ctx, readyKey).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if readyVal != "1" {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
|
||||
activeVal, err := c.rdb.Get(ctx, activeKey).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
snapshotKey := schedulerSnapshotKey(bucket, activeVal)
|
||||
ids, err := c.rdb.ZRange(ctx, snapshotKey, 0, -1).Result()
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return []*service.Account{}, true, nil
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
keys = append(keys, schedulerAccountKey(id))
|
||||
}
|
||||
values, err := c.rdb.MGet(ctx, keys...).Result()
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
accounts := make([]*service.Account, 0, len(values))
|
||||
for _, val := range values {
|
||||
if val == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
account, err := decodeCachedAccount(val)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
accounts = append(accounts, account)
|
||||
}
|
||||
|
||||
return accounts, true, nil
|
||||
}
|
||||
|
||||
func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
|
||||
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
|
||||
oldActive, _ := c.rdb.Get(ctx, activeKey).Result()
|
||||
|
||||
versionKey := schedulerBucketKey(schedulerVersionPrefix, bucket)
|
||||
version, err := c.rdb.Incr(ctx, versionKey).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
versionStr := strconv.FormatInt(version, 10)
|
||||
snapshotKey := schedulerSnapshotKey(bucket, versionStr)
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
for _, account := range accounts {
|
||||
payload, err := json.Marshal(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pipe.Set(ctx, schedulerAccountKey(strconv.FormatInt(account.ID, 10)), payload, 0)
|
||||
}
|
||||
if len(accounts) > 0 {
|
||||
// 使用序号作为 score,保持数据库返回的排序语义。
|
||||
members := make([]redis.Z, 0, len(accounts))
|
||||
for idx, account := range accounts {
|
||||
members = append(members, redis.Z{
|
||||
Score: float64(idx),
|
||||
Member: strconv.FormatInt(account.ID, 10),
|
||||
})
|
||||
}
|
||||
pipe.ZAdd(ctx, snapshotKey, members...)
|
||||
} else {
|
||||
pipe.Del(ctx, snapshotKey)
|
||||
}
|
||||
pipe.Set(ctx, activeKey, versionStr, 0)
|
||||
pipe.Set(ctx, schedulerBucketKey(schedulerReadyPrefix, bucket), "1", 0)
|
||||
pipe.SAdd(ctx, schedulerBucketSetKey, bucket.String())
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if oldActive != "" && oldActive != versionStr {
|
||||
_ = c.rdb.Del(ctx, schedulerSnapshotKey(bucket, oldActive)).Err()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *schedulerCache) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
|
||||
key := schedulerAccountKey(strconv.FormatInt(accountID, 10))
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decodeCachedAccount(val)
|
||||
}
|
||||
|
||||
func (c *schedulerCache) SetAccount(ctx context.Context, account *service.Account) error {
|
||||
if account == nil || account.ID <= 0 {
|
||||
return nil
|
||||
}
|
||||
payload, err := json.Marshal(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key := schedulerAccountKey(strconv.FormatInt(account.ID, 10))
|
||||
return c.rdb.Set(ctx, key, payload, 0).Err()
|
||||
}
|
||||
|
||||
func (c *schedulerCache) DeleteAccount(ctx context.Context, accountID int64) error {
|
||||
if accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
key := schedulerAccountKey(strconv.FormatInt(accountID, 10))
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||
if len(updates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(updates))
|
||||
ids := make([]int64, 0, len(updates))
|
||||
for id := range updates {
|
||||
keys = append(keys, schedulerAccountKey(strconv.FormatInt(id, 10)))
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
values, err := c.rdb.MGet(ctx, keys...).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
for i, val := range values {
|
||||
if val == nil {
|
||||
continue
|
||||
}
|
||||
account, err := decodeCachedAccount(val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
account.LastUsedAt = ptrTime(updates[ids[i]])
|
||||
updated, err := json.Marshal(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pipe.Set(ctx, keys[i], updated, 0)
|
||||
}
|
||||
_, err = pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *schedulerCache) TryLockBucket(ctx context.Context, bucket service.SchedulerBucket, ttl time.Duration) (bool, error) {
|
||||
key := schedulerBucketKey(schedulerLockPrefix, bucket)
|
||||
return c.rdb.SetNX(ctx, key, time.Now().UnixNano(), ttl).Result()
|
||||
}
|
||||
|
||||
func (c *schedulerCache) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
|
||||
raw, err := c.rdb.SMembers(ctx, schedulerBucketSetKey).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]service.SchedulerBucket, 0, len(raw))
|
||||
for _, entry := range raw {
|
||||
bucket, ok := service.ParseSchedulerBucket(entry)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
out = append(out, bucket)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *schedulerCache) GetOutboxWatermark(ctx context.Context) (int64, error) {
|
||||
val, err := c.rdb.Get(ctx, schedulerOutboxWatermarkKey).Result()
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
id, err := strconv.ParseInt(val, 10, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (c *schedulerCache) SetOutboxWatermark(ctx context.Context, id int64) error {
|
||||
return c.rdb.Set(ctx, schedulerOutboxWatermarkKey, strconv.FormatInt(id, 10), 0).Err()
|
||||
}
|
||||
|
||||
func schedulerBucketKey(prefix string, bucket service.SchedulerBucket) string {
|
||||
return fmt.Sprintf("%s%d:%s:%s", prefix, bucket.GroupID, bucket.Platform, bucket.Mode)
|
||||
}
|
||||
|
||||
func schedulerSnapshotKey(bucket service.SchedulerBucket, version string) string {
|
||||
return fmt.Sprintf("%s%d:%s:%s:v%s", schedulerSnapshotPrefix, bucket.GroupID, bucket.Platform, bucket.Mode, version)
|
||||
}
|
||||
|
||||
func schedulerAccountKey(id string) string {
|
||||
return schedulerAccountPrefix + id
|
||||
}
|
||||
|
||||
func ptrTime(t time.Time) *time.Time {
|
||||
return &t
|
||||
}
|
||||
|
||||
func decodeCachedAccount(val any) (*service.Account, error) {
|
||||
var payload []byte
|
||||
switch raw := val.(type) {
|
||||
case string:
|
||||
payload = []byte(raw)
|
||||
case []byte:
|
||||
payload = raw
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected account cache type: %T", val)
|
||||
}
|
||||
var account service.Account
|
||||
if err := json.Unmarshal(payload, &account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &account, nil
|
||||
}
|
||||
96
backend/internal/repository/scheduler_outbox_repo.go
Normal file
96
backend/internal/repository/scheduler_outbox_repo.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type schedulerOutboxRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewSchedulerOutboxRepository(db *sql.DB) service.SchedulerOutboxRepository {
|
||||
return &schedulerOutboxRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *schedulerOutboxRepository) ListAfter(ctx context.Context, afterID int64, limit int) ([]service.SchedulerOutboxEvent, error) {
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, event_type, account_id, group_id, payload, created_at
|
||||
FROM scheduler_outbox
|
||||
WHERE id > $1
|
||||
ORDER BY id ASC
|
||||
LIMIT $2
|
||||
`, afterID, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
events := make([]service.SchedulerOutboxEvent, 0, limit)
|
||||
for rows.Next() {
|
||||
var (
|
||||
payloadRaw []byte
|
||||
accountID sql.NullInt64
|
||||
groupID sql.NullInt64
|
||||
event service.SchedulerOutboxEvent
|
||||
)
|
||||
if err := rows.Scan(&event.ID, &event.EventType, &accountID, &groupID, &payloadRaw, &event.CreatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accountID.Valid {
|
||||
v := accountID.Int64
|
||||
event.AccountID = &v
|
||||
}
|
||||
if groupID.Valid {
|
||||
v := groupID.Int64
|
||||
event.GroupID = &v
|
||||
}
|
||||
if len(payloadRaw) > 0 {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(payloadRaw, &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
event.Payload = payload
|
||||
}
|
||||
events = append(events, event)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return events, nil
|
||||
}
|
||||
|
||||
func (r *schedulerOutboxRepository) MaxID(ctx context.Context) (int64, error) {
|
||||
var maxID int64
|
||||
if err := r.db.QueryRowContext(ctx, "SELECT COALESCE(MAX(id), 0) FROM scheduler_outbox").Scan(&maxID); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return maxID, nil
|
||||
}
|
||||
|
||||
func enqueueSchedulerOutbox(ctx context.Context, exec sqlExecutor, eventType string, accountID *int64, groupID *int64, payload any) error {
|
||||
if exec == nil {
|
||||
return nil
|
||||
}
|
||||
var payloadArg any
|
||||
if payload != nil {
|
||||
encoded, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payloadArg = encoded
|
||||
}
|
||||
_, err := exec.ExecContext(ctx, `
|
||||
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
`, eventType, accountID, groupID, payloadArg)
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rdb := testRedis(t)
|
||||
client := testEntClient(t)
|
||||
|
||||
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox")
|
||||
|
||||
accountRepo := newAccountRepositoryWithSQL(client, integrationDB)
|
||||
outboxRepo := NewSchedulerOutboxRepository(integrationDB)
|
||||
cache := NewSchedulerCache(rdb)
|
||||
|
||||
cfg := &config.Config{
|
||||
RunMode: config.RunModeStandard,
|
||||
Gateway: config.GatewayConfig{
|
||||
Scheduling: config.GatewaySchedulingConfig{
|
||||
OutboxPollIntervalSeconds: 1,
|
||||
FullRebuildIntervalSeconds: 0,
|
||||
DbFallbackEnabled: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
account := &service.Account{
|
||||
Name: "outbox-replay-" + time.Now().Format("150405.000000"),
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 3,
|
||||
Priority: 1,
|
||||
Credentials: map[string]any{},
|
||||
Extra: map[string]any{},
|
||||
}
|
||||
require.NoError(t, accountRepo.Create(ctx, account))
|
||||
require.NoError(t, cache.SetAccount(ctx, account))
|
||||
|
||||
svc := service.NewSchedulerSnapshotService(cache, outboxRepo, accountRepo, nil, cfg)
|
||||
svc.Start()
|
||||
t.Cleanup(svc.Stop)
|
||||
|
||||
require.NoError(t, accountRepo.UpdateLastUsed(ctx, account.ID))
|
||||
updated, err := accountRepo.GetByID(ctx, account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated.LastUsedAt)
|
||||
expectedUnix := updated.LastUsedAt.Unix()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
cached, err := cache.GetAccount(ctx, account.ID)
|
||||
if err != nil || cached == nil || cached.LastUsedAt == nil {
|
||||
return false
|
||||
}
|
||||
return cached.LastUsedAt.Unix() == expectedUnix
|
||||
}, 5*time.Second, 100*time.Millisecond)
|
||||
}
|
||||
105
backend/internal/repository/setting_repo.go
Normal file
105
backend/internal/repository/setting_repo.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/setting"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type settingRepository struct {
|
||||
client *ent.Client
|
||||
}
|
||||
|
||||
func NewSettingRepository(client *ent.Client) service.SettingRepository {
|
||||
return &settingRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *settingRepository) Get(ctx context.Context, key string) (*service.Setting, error) {
|
||||
m, err := r.client.Setting.Query().Where(setting.KeyEQ(key)).Only(ctx)
|
||||
if err != nil {
|
||||
if ent.IsNotFound(err) {
|
||||
return nil, service.ErrSettingNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &service.Setting{
|
||||
ID: m.ID,
|
||||
Key: m.Key,
|
||||
Value: m.Value,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *settingRepository) GetValue(ctx context.Context, key string) (string, error) {
|
||||
setting, err := r.Get(ctx, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return setting.Value, nil
|
||||
}
|
||||
|
||||
func (r *settingRepository) Set(ctx context.Context, key, value string) error {
|
||||
now := time.Now()
|
||||
return r.client.Setting.
|
||||
Create().
|
||||
SetKey(key).
|
||||
SetValue(value).
|
||||
SetUpdatedAt(now).
|
||||
OnConflictColumns(setting.FieldKey).
|
||||
UpdateNewValues().
|
||||
Exec(ctx)
|
||||
}
|
||||
|
||||
func (r *settingRepository) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
if len(keys) == 0 {
|
||||
return map[string]string{}, nil
|
||||
}
|
||||
settings, err := r.client.Setting.Query().Where(setting.KeyIn(keys...)).All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make(map[string]string)
|
||||
for _, s := range settings {
|
||||
result[s.Key] = s.Value
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *settingRepository) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
if len(settings) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
builders := make([]*ent.SettingCreate, 0, len(settings))
|
||||
for key, value := range settings {
|
||||
builders = append(builders, r.client.Setting.Create().SetKey(key).SetValue(value).SetUpdatedAt(now))
|
||||
}
|
||||
return r.client.Setting.
|
||||
CreateBulk(builders...).
|
||||
OnConflictColumns(setting.FieldKey).
|
||||
UpdateNewValues().
|
||||
Exec(ctx)
|
||||
}
|
||||
|
||||
func (r *settingRepository) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
settings, err := r.client.Setting.Query().All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make(map[string]string)
|
||||
for _, s := range settings {
|
||||
result[s.Key] = s.Value
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *settingRepository) Delete(ctx context.Context, key string) error {
|
||||
_, err := r.client.Setting.Delete().Where(setting.KeyEQ(key)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
163
backend/internal/repository/setting_repo_integration_test.go
Normal file
163
backend/internal/repository/setting_repo_integration_test.go
Normal file
@@ -0,0 +1,163 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type SettingRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
repo *settingRepository
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.repo = NewSettingRepository(tx.Client()).(*settingRepository)
|
||||
}
|
||||
|
||||
func TestSettingRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(SettingRepoSuite))
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) TestSetAndGetValue() {
|
||||
s.Require().NoError(s.repo.Set(s.ctx, "k1", "v1"), "Set")
|
||||
got, err := s.repo.GetValue(s.ctx, "k1")
|
||||
s.Require().NoError(err, "GetValue")
|
||||
s.Require().Equal("v1", got, "GetValue mismatch")
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) TestSet_Upsert() {
|
||||
s.Require().NoError(s.repo.Set(s.ctx, "k1", "v1"), "Set")
|
||||
s.Require().NoError(s.repo.Set(s.ctx, "k1", "v2"), "Set upsert")
|
||||
got, err := s.repo.GetValue(s.ctx, "k1")
|
||||
s.Require().NoError(err, "GetValue after upsert")
|
||||
s.Require().Equal("v2", got, "upsert mismatch")
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) TestGetValue_Missing() {
|
||||
_, err := s.repo.GetValue(s.ctx, "nonexistent")
|
||||
s.Require().Error(err, "expected error for missing key")
|
||||
s.Require().ErrorIs(err, service.ErrSettingNotFound)
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) TestSetMultiple_AndGetMultiple() {
|
||||
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"k2": "v2", "k3": "v3"}), "SetMultiple")
|
||||
m, err := s.repo.GetMultiple(s.ctx, []string{"k2", "k3"})
|
||||
s.Require().NoError(err, "GetMultiple")
|
||||
s.Require().Equal("v2", m["k2"])
|
||||
s.Require().Equal("v3", m["k3"])
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) TestGetMultiple_EmptyKeys() {
|
||||
m, err := s.repo.GetMultiple(s.ctx, []string{})
|
||||
s.Require().NoError(err, "GetMultiple with empty keys")
|
||||
s.Require().Empty(m, "expected empty map")
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) TestGetMultiple_Subset() {
|
||||
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"a": "1", "b": "2", "c": "3"}))
|
||||
m, err := s.repo.GetMultiple(s.ctx, []string{"a", "c", "nonexistent"})
|
||||
s.Require().NoError(err, "GetMultiple subset")
|
||||
s.Require().Equal("1", m["a"])
|
||||
s.Require().Equal("3", m["c"])
|
||||
_, exists := m["nonexistent"]
|
||||
s.Require().False(exists, "nonexistent key should not be in map")
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) TestGetAll() {
|
||||
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"x": "1", "y": "2"}))
|
||||
all, err := s.repo.GetAll(s.ctx)
|
||||
s.Require().NoError(err, "GetAll")
|
||||
s.Require().GreaterOrEqual(len(all), 2, "expected at least 2 settings")
|
||||
s.Require().Equal("1", all["x"])
|
||||
s.Require().Equal("2", all["y"])
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) TestDelete() {
|
||||
s.Require().NoError(s.repo.Set(s.ctx, "todelete", "val"))
|
||||
s.Require().NoError(s.repo.Delete(s.ctx, "todelete"), "Delete")
|
||||
_, err := s.repo.GetValue(s.ctx, "todelete")
|
||||
s.Require().Error(err, "expected missing key error after Delete")
|
||||
s.Require().ErrorIs(err, service.ErrSettingNotFound)
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) TestDelete_Idempotent() {
|
||||
// Delete a key that doesn't exist should not error
|
||||
s.Require().NoError(s.repo.Delete(s.ctx, "nonexistent_delete"), "Delete nonexistent should be idempotent")
|
||||
}
|
||||
|
||||
func (s *SettingRepoSuite) TestSetMultiple_Upsert() {
|
||||
s.Require().NoError(s.repo.Set(s.ctx, "upsert_key", "old_value"))
|
||||
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"upsert_key": "new_value", "new_key": "new_val"}))
|
||||
|
||||
got, err := s.repo.GetValue(s.ctx, "upsert_key")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("new_value", got, "SetMultiple should upsert existing key")
|
||||
|
||||
got2, err := s.repo.GetValue(s.ctx, "new_key")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("new_val", got2)
|
||||
}
|
||||
|
||||
// TestSet_EmptyValue 测试保存空字符串值
|
||||
// 这是一个回归测试,确保可选设置(如站点Logo、API端点地址等)可以保存为空字符串
|
||||
func (s *SettingRepoSuite) TestSet_EmptyValue() {
|
||||
// 测试 Set 方法保存空值
|
||||
s.Require().NoError(s.repo.Set(s.ctx, "empty_key", ""), "Set with empty value should succeed")
|
||||
|
||||
got, err := s.repo.GetValue(s.ctx, "empty_key")
|
||||
s.Require().NoError(err, "GetValue for empty value")
|
||||
s.Require().Equal("", got, "empty value should be preserved")
|
||||
}
|
||||
|
||||
// TestSetMultiple_WithEmptyValues 测试批量保存包含空字符串的设置
|
||||
// 模拟用户保存站点设置时部分字段为空的场景
|
||||
func (s *SettingRepoSuite) TestSetMultiple_WithEmptyValues() {
|
||||
// 模拟保存站点设置,部分字段有值,部分字段为空
|
||||
settings := map[string]string{
|
||||
"site_name": "AICodex2API",
|
||||
"site_subtitle": "Subscription to API",
|
||||
"site_logo": "", // 用户未上传Logo
|
||||
"api_base_url": "", // 用户未设置API地址
|
||||
"contact_info": "", // 用户未设置联系方式
|
||||
"doc_url": "", // 用户未设置文档链接
|
||||
}
|
||||
|
||||
s.Require().NoError(s.repo.SetMultiple(s.ctx, settings), "SetMultiple with empty values should succeed")
|
||||
|
||||
// 验证所有值都正确保存
|
||||
result, err := s.repo.GetMultiple(s.ctx, []string{"site_name", "site_subtitle", "site_logo", "api_base_url", "contact_info", "doc_url"})
|
||||
s.Require().NoError(err, "GetMultiple after SetMultiple with empty values")
|
||||
|
||||
s.Require().Equal("AICodex2API", result["site_name"])
|
||||
s.Require().Equal("Subscription to API", result["site_subtitle"])
|
||||
s.Require().Equal("", result["site_logo"], "empty site_logo should be preserved")
|
||||
s.Require().Equal("", result["api_base_url"], "empty api_base_url should be preserved")
|
||||
s.Require().Equal("", result["contact_info"], "empty contact_info should be preserved")
|
||||
s.Require().Equal("", result["doc_url"], "empty doc_url should be preserved")
|
||||
}
|
||||
|
||||
// TestSetMultiple_UpdateToEmpty 测试将已有值更新为空字符串
|
||||
// 确保用户可以清空之前设置的值
|
||||
func (s *SettingRepoSuite) TestSetMultiple_UpdateToEmpty() {
|
||||
// 先设置非空值
|
||||
s.Require().NoError(s.repo.Set(s.ctx, "clearable_key", "initial_value"))
|
||||
|
||||
got, err := s.repo.GetValue(s.ctx, "clearable_key")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("initial_value", got)
|
||||
|
||||
// 更新为空值
|
||||
s.Require().NoError(s.repo.SetMultiple(s.ctx, map[string]string{"clearable_key": ""}), "Update to empty should succeed")
|
||||
|
||||
got, err = s.repo.GetValue(s.ctx, "clearable_key")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("", got, "value should be updated to empty string")
|
||||
}
|
||||
216
backend/internal/repository/soft_delete_ent_integration_test.go
Normal file
216
backend/internal/repository/soft_delete_ent_integration_test.go
Normal file
@@ -0,0 +1,216 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func uniqueSoftDeleteValue(t *testing.T, prefix string) string {
|
||||
t.Helper()
|
||||
safeName := strings.NewReplacer("/", "_", " ", "_").Replace(t.Name())
|
||||
return fmt.Sprintf("%s-%s", prefix, safeName)
|
||||
}
|
||||
|
||||
func createEntUser(t *testing.T, ctx context.Context, client *dbent.Client, email string) *dbent.User {
|
||||
t.Helper()
|
||||
|
||||
u, err := client.User.Create().
|
||||
SetEmail(email).
|
||||
SetPasswordHash("test-password-hash").
|
||||
Save(ctx)
|
||||
require.NoError(t, err, "create ent user")
|
||||
return u
|
||||
}
|
||||
|
||||
func TestEntSoftDelete_ApiKey_DefaultFilterAndSkip(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
// 使用全局 ent client,确保软删除验证在实际持久化数据上进行。
|
||||
client := testEntClient(t)
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user")+"@example.com")
|
||||
|
||||
repo := NewAPIKeyRepository(client)
|
||||
key := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueSoftDeleteValue(t, "sk-soft-delete"),
|
||||
Name: "soft-delete",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, key), "create api key")
|
||||
|
||||
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
|
||||
|
||||
_, err := repo.GetByID(ctx, key.ID)
|
||||
require.ErrorIs(t, err, service.ErrAPIKeyNotFound, "deleted rows should be hidden by default")
|
||||
|
||||
_, err = client.APIKey.Query().Where(apikey.IDEQ(key.ID)).Only(ctx)
|
||||
require.Error(t, err, "default ent query should not see soft-deleted rows")
|
||||
require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter")
|
||||
|
||||
got, err := client.APIKey.Query().
|
||||
Where(apikey.IDEQ(key.ID)).
|
||||
Only(mixins.SkipSoftDelete(ctx))
|
||||
require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows")
|
||||
require.NotNil(t, got.DeletedAt, "deleted_at should be set after soft delete")
|
||||
}
|
||||
|
||||
func TestEntSoftDelete_ApiKey_DeleteIdempotent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
// 使用全局 ent client,避免事务回滚影响幂等性验证。
|
||||
client := testEntClient(t)
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user2")+"@example.com")
|
||||
|
||||
repo := NewAPIKeyRepository(client)
|
||||
key := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueSoftDeleteValue(t, "sk-soft-delete2"),
|
||||
Name: "soft-delete2",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, key), "create api key")
|
||||
|
||||
require.NoError(t, repo.Delete(ctx, key.ID), "first delete")
|
||||
require.NoError(t, repo.Delete(ctx, key.ID), "second delete should be idempotent")
|
||||
}
|
||||
|
||||
func TestEntSoftDelete_ApiKey_HardDeleteViaSkipSoftDelete(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
// 使用全局 ent client,确保 SkipSoftDelete 的硬删除语义可验证。
|
||||
client := testEntClient(t)
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-user3")+"@example.com")
|
||||
|
||||
repo := NewAPIKeyRepository(client)
|
||||
key := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueSoftDeleteValue(t, "sk-soft-delete3"),
|
||||
Name: "soft-delete3",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, key), "create api key")
|
||||
|
||||
require.NoError(t, repo.Delete(ctx, key.ID), "soft delete api key")
|
||||
|
||||
// Hard delete using SkipSoftDelete so the hook doesn't convert it to update-deleted_at.
|
||||
_, err := client.APIKey.Delete().Where(apikey.IDEQ(key.ID)).Exec(mixins.SkipSoftDelete(ctx))
|
||||
require.NoError(t, err, "hard delete")
|
||||
|
||||
_, err = client.APIKey.Query().
|
||||
Where(apikey.IDEQ(key.ID)).
|
||||
Only(mixins.SkipSoftDelete(ctx))
|
||||
require.True(t, dbent.IsNotFound(err), "expected row to be hard deleted")
|
||||
}
|
||||
|
||||
// --- UserSubscription 软删除测试 ---
|
||||
|
||||
func createEntGroup(t *testing.T, ctx context.Context, client *dbent.Client, name string) *dbent.Group {
|
||||
t.Helper()
|
||||
|
||||
g, err := client.Group.Create().
|
||||
SetName(name).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err, "create ent group")
|
||||
return g
|
||||
}
|
||||
|
||||
func TestEntSoftDelete_UserSubscription_DefaultFilterAndSkip(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-user")+"@example.com")
|
||||
g := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group"))
|
||||
|
||||
repo := NewUserSubscriptionRepository(client)
|
||||
sub := &service.UserSubscription{
|
||||
UserID: u.ID,
|
||||
GroupID: g.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, sub), "create user subscription")
|
||||
|
||||
require.NoError(t, repo.Delete(ctx, sub.ID), "soft delete user subscription")
|
||||
|
||||
_, err := repo.GetByID(ctx, sub.ID)
|
||||
require.Error(t, err, "deleted rows should be hidden by default")
|
||||
|
||||
_, err = client.UserSubscription.Query().Where(usersubscription.IDEQ(sub.ID)).Only(ctx)
|
||||
require.Error(t, err, "default ent query should not see soft-deleted rows")
|
||||
require.True(t, dbent.IsNotFound(err), "expected ent not-found after default soft delete filter")
|
||||
|
||||
got, err := client.UserSubscription.Query().
|
||||
Where(usersubscription.IDEQ(sub.ID)).
|
||||
Only(mixins.SkipSoftDelete(ctx))
|
||||
require.NoError(t, err, "SkipSoftDelete should include soft-deleted rows")
|
||||
require.NotNil(t, got.DeletedAt, "deleted_at should be set after soft delete")
|
||||
}
|
||||
|
||||
func TestEntSoftDelete_UserSubscription_DeleteIdempotent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-user2")+"@example.com")
|
||||
g := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group2"))
|
||||
|
||||
repo := NewUserSubscriptionRepository(client)
|
||||
sub := &service.UserSubscription{
|
||||
UserID: u.ID,
|
||||
GroupID: g.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, sub), "create user subscription")
|
||||
|
||||
require.NoError(t, repo.Delete(ctx, sub.ID), "first delete")
|
||||
require.NoError(t, repo.Delete(ctx, sub.ID), "second delete should be idempotent")
|
||||
}
|
||||
|
||||
func TestEntSoftDelete_UserSubscription_ListExcludesDeleted(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := testEntClient(t)
|
||||
|
||||
u := createEntUser(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-user3")+"@example.com")
|
||||
g1 := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group3a"))
|
||||
g2 := createEntGroup(t, ctx, client, uniqueSoftDeleteValue(t, "sd-sub-group3b"))
|
||||
|
||||
repo := NewUserSubscriptionRepository(client)
|
||||
|
||||
sub1 := &service.UserSubscription{
|
||||
UserID: u.ID,
|
||||
GroupID: g1.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, sub1), "create subscription 1")
|
||||
|
||||
sub2 := &service.UserSubscription{
|
||||
UserID: u.ID,
|
||||
GroupID: g2.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, sub2), "create subscription 2")
|
||||
|
||||
// 软删除 sub1
|
||||
require.NoError(t, repo.Delete(ctx, sub1.ID), "soft delete subscription 1")
|
||||
|
||||
// ListByUserID 应只返回未删除的订阅
|
||||
subs, err := repo.ListByUserID(ctx, u.ID)
|
||||
require.NoError(t, err, "ListByUserID")
|
||||
require.Len(t, subs, 1, "should only return non-deleted subscriptions")
|
||||
require.Equal(t, sub2.ID, subs[0].ID, "expected sub2 to be returned")
|
||||
}
|
||||
42
backend/internal/repository/sql_scan.go
Normal file
42
backend/internal/repository/sql_scan.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type sqlQueryer interface {
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
// scanSingleRow 执行查询并扫描第一行到 dest。
|
||||
// 若无结果,可通过 errors.Is(err, sql.ErrNoRows) 判断。
|
||||
// 如果 Close 失败,会与原始错误合并返回。
|
||||
// 设计目的:仅依赖 QueryContext,避免 QueryRowContext 对 *sql.Tx 的强绑定,
|
||||
// 让 ent.Tx 也能作为 sqlExecutor/Queryer 使用。
|
||||
func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any, dest ...any) (err error) {
|
||||
rows, err := q.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := rows.Close(); closeErr != nil {
|
||||
err = errors.Join(err, closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
if !rows.Next() {
|
||||
if err = rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
if err = rows.Scan(dest...); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
91
backend/internal/repository/temp_unsched_cache.go
Normal file
91
backend/internal/repository/temp_unsched_cache.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const tempUnschedPrefix = "temp_unsched:account:"
|
||||
|
||||
var tempUnschedSetScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local new_until = tonumber(ARGV[1])
|
||||
local new_value = ARGV[2]
|
||||
local new_ttl = tonumber(ARGV[3])
|
||||
|
||||
local existing = redis.call('GET', key)
|
||||
if existing then
|
||||
local ok, existing_data = pcall(cjson.decode, existing)
|
||||
if ok and existing_data and existing_data.until_unix then
|
||||
local existing_until = tonumber(existing_data.until_unix)
|
||||
if existing_until and new_until <= existing_until then
|
||||
return 0
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
redis.call('SET', key, new_value, 'EX', new_ttl)
|
||||
return 1
|
||||
`)
|
||||
|
||||
type tempUnschedCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewTempUnschedCache(rdb *redis.Client) service.TempUnschedCache {
|
||||
return &tempUnschedCache{rdb: rdb}
|
||||
}
|
||||
|
||||
// SetTempUnsched 设置临时不可调度状态(只延长不缩短)
|
||||
func (c *tempUnschedCache) SetTempUnsched(ctx context.Context, accountID int64, state *service.TempUnschedState) error {
|
||||
key := fmt.Sprintf("%s%d", tempUnschedPrefix, accountID)
|
||||
|
||||
stateJSON, err := json.Marshal(state)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal state: %w", err)
|
||||
}
|
||||
|
||||
ttl := time.Until(time.Unix(state.UntilUnix, 0))
|
||||
if ttl <= 0 {
|
||||
return nil // 已过期,不设置
|
||||
}
|
||||
|
||||
ttlSeconds := int(ttl.Seconds())
|
||||
if ttlSeconds < 1 {
|
||||
ttlSeconds = 1
|
||||
}
|
||||
|
||||
_, err = tempUnschedSetScript.Run(ctx, c.rdb, []string{key}, state.UntilUnix, string(stateJSON), ttlSeconds).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
// GetTempUnsched 获取临时不可调度状态
|
||||
func (c *tempUnschedCache) GetTempUnsched(ctx context.Context, accountID int64) (*service.TempUnschedState, error) {
|
||||
key := fmt.Sprintf("%s%d", tempUnschedPrefix, accountID)
|
||||
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var state service.TempUnschedState
|
||||
if err := json.Unmarshal([]byte(val), &state); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal state: %w", err)
|
||||
}
|
||||
|
||||
return &state, nil
|
||||
}
|
||||
|
||||
// DeleteTempUnsched 删除临时不可调度状态
|
||||
func (c *tempUnschedCache) DeleteTempUnsched(ctx context.Context, accountID int64) error {
|
||||
key := fmt.Sprintf("%s%d", tempUnschedPrefix, accountID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
80
backend/internal/repository/timeout_counter_cache.go
Normal file
80
backend/internal/repository/timeout_counter_cache.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const timeoutCounterPrefix = "timeout_count:account:"
|
||||
|
||||
// timeoutCounterIncrScript 使用 Lua 脚本原子性地增加计数并返回当前值
|
||||
// 如果 key 不存在,则创建并设置过期时间
|
||||
var timeoutCounterIncrScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
|
||||
local count = redis.call('INCR', key)
|
||||
if count == 1 then
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
end
|
||||
|
||||
return count
|
||||
`)
|
||||
|
||||
type timeoutCounterCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewTimeoutCounterCache 创建超时计数器缓存实例
|
||||
func NewTimeoutCounterCache(rdb *redis.Client) service.TimeoutCounterCache {
|
||||
return &timeoutCounterCache{rdb: rdb}
|
||||
}
|
||||
|
||||
// IncrementTimeoutCount 增加账户的超时计数,返回当前计数值
|
||||
// windowMinutes 是计数窗口时间(分钟),超过此时间计数器会自动重置
|
||||
func (c *timeoutCounterCache) IncrementTimeoutCount(ctx context.Context, accountID int64, windowMinutes int) (int64, error) {
|
||||
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
|
||||
|
||||
ttlSeconds := windowMinutes * 60
|
||||
if ttlSeconds < 60 {
|
||||
ttlSeconds = 60 // 最小1分钟
|
||||
}
|
||||
|
||||
result, err := timeoutCounterIncrScript.Run(ctx, c.rdb, []string{key}, ttlSeconds).Int64()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("increment timeout count: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetTimeoutCount 获取账户当前的超时计数
|
||||
func (c *timeoutCounterCache) GetTimeoutCount(ctx context.Context, accountID int64) (int64, error) {
|
||||
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
|
||||
|
||||
val, err := c.rdb.Get(ctx, key).Int64()
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get timeout count: %w", err)
|
||||
}
|
||||
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// ResetTimeoutCount 重置账户的超时计数
|
||||
func (c *timeoutCounterCache) ResetTimeoutCount(ctx context.Context, accountID int64) error {
|
||||
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// GetTimeoutCountTTL 获取计数器剩余过期时间
|
||||
func (c *timeoutCounterCache) GetTimeoutCountTTL(ctx context.Context, accountID int64) (time.Duration, error) {
|
||||
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
|
||||
return c.rdb.TTL(ctx, key).Result()
|
||||
}
|
||||
63
backend/internal/repository/turnstile_service.go
Normal file
63
backend/internal/repository/turnstile_service.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
const turnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
|
||||
|
||||
type turnstileVerifier struct {
|
||||
httpClient *http.Client
|
||||
verifyURL string
|
||||
}
|
||||
|
||||
func NewTurnstileVerifier() service.TurnstileVerifier {
|
||||
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 10 * time.Second,
|
||||
ValidateResolvedIP: true,
|
||||
})
|
||||
if err != nil {
|
||||
sharedClient = &http.Client{Timeout: 10 * time.Second}
|
||||
}
|
||||
return &turnstileVerifier{
|
||||
httpClient: sharedClient,
|
||||
verifyURL: turnstileVerifyURL,
|
||||
}
|
||||
}
|
||||
|
||||
func (v *turnstileVerifier) VerifyToken(ctx context.Context, secretKey, token, remoteIP string) (*service.TurnstileVerifyResponse, error) {
|
||||
formData := url.Values{}
|
||||
formData.Set("secret", secretKey)
|
||||
formData.Set("response", token)
|
||||
if remoteIP != "" {
|
||||
formData.Set("remoteip", remoteIP)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, v.verifyURL, strings.NewReader(formData.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := v.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
var result service.TurnstileVerifyResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("decode response: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
141
backend/internal/repository/turnstile_service_test.go
Normal file
141
backend/internal/repository/turnstile_service_test.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type TurnstileServiceSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
verifier *turnstileVerifier
|
||||
received chan url.Values
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.received = make(chan url.Values, 1)
|
||||
verifier, ok := NewTurnstileVerifier().(*turnstileVerifier)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.verifier = verifier
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) setupTransport(handler http.HandlerFunc) {
|
||||
s.verifier.verifyURL = "http://in-process/turnstile"
|
||||
s.verifier.httpClient = &http.Client{
|
||||
Transport: newInProcessTransport(handler, nil),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_SendsFormAndDecodesJSON() {
|
||||
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Capture form data in main goroutine context later
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
values, _ := url.ParseQuery(string(body))
|
||||
s.received <- values
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
|
||||
}))
|
||||
|
||||
resp, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
|
||||
require.NoError(s.T(), err, "VerifyToken")
|
||||
require.NotNil(s.T(), resp)
|
||||
require.True(s.T(), resp.Success, "expected success response")
|
||||
|
||||
// Assert form fields in main goroutine
|
||||
select {
|
||||
case values := <-s.received:
|
||||
require.Equal(s.T(), "sk", values.Get("secret"))
|
||||
require.Equal(s.T(), "token", values.Get("response"))
|
||||
require.Equal(s.T(), "1.1.1.1", values.Get("remoteip"))
|
||||
default:
|
||||
require.Fail(s.T(), "expected server to receive request")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_ContentType() {
|
||||
var contentType string
|
||||
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
contentType = r.Header.Get("Content-Type")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
|
||||
}))
|
||||
|
||||
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), strings.HasPrefix(contentType, "application/x-www-form-urlencoded"), "unexpected content-type: %s", contentType)
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_EmptyRemoteIP_NotSent() {
|
||||
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
values, _ := url.ParseQuery(string(body))
|
||||
s.received <- values
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{Success: true})
|
||||
}))
|
||||
|
||||
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "")
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
select {
|
||||
case values := <-s.received:
|
||||
require.Equal(s.T(), "", values.Get("remoteip"), "remoteip should be empty or not sent")
|
||||
default:
|
||||
require.Fail(s.T(), "expected server to receive request")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_RequestError() {
|
||||
s.verifier.verifyURL = "http://in-process/turnstile"
|
||||
s.verifier.httpClient = &http.Client{
|
||||
Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("dial failed")
|
||||
}),
|
||||
}
|
||||
|
||||
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
|
||||
require.Error(s.T(), err, "expected error when server is closed")
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_InvalidJSON() {
|
||||
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, "not-valid-json")
|
||||
}))
|
||||
|
||||
_, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
|
||||
require.Error(s.T(), err, "expected error for invalid JSON response")
|
||||
}
|
||||
|
||||
func (s *TurnstileServiceSuite) TestVerifyToken_SuccessFalse() {
|
||||
s.setupTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(service.TurnstileVerifyResponse{
|
||||
Success: false,
|
||||
ErrorCodes: []string{"invalid-input-response"},
|
||||
})
|
||||
}))
|
||||
|
||||
resp, err := s.verifier.VerifyToken(s.ctx, "sk", "token", "1.1.1.1")
|
||||
require.NoError(s.T(), err, "VerifyToken should not error on success=false")
|
||||
require.NotNil(s.T(), resp)
|
||||
require.False(s.T(), resp.Success)
|
||||
require.Contains(s.T(), resp.ErrorCodes, "invalid-input-response")
|
||||
}
|
||||
|
||||
func TestTurnstileServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(TurnstileServiceSuite))
|
||||
}
|
||||
27
backend/internal/repository/update_cache.go
Normal file
27
backend/internal/repository/update_cache.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const updateCacheKey = "update:latest"
|
||||
|
||||
type updateCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewUpdateCache(rdb *redis.Client) service.UpdateCache {
|
||||
return &updateCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *updateCache) GetUpdateInfo(ctx context.Context) (string, error) {
|
||||
return c.rdb.Get(ctx, updateCacheKey).Result()
|
||||
}
|
||||
|
||||
func (c *updateCache) SetUpdateInfo(ctx context.Context, data string, ttl time.Duration) error {
|
||||
return c.rdb.Set(ctx, updateCacheKey, data, ttl).Err()
|
||||
}
|
||||
73
backend/internal/repository/update_cache_integration_test.go
Normal file
73
backend/internal/repository/update_cache_integration_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type UpdateCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache *updateCache
|
||||
}
|
||||
|
||||
func (s *UpdateCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewUpdateCache(s.rdb).(*updateCache)
|
||||
}
|
||||
|
||||
func (s *UpdateCacheSuite) TestGetUpdateInfo_Missing() {
|
||||
_, err := s.cache.GetUpdateInfo(s.ctx)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing update info")
|
||||
}
|
||||
|
||||
func (s *UpdateCacheSuite) TestSetAndGetUpdateInfo() {
|
||||
updateTTL := 5 * time.Minute
|
||||
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.2.3", updateTTL), "SetUpdateInfo")
|
||||
|
||||
info, err := s.cache.GetUpdateInfo(s.ctx)
|
||||
require.NoError(s.T(), err, "GetUpdateInfo")
|
||||
require.Equal(s.T(), "v1.2.3", info, "update info mismatch")
|
||||
}
|
||||
|
||||
func (s *UpdateCacheSuite) TestSetUpdateInfo_TTL() {
|
||||
updateTTL := 5 * time.Minute
|
||||
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.2.3", updateTTL))
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, updateCacheKey).Result()
|
||||
require.NoError(s.T(), err, "TTL updateCacheKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, updateTTL)
|
||||
}
|
||||
|
||||
func (s *UpdateCacheSuite) TestSetUpdateInfo_Overwrite() {
|
||||
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.0.0", 5*time.Minute))
|
||||
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v2.0.0", 5*time.Minute))
|
||||
|
||||
info, err := s.cache.GetUpdateInfo(s.ctx)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), "v2.0.0", info, "expected overwritten value")
|
||||
}
|
||||
|
||||
func (s *UpdateCacheSuite) TestSetUpdateInfo_ZeroTTL() {
|
||||
// TTL=0 means persist forever (no expiry) in Redis SET command
|
||||
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v0.0.0", 0))
|
||||
|
||||
info, err := s.cache.GetUpdateInfo(s.ctx)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), "v0.0.0", info)
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, updateCacheKey).Result()
|
||||
require.NoError(s.T(), err)
|
||||
// TTL=-1 means no expiry, TTL=-2 means key doesn't exist
|
||||
require.Equal(s.T(), time.Duration(-1), ttl, "expected TTL=-1 for key with no expiry")
|
||||
}
|
||||
|
||||
func TestUpdateCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(UpdateCacheSuite))
|
||||
}
|
||||
2271
backend/internal/repository/usage_log_repo.go
Normal file
2271
backend/internal/repository/usage_log_repo.go
Normal file
File diff suppressed because it is too large
Load Diff
1215
backend/internal/repository/usage_log_repo_integration_test.go
Normal file
1215
backend/internal/repository/usage_log_repo_integration_test.go
Normal file
File diff suppressed because it is too large
Load Diff
385
backend/internal/repository/user_attribute_repo.go
Normal file
385
backend/internal/repository/user_attribute_repo.go
Normal file
@@ -0,0 +1,385 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributedefinition"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userattributevalue"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// UserAttributeDefinitionRepository implementation
|
||||
type userAttributeDefinitionRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
// NewUserAttributeDefinitionRepository creates a new repository instance
|
||||
func NewUserAttributeDefinitionRepository(client *dbent.Client) service.UserAttributeDefinitionRepository {
|
||||
return &userAttributeDefinitionRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *userAttributeDefinitionRepository) Create(ctx context.Context, def *service.UserAttributeDefinition) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
|
||||
created, err := client.UserAttributeDefinition.Create().
|
||||
SetKey(def.Key).
|
||||
SetName(def.Name).
|
||||
SetDescription(def.Description).
|
||||
SetType(string(def.Type)).
|
||||
SetOptions(toEntOptions(def.Options)).
|
||||
SetRequired(def.Required).
|
||||
SetValidation(toEntValidation(def.Validation)).
|
||||
SetPlaceholder(def.Placeholder).
|
||||
SetEnabled(def.Enabled).
|
||||
Save(ctx)
|
||||
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, nil, service.ErrAttributeKeyExists)
|
||||
}
|
||||
|
||||
def.ID = created.ID
|
||||
def.DisplayOrder = created.DisplayOrder
|
||||
def.CreatedAt = created.CreatedAt
|
||||
def.UpdatedAt = created.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userAttributeDefinitionRepository) GetByID(ctx context.Context, id int64) (*service.UserAttributeDefinition, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
|
||||
e, err := client.UserAttributeDefinition.Query().
|
||||
Where(userattributedefinition.IDEQ(id)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
|
||||
}
|
||||
return defEntityToService(e), nil
|
||||
}
|
||||
|
||||
func (r *userAttributeDefinitionRepository) GetByKey(ctx context.Context, key string) (*service.UserAttributeDefinition, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
|
||||
e, err := client.UserAttributeDefinition.Query().
|
||||
Where(userattributedefinition.KeyEQ(key)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
|
||||
}
|
||||
return defEntityToService(e), nil
|
||||
}
|
||||
|
||||
func (r *userAttributeDefinitionRepository) Update(ctx context.Context, def *service.UserAttributeDefinition) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
|
||||
updated, err := client.UserAttributeDefinition.UpdateOneID(def.ID).
|
||||
SetName(def.Name).
|
||||
SetDescription(def.Description).
|
||||
SetType(string(def.Type)).
|
||||
SetOptions(toEntOptions(def.Options)).
|
||||
SetRequired(def.Required).
|
||||
SetValidation(toEntValidation(def.Validation)).
|
||||
SetPlaceholder(def.Placeholder).
|
||||
SetDisplayOrder(def.DisplayOrder).
|
||||
SetEnabled(def.Enabled).
|
||||
Save(ctx)
|
||||
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, service.ErrAttributeKeyExists)
|
||||
}
|
||||
|
||||
def.UpdatedAt = updated.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userAttributeDefinitionRepository) Delete(ctx context.Context, id int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
|
||||
_, err := client.UserAttributeDefinition.Delete().
|
||||
Where(userattributedefinition.IDEQ(id)).
|
||||
Exec(ctx)
|
||||
return translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
|
||||
}
|
||||
|
||||
func (r *userAttributeDefinitionRepository) List(ctx context.Context, enabledOnly bool) ([]service.UserAttributeDefinition, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
|
||||
q := client.UserAttributeDefinition.Query()
|
||||
if enabledOnly {
|
||||
q = q.Where(userattributedefinition.EnabledEQ(true))
|
||||
}
|
||||
|
||||
entities, err := q.Order(dbent.Asc(userattributedefinition.FieldDisplayOrder)).All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]service.UserAttributeDefinition, 0, len(entities))
|
||||
for _, e := range entities {
|
||||
result = append(result, *defEntityToService(e))
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *userAttributeDefinitionRepository) UpdateDisplayOrders(ctx context.Context, orders map[int64]int) error {
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
for id, order := range orders {
|
||||
if _, err := tx.UserAttributeDefinition.UpdateOneID(id).
|
||||
SetDisplayOrder(order).
|
||||
Save(ctx); err != nil {
|
||||
return translatePersistenceError(err, service.ErrAttributeDefinitionNotFound, nil)
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (r *userAttributeDefinitionRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
return client.UserAttributeDefinition.Query().
|
||||
Where(userattributedefinition.KeyEQ(key)).
|
||||
Exist(ctx)
|
||||
}
|
||||
|
||||
// UserAttributeValueRepository implementation
|
||||
type userAttributeValueRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
// NewUserAttributeValueRepository creates a new repository instance
|
||||
func NewUserAttributeValueRepository(client *dbent.Client) service.UserAttributeValueRepository {
|
||||
return &userAttributeValueRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *userAttributeValueRepository) GetByUserID(ctx context.Context, userID int64) ([]service.UserAttributeValue, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
|
||||
entities, err := client.UserAttributeValue.Query().
|
||||
Where(userattributevalue.UserIDEQ(userID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]service.UserAttributeValue, 0, len(entities))
|
||||
for _, e := range entities {
|
||||
result = append(result, service.UserAttributeValue{
|
||||
ID: e.ID,
|
||||
UserID: e.UserID,
|
||||
AttributeID: e.AttributeID,
|
||||
Value: e.Value,
|
||||
CreatedAt: e.CreatedAt,
|
||||
UpdatedAt: e.UpdatedAt,
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *userAttributeValueRepository) GetByUserIDs(ctx context.Context, userIDs []int64) ([]service.UserAttributeValue, error) {
|
||||
if len(userIDs) == 0 {
|
||||
return []service.UserAttributeValue{}, nil
|
||||
}
|
||||
|
||||
client := clientFromContext(ctx, r.client)
|
||||
|
||||
entities, err := client.UserAttributeValue.Query().
|
||||
Where(userattributevalue.UserIDIn(userIDs...)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]service.UserAttributeValue, 0, len(entities))
|
||||
for _, e := range entities {
|
||||
result = append(result, service.UserAttributeValue{
|
||||
ID: e.ID,
|
||||
UserID: e.UserID,
|
||||
AttributeID: e.AttributeID,
|
||||
Value: e.Value,
|
||||
CreatedAt: e.CreatedAt,
|
||||
UpdatedAt: e.UpdatedAt,
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *userAttributeValueRepository) UpsertBatch(ctx context.Context, userID int64, inputs []service.UpdateUserAttributeInput) error {
|
||||
if len(inputs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
for _, input := range inputs {
|
||||
// Use upsert (ON CONFLICT DO UPDATE)
|
||||
err := tx.UserAttributeValue.Create().
|
||||
SetUserID(userID).
|
||||
SetAttributeID(input.AttributeID).
|
||||
SetValue(input.Value).
|
||||
OnConflictColumns(userattributevalue.FieldUserID, userattributevalue.FieldAttributeID).
|
||||
UpdateValue().
|
||||
UpdateUpdatedAt().
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (r *userAttributeValueRepository) DeleteByAttributeID(ctx context.Context, attributeID int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
|
||||
_, err := client.UserAttributeValue.Delete().
|
||||
Where(userattributevalue.AttributeIDEQ(attributeID)).
|
||||
Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *userAttributeValueRepository) DeleteByUserID(ctx context.Context, userID int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
|
||||
_, err := client.UserAttributeValue.Delete().
|
||||
Where(userattributevalue.UserIDEQ(userID)).
|
||||
Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// Helper functions for entity to service conversion
|
||||
func defEntityToService(e *dbent.UserAttributeDefinition) *service.UserAttributeDefinition {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.UserAttributeDefinition{
|
||||
ID: e.ID,
|
||||
Key: e.Key,
|
||||
Name: e.Name,
|
||||
Description: e.Description,
|
||||
Type: service.UserAttributeType(e.Type),
|
||||
Options: toServiceOptions(e.Options),
|
||||
Required: e.Required,
|
||||
Validation: toServiceValidation(e.Validation),
|
||||
Placeholder: e.Placeholder,
|
||||
DisplayOrder: e.DisplayOrder,
|
||||
Enabled: e.Enabled,
|
||||
CreatedAt: e.CreatedAt,
|
||||
UpdatedAt: e.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// Type conversion helpers (map types <-> service types)
|
||||
func toEntOptions(opts []service.UserAttributeOption) []map[string]any {
|
||||
if opts == nil {
|
||||
return []map[string]any{}
|
||||
}
|
||||
result := make([]map[string]any, len(opts))
|
||||
for i, o := range opts {
|
||||
result[i] = map[string]any{"value": o.Value, "label": o.Label}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func toServiceOptions(opts []map[string]any) []service.UserAttributeOption {
|
||||
if opts == nil {
|
||||
return []service.UserAttributeOption{}
|
||||
}
|
||||
result := make([]service.UserAttributeOption, len(opts))
|
||||
for i, o := range opts {
|
||||
result[i] = service.UserAttributeOption{
|
||||
Value: getString(o, "value"),
|
||||
Label: getString(o, "label"),
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func toEntValidation(v service.UserAttributeValidation) map[string]any {
|
||||
result := map[string]any{}
|
||||
if v.MinLength != nil {
|
||||
result["min_length"] = *v.MinLength
|
||||
}
|
||||
if v.MaxLength != nil {
|
||||
result["max_length"] = *v.MaxLength
|
||||
}
|
||||
if v.Min != nil {
|
||||
result["min"] = *v.Min
|
||||
}
|
||||
if v.Max != nil {
|
||||
result["max"] = *v.Max
|
||||
}
|
||||
if v.Pattern != nil {
|
||||
result["pattern"] = *v.Pattern
|
||||
}
|
||||
if v.Message != nil {
|
||||
result["message"] = *v.Message
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func toServiceValidation(v map[string]any) service.UserAttributeValidation {
|
||||
result := service.UserAttributeValidation{}
|
||||
if val := getInt(v, "min_length"); val != nil {
|
||||
result.MinLength = val
|
||||
}
|
||||
if val := getInt(v, "max_length"); val != nil {
|
||||
result.MaxLength = val
|
||||
}
|
||||
if val := getInt(v, "min"); val != nil {
|
||||
result.Min = val
|
||||
}
|
||||
if val := getInt(v, "max"); val != nil {
|
||||
result.Max = val
|
||||
}
|
||||
if val := getStringPtr(v, "pattern"); val != nil {
|
||||
result.Pattern = val
|
||||
}
|
||||
if val := getStringPtr(v, "message"); val != nil {
|
||||
result.Message = val
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Helper functions for type conversion
|
||||
func getString(m map[string]any, key string) string {
|
||||
if v, ok := m[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func getStringPtr(m map[string]any, key string) *string {
|
||||
if v, ok := m[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return &s
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getInt(m map[string]any, key string) *int {
|
||||
if v, ok := m[key]; ok {
|
||||
switch n := v.(type) {
|
||||
case int:
|
||||
return &n
|
||||
case int64:
|
||||
i := int(n)
|
||||
return &i
|
||||
case float64:
|
||||
i := int(n)
|
||||
return &i
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
468
backend/internal/repository/user_repo.go
Normal file
468
backend/internal/repository/user_repo.go
Normal file
@@ -0,0 +1,468 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
dbuser "github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type userRepository struct {
|
||||
client *dbent.Client
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
func NewUserRepository(client *dbent.Client, sqlDB *sql.DB) service.UserRepository {
|
||||
return newUserRepositoryWithSQL(client, sqlDB)
|
||||
}
|
||||
|
||||
func newUserRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *userRepository {
|
||||
return &userRepository{client: client, sql: sqlq}
|
||||
}
|
||||
|
||||
func (r *userRepository) Create(ctx context.Context, userIn *service.User) error {
|
||||
if userIn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 统一使用 ent 的事务:保证用户与允许分组的更新原子化,
|
||||
// 并避免基于 *sql.Tx 手动构造 ent client 导致的 ExecQuerier 断言错误。
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||
return err
|
||||
}
|
||||
|
||||
var txClient *dbent.Client
|
||||
if err == nil {
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txClient = tx.Client()
|
||||
} else {
|
||||
// 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
|
||||
txClient = r.client
|
||||
}
|
||||
|
||||
created, err := txClient.User.Create().
|
||||
SetEmail(userIn.Email).
|
||||
SetUsername(userIn.Username).
|
||||
SetNotes(userIn.Notes).
|
||||
SetPasswordHash(userIn.PasswordHash).
|
||||
SetRole(userIn.Role).
|
||||
SetBalance(userIn.Balance).
|
||||
SetConcurrency(userIn.Concurrency).
|
||||
SetStatus(userIn.Status).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, nil, service.ErrEmailExists)
|
||||
}
|
||||
|
||||
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, created.ID, userIn.AllowedGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if tx != nil {
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
applyUserEntityToService(userIn, created)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) GetByID(ctx context.Context, id int64) (*service.User, error) {
|
||||
m, err := r.client.User.Query().Where(dbuser.IDEQ(id)).Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
|
||||
out := userEntityToService(m)
|
||||
groups, err := r.loadAllowedGroups(ctx, []int64{id})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v, ok := groups[id]; ok {
|
||||
out.AllowedGroups = v
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) GetByEmail(ctx context.Context, email string) (*service.User, error) {
|
||||
m, err := r.client.User.Query().Where(dbuser.EmailEQ(email)).Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
|
||||
out := userEntityToService(m)
|
||||
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v, ok := groups[m.ID]; ok {
|
||||
out.AllowedGroups = v
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) Update(ctx context.Context, userIn *service.User) error {
|
||||
if userIn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 使用 ent 事务包裹用户更新与 allowed_groups 同步,避免跨层事务不一致。
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||
return err
|
||||
}
|
||||
|
||||
var txClient *dbent.Client
|
||||
if err == nil {
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txClient = tx.Client()
|
||||
} else {
|
||||
// 已处于外部事务中(ErrTxStarted),复用当前 client 并由调用方负责提交/回滚。
|
||||
txClient = r.client
|
||||
}
|
||||
|
||||
updated, err := txClient.User.UpdateOneID(userIn.ID).
|
||||
SetEmail(userIn.Email).
|
||||
SetUsername(userIn.Username).
|
||||
SetNotes(userIn.Notes).
|
||||
SetPasswordHash(userIn.PasswordHash).
|
||||
SetRole(userIn.Role).
|
||||
SetBalance(userIn.Balance).
|
||||
SetConcurrency(userIn.Concurrency).
|
||||
SetStatus(userIn.Status).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
|
||||
}
|
||||
|
||||
if err := r.syncUserAllowedGroupsWithClient(ctx, txClient, updated.ID, userIn.AllowedGroups); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if tx != nil {
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
userIn.UpdatedAt = updated.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) Delete(ctx context.Context, id int64) error {
|
||||
affected, err := r.client.User.Delete().Where(dbuser.IDEQ(id)).Exec(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, service.UserListFilters{})
|
||||
}
|
||||
|
||||
func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) {
|
||||
q := r.client.User.Query()
|
||||
|
||||
if filters.Status != "" {
|
||||
q = q.Where(dbuser.StatusEQ(filters.Status))
|
||||
}
|
||||
if filters.Role != "" {
|
||||
q = q.Where(dbuser.RoleEQ(filters.Role))
|
||||
}
|
||||
if filters.Search != "" {
|
||||
q = q.Where(
|
||||
dbuser.Or(
|
||||
dbuser.EmailContainsFold(filters.Search),
|
||||
dbuser.UsernameContainsFold(filters.Search),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// If attribute filters are specified, we need to filter by user IDs first
|
||||
var allowedUserIDs []int64
|
||||
if len(filters.Attributes) > 0 {
|
||||
var attrErr error
|
||||
allowedUserIDs, attrErr = r.filterUsersByAttributes(ctx, filters.Attributes)
|
||||
if attrErr != nil {
|
||||
return nil, nil, attrErr
|
||||
}
|
||||
if len(allowedUserIDs) == 0 {
|
||||
// No users match the attribute filters
|
||||
return []service.User{}, paginationResultFromTotal(0, params), nil
|
||||
}
|
||||
q = q.Where(dbuser.IDIn(allowedUserIDs...))
|
||||
}
|
||||
|
||||
total, err := q.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
users, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(dbuser.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outUsers := make([]service.User, 0, len(users))
|
||||
if len(users) == 0 {
|
||||
return outUsers, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
userIDs := make([]int64, 0, len(users))
|
||||
userMap := make(map[int64]*service.User, len(users))
|
||||
for i := range users {
|
||||
userIDs = append(userIDs, users[i].ID)
|
||||
u := userEntityToService(users[i])
|
||||
outUsers = append(outUsers, *u)
|
||||
userMap[u.ID] = &outUsers[len(outUsers)-1]
|
||||
}
|
||||
|
||||
// Batch load active subscriptions with groups to avoid N+1.
|
||||
subs, err := r.client.UserSubscription.Query().
|
||||
Where(
|
||||
usersubscription.UserIDIn(userIDs...),
|
||||
usersubscription.StatusEQ(service.SubscriptionStatusActive),
|
||||
).
|
||||
WithGroup().
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for i := range subs {
|
||||
if u, ok := userMap[subs[i].UserID]; ok {
|
||||
u.Subscriptions = append(u.Subscriptions, *userSubscriptionEntityToService(subs[i]))
|
||||
}
|
||||
}
|
||||
|
||||
allowedGroupsByUser, err := r.loadAllowedGroups(ctx, userIDs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for id, u := range userMap {
|
||||
if groups, ok := allowedGroupsByUser[id]; ok {
|
||||
u.AllowedGroups = groups
|
||||
}
|
||||
}
|
||||
|
||||
return outUsers, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
|
||||
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
|
||||
if len(attrs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if r.sql == nil {
|
||||
return nil, fmt.Errorf("sql executor is not configured")
|
||||
}
|
||||
|
||||
clauses := make([]string, 0, len(attrs))
|
||||
args := make([]any, 0, len(attrs)*2+1)
|
||||
argIndex := 1
|
||||
for attrID, value := range attrs {
|
||||
clauses = append(clauses, fmt.Sprintf("(attribute_id = $%d AND value ILIKE $%d)", argIndex, argIndex+1))
|
||||
args = append(args, attrID, "%"+value+"%")
|
||||
argIndex += 2
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(
|
||||
`SELECT user_id
|
||||
FROM user_attribute_values
|
||||
WHERE %s
|
||||
GROUP BY user_id
|
||||
HAVING COUNT(DISTINCT attribute_id) = $%d`,
|
||||
strings.Join(clauses, " OR "),
|
||||
argIndex,
|
||||
)
|
||||
args = append(args, len(attrs))
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
result := make([]int64, 0)
|
||||
for rows.Next() {
|
||||
var userID int64
|
||||
if scanErr := rows.Scan(&userID); scanErr != nil {
|
||||
return nil, scanErr
|
||||
}
|
||||
result = append(result, userID)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
if n == 0 {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeductBalance 扣除用户余额
|
||||
// 透支策略:允许余额变为负数,确保当前请求能够完成
|
||||
// 中间件会阻止余额 <= 0 的用户发起后续请求
|
||||
func (r *userRepository) DeductBalance(ctx context.Context, id int64, amount float64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
n, err := client.User.Update().
|
||||
Where(dbuser.IDEQ(id)).
|
||||
AddBalance(-amount).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddConcurrency(amount).Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
if n == 0 {
|
||||
return service.ErrUserNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
|
||||
}
|
||||
|
||||
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
|
||||
// 仅操作 user_allowed_groups 联接表,legacy users.allowed_groups 列已弃用。
|
||||
affected, err := r.client.UserAllowedGroup.Delete().
|
||||
Where(userallowedgroup.GroupIDEQ(groupID)).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(affected), nil
|
||||
}
|
||||
|
||||
func (r *userRepository) GetFirstAdmin(ctx context.Context) (*service.User, error) {
|
||||
m, err := r.client.User.Query().
|
||||
Where(
|
||||
dbuser.RoleEQ(service.RoleAdmin),
|
||||
dbuser.StatusEQ(service.StatusActive),
|
||||
).
|
||||
Order(dbent.Asc(dbuser.FieldID)).
|
||||
First(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||
}
|
||||
|
||||
out := userEntityToService(m)
|
||||
groups, err := r.loadAllowedGroups(ctx, []int64{m.ID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v, ok := groups[m.ID]; ok {
|
||||
out.AllowedGroups = v
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) loadAllowedGroups(ctx context.Context, userIDs []int64) (map[int64][]int64, error) {
|
||||
out := make(map[int64][]int64, len(userIDs))
|
||||
if len(userIDs) == 0 {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
rows, err := r.client.UserAllowedGroup.Query().
|
||||
Where(userallowedgroup.UserIDIn(userIDs...)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range rows {
|
||||
out[rows[i].UserID] = append(out[rows[i].UserID], rows[i].GroupID)
|
||||
}
|
||||
|
||||
for userID := range out {
|
||||
sort.Slice(out[userID], func(i, j int) bool { return out[userID][i] < out[userID][j] })
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// syncUserAllowedGroupsWithClient 在 ent client/事务内同步用户允许分组:
|
||||
// 仅操作 user_allowed_groups 联接表,legacy users.allowed_groups 列已弃用。
|
||||
func (r *userRepository) syncUserAllowedGroupsWithClient(ctx context.Context, client *dbent.Client, userID int64, groupIDs []int64) error {
|
||||
if client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Keep join table as the source of truth for reads.
|
||||
if _, err := client.UserAllowedGroup.Delete().Where(userallowedgroup.UserIDEQ(userID)).Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unique := make(map[int64]struct{}, len(groupIDs))
|
||||
for _, id := range groupIDs {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
unique[id] = struct{}{}
|
||||
}
|
||||
|
||||
if len(unique) > 0 {
|
||||
creates := make([]*dbent.UserAllowedGroupCreate, 0, len(unique))
|
||||
for groupID := range unique {
|
||||
creates = append(creates, client.UserAllowedGroup.Create().SetUserID(userID).SetGroupID(groupID))
|
||||
}
|
||||
if err := client.UserAllowedGroup.
|
||||
CreateBulk(creates...).
|
||||
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
|
||||
DoNothing().
|
||||
Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyUserEntityToService(dst *service.User, src *dbent.User) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
dst.ID = src.ID
|
||||
dst.CreatedAt = src.CreatedAt
|
||||
dst.UpdatedAt = src.UpdatedAt
|
||||
}
|
||||
537
backend/internal/repository/user_repo_integration_test.go
Normal file
537
backend/internal/repository/user_repo_integration_test.go
Normal file
@@ -0,0 +1,537 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type UserRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
repo *userRepository
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.client = testEntClient(s.T())
|
||||
s.repo = newUserRepositoryWithSQL(s.client, integrationDB)
|
||||
|
||||
// 清理测试数据,确保每个测试从干净状态开始
|
||||
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_subscriptions")
|
||||
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_allowed_groups")
|
||||
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM users")
|
||||
}
|
||||
|
||||
func TestUserRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(UserRepoSuite))
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) mustCreateUser(u *service.User) *service.User {
|
||||
s.T().Helper()
|
||||
|
||||
if u.Email == "" {
|
||||
u.Email = "user-" + time.Now().Format(time.RFC3339Nano) + "@example.com"
|
||||
}
|
||||
if u.PasswordHash == "" {
|
||||
u.PasswordHash = "test-password-hash"
|
||||
}
|
||||
if u.Role == "" {
|
||||
u.Role = service.RoleUser
|
||||
}
|
||||
if u.Status == "" {
|
||||
u.Status = service.StatusActive
|
||||
}
|
||||
if u.Concurrency == 0 {
|
||||
u.Concurrency = 5
|
||||
}
|
||||
|
||||
s.Require().NoError(s.repo.Create(s.ctx, u), "create user")
|
||||
return u
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) mustCreateGroup(name string) *service.Group {
|
||||
s.T().Helper()
|
||||
|
||||
g, err := s.client.Group.Create().
|
||||
SetName(name).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "create group")
|
||||
return groupEntityToService(g)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) mustCreateSubscription(userID, groupID int64, mutate func(*dbent.UserSubscriptionCreate)) *dbent.UserSubscription {
|
||||
s.T().Helper()
|
||||
|
||||
now := time.Now()
|
||||
create := s.client.UserSubscription.Create().
|
||||
SetUserID(userID).
|
||||
SetGroupID(groupID).
|
||||
SetStartsAt(now.Add(-1 * time.Hour)).
|
||||
SetExpiresAt(now.Add(24 * time.Hour)).
|
||||
SetStatus(service.SubscriptionStatusActive).
|
||||
SetAssignedAt(now).
|
||||
SetNotes("")
|
||||
|
||||
if mutate != nil {
|
||||
mutate(create)
|
||||
}
|
||||
|
||||
sub, err := create.Save(s.ctx)
|
||||
s.Require().NoError(err, "create subscription")
|
||||
return sub
|
||||
}
|
||||
|
||||
// --- Create / GetByID / GetByEmail / Update / Delete ---
|
||||
|
||||
func (s *UserRepoSuite) TestCreate() {
|
||||
user := s.mustCreateUser(&service.User{
|
||||
Email: "create@test.com",
|
||||
Username: "testuser",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
s.Require().NotZero(user.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("create@test.com", got.Email)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestGetByEmail() {
|
||||
user := s.mustCreateUser(&service.User{Email: "byemail@test.com"})
|
||||
|
||||
got, err := s.repo.GetByEmail(s.ctx, user.Email)
|
||||
s.Require().NoError(err, "GetByEmail")
|
||||
s.Require().Equal(user.ID, got.ID)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestGetByEmail_NotFound() {
|
||||
_, err := s.repo.GetByEmail(s.ctx, "nonexistent@test.com")
|
||||
s.Require().Error(err, "expected error for non-existent email")
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestUpdate() {
|
||||
user := s.mustCreateUser(&service.User{Email: "update@test.com", Username: "original"})
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
got.Username = "updated"
|
||||
s.Require().NoError(s.repo.Update(s.ctx, got), "Update")
|
||||
|
||||
updated, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("updated", updated.Username)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDelete() {
|
||||
user := s.mustCreateUser(&service.User{Email: "delete@test.com"})
|
||||
|
||||
err := s.repo.Delete(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *UserRepoSuite) TestList() {
|
||||
s.mustCreateUser(&service.User{Email: "list1@test.com"})
|
||||
s.mustCreateUser(&service.User{Email: "list2@test.com"})
|
||||
|
||||
users, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(users, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_Status() {
|
||||
s.mustCreateUser(&service.User{Email: "active@test.com", Status: service.StatusActive})
|
||||
s.mustCreateUser(&service.User{Email: "disabled@test.com", Status: service.StatusDisabled})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Status: service.StatusActive})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(users, 1)
|
||||
s.Require().Equal(service.StatusActive, users[0].Status)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_Role() {
|
||||
s.mustCreateUser(&service.User{Email: "user@test.com", Role: service.RoleUser})
|
||||
s.mustCreateUser(&service.User{Email: "admin@test.com", Role: service.RoleAdmin})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Role: service.RoleAdmin})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(users, 1)
|
||||
s.Require().Equal(service.RoleAdmin, users[0].Role)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_Search() {
|
||||
s.mustCreateUser(&service.User{Email: "alice@test.com", Username: "Alice"})
|
||||
s.mustCreateUser(&service.User{Email: "bob@test.com", Username: "Bob"})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "alice"})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(users, 1)
|
||||
s.Require().Contains(users[0].Email, "alice")
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
|
||||
s.mustCreateUser(&service.User{Email: "u1@test.com", Username: "JohnDoe"})
|
||||
s.mustCreateUser(&service.User{Email: "u2@test.com", Username: "JaneSmith"})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "john"})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(users, 1)
|
||||
s.Require().Equal("JohnDoe", users[0].Username)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
|
||||
user := s.mustCreateUser(&service.User{Email: "sub@test.com", Status: service.StatusActive})
|
||||
groupActive := s.mustCreateGroup("g-sub-active")
|
||||
groupExpired := s.mustCreateGroup("g-sub-expired")
|
||||
|
||||
_ = s.mustCreateSubscription(user.ID, groupActive.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetStatus(service.SubscriptionStatusActive)
|
||||
c.SetExpiresAt(time.Now().Add(1 * time.Hour))
|
||||
})
|
||||
_ = s.mustCreateSubscription(user.ID, groupExpired.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetStatus(service.SubscriptionStatusExpired)
|
||||
c.SetExpiresAt(time.Now().Add(-1 * time.Hour))
|
||||
})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "sub@"})
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Len(users, 1, "expected 1 user")
|
||||
s.Require().Len(users[0].Subscriptions, 1, "expected 1 active subscription")
|
||||
s.Require().NotNil(users[0].Subscriptions[0].Group, "expected subscription group preload")
|
||||
s.Require().Equal(groupActive.ID, users[0].Subscriptions[0].Group.ID, "group ID mismatch")
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
|
||||
s.mustCreateUser(&service.User{
|
||||
Email: "a@example.com",
|
||||
Username: "Alice",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
})
|
||||
target := s.mustCreateUser(&service.User{
|
||||
Email: "b@example.com",
|
||||
Username: "Bob",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusActive,
|
||||
Balance: 1,
|
||||
})
|
||||
s.mustCreateUser(&service.User{
|
||||
Email: "c@example.com",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusDisabled,
|
||||
})
|
||||
|
||||
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Status: service.StatusActive, Role: service.RoleAdmin, Search: "b@"})
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
|
||||
s.Require().Len(users, 1, "ListWithFilters len mismatch")
|
||||
s.Require().Equal(target.ID, users[0].ID, "ListWithFilters result mismatch")
|
||||
}
|
||||
|
||||
// --- Balance operations ---
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateBalance() {
|
||||
user := s.mustCreateUser(&service.User{Email: "bal@test.com", Balance: 10})
|
||||
|
||||
err := s.repo.UpdateBalance(s.ctx, user.ID, 2.5)
|
||||
s.Require().NoError(err, "UpdateBalance")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(12.5, got.Balance, 1e-6)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateBalance_Negative() {
|
||||
user := s.mustCreateUser(&service.User{Email: "balneg@test.com", Balance: 10})
|
||||
|
||||
err := s.repo.UpdateBalance(s.ctx, user.ID, -3)
|
||||
s.Require().NoError(err, "UpdateBalance with negative")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(7.0, got.Balance, 1e-6)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDeductBalance() {
|
||||
user := s.mustCreateUser(&service.User{Email: "deduct@test.com", Balance: 10})
|
||||
|
||||
err := s.repo.DeductBalance(s.ctx, user.ID, 5)
|
||||
s.Require().NoError(err, "DeductBalance")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(5.0, got.Balance, 1e-6)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
|
||||
user := s.mustCreateUser(&service.User{Email: "insuf@test.com", Balance: 5})
|
||||
|
||||
// 透支策略:允许扣除超过余额的金额
|
||||
err := s.repo.DeductBalance(s.ctx, user.ID, 999)
|
||||
s.Require().NoError(err, "DeductBalance should allow overdraft")
|
||||
|
||||
// 验证余额变为负数
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(-994.0, got.Balance, 1e-6, "Balance should be negative after overdraft")
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
|
||||
user := s.mustCreateUser(&service.User{Email: "exact@test.com", Balance: 10})
|
||||
|
||||
err := s.repo.DeductBalance(s.ctx, user.ID, 10)
|
||||
s.Require().NoError(err, "DeductBalance exact amount")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(0.0, got.Balance, 1e-6)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDeductBalance_AllowsOverdraft() {
|
||||
user := s.mustCreateUser(&service.User{Email: "overdraft@test.com", Balance: 5.0})
|
||||
|
||||
// 扣除超过余额的金额 - 应该成功
|
||||
err := s.repo.DeductBalance(s.ctx, user.ID, 10.0)
|
||||
s.Require().NoError(err, "DeductBalance should allow overdraft")
|
||||
|
||||
// 验证余额为负
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(-5.0, got.Balance, 1e-6, "Balance should be -5.0 after overdraft")
|
||||
}
|
||||
|
||||
// --- Concurrency ---
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateConcurrency() {
|
||||
user := s.mustCreateUser(&service.User{Email: "conc@test.com", Concurrency: 5})
|
||||
|
||||
err := s.repo.UpdateConcurrency(s.ctx, user.ID, 3)
|
||||
s.Require().NoError(err, "UpdateConcurrency")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(8, got.Concurrency)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
|
||||
user := s.mustCreateUser(&service.User{Email: "concneg@test.com", Concurrency: 5})
|
||||
|
||||
err := s.repo.UpdateConcurrency(s.ctx, user.ID, -2)
|
||||
s.Require().NoError(err, "UpdateConcurrency negative")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(3, got.Concurrency)
|
||||
}
|
||||
|
||||
// --- ExistsByEmail ---
|
||||
|
||||
func (s *UserRepoSuite) TestExistsByEmail() {
|
||||
s.mustCreateUser(&service.User{Email: "exists@test.com"})
|
||||
|
||||
exists, err := s.repo.ExistsByEmail(s.ctx, "exists@test.com")
|
||||
s.Require().NoError(err, "ExistsByEmail")
|
||||
s.Require().True(exists)
|
||||
|
||||
notExists, err := s.repo.ExistsByEmail(s.ctx, "notexists@test.com")
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(notExists)
|
||||
}
|
||||
|
||||
// --- RemoveGroupFromAllowedGroups ---
|
||||
|
||||
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() {
|
||||
target := s.mustCreateGroup("target-42")
|
||||
other := s.mustCreateGroup("other-7")
|
||||
|
||||
userA := s.mustCreateUser(&service.User{
|
||||
Email: "a1@example.com",
|
||||
AllowedGroups: []int64{target.ID, other.ID},
|
||||
})
|
||||
s.mustCreateUser(&service.User{
|
||||
Email: "a2@example.com",
|
||||
AllowedGroups: []int64{other.ID},
|
||||
})
|
||||
|
||||
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, target.ID)
|
||||
s.Require().NoError(err, "RemoveGroupFromAllowedGroups")
|
||||
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, userA.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().NotContains(got.AllowedGroups, target.ID)
|
||||
s.Require().Contains(got.AllowedGroups, other.ID)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
|
||||
groupA := s.mustCreateGroup("nomatch-a")
|
||||
groupB := s.mustCreateGroup("nomatch-b")
|
||||
|
||||
s.mustCreateUser(&service.User{
|
||||
Email: "nomatch@test.com",
|
||||
AllowedGroups: []int64{groupA.ID, groupB.ID},
|
||||
})
|
||||
|
||||
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, 999999)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(affected, "expected no affected rows")
|
||||
}
|
||||
|
||||
// --- GetFirstAdmin ---
|
||||
|
||||
func (s *UserRepoSuite) TestGetFirstAdmin() {
|
||||
admin1 := s.mustCreateUser(&service.User{
|
||||
Email: "admin1@example.com",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
s.mustCreateUser(&service.User{
|
||||
Email: "admin2@example.com",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
got, err := s.repo.GetFirstAdmin(s.ctx)
|
||||
s.Require().NoError(err, "GetFirstAdmin")
|
||||
s.Require().Equal(admin1.ID, got.ID, "GetFirstAdmin mismatch")
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
|
||||
s.mustCreateUser(&service.User{
|
||||
Email: "user@example.com",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
_, err := s.repo.GetFirstAdmin(s.ctx)
|
||||
s.Require().Error(err, "expected error when no admin exists")
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
|
||||
s.mustCreateUser(&service.User{
|
||||
Email: "disabled@example.com",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusDisabled,
|
||||
})
|
||||
activeAdmin := s.mustCreateUser(&service.User{
|
||||
Email: "active@example.com",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
got, err := s.repo.GetFirstAdmin(s.ctx)
|
||||
s.Require().NoError(err, "GetFirstAdmin")
|
||||
s.Require().Equal(activeAdmin.ID, got.ID, "should return only active admin")
|
||||
}
|
||||
|
||||
// --- Combined ---
|
||||
|
||||
func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
|
||||
user1 := s.mustCreateUser(&service.User{
|
||||
Email: "a@example.com",
|
||||
Username: "Alice",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
})
|
||||
user2 := s.mustCreateUser(&service.User{
|
||||
Email: "b@example.com",
|
||||
Username: "Bob",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusActive,
|
||||
Balance: 1,
|
||||
})
|
||||
s.mustCreateUser(&service.User{
|
||||
Email: "c@example.com",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusDisabled,
|
||||
})
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal(user1.Email, got.Email, "GetByID email mismatch")
|
||||
|
||||
gotByEmail, err := s.repo.GetByEmail(s.ctx, user2.Email)
|
||||
s.Require().NoError(err, "GetByEmail")
|
||||
s.Require().Equal(user2.ID, gotByEmail.ID, "GetByEmail ID mismatch")
|
||||
|
||||
got.Username = "Alice2"
|
||||
s.Require().NoError(s.repo.Update(s.ctx, got), "Update")
|
||||
got2, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("Alice2", got2.Username, "Update did not persist")
|
||||
|
||||
s.Require().NoError(s.repo.UpdateBalance(s.ctx, user1.ID, 2.5), "UpdateBalance")
|
||||
got3, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||
s.Require().NoError(err, "GetByID after UpdateBalance")
|
||||
s.Require().InDelta(12.5, got3.Balance, 1e-6)
|
||||
|
||||
s.Require().NoError(s.repo.DeductBalance(s.ctx, user1.ID, 5), "DeductBalance")
|
||||
got4, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||
s.Require().NoError(err, "GetByID after DeductBalance")
|
||||
s.Require().InDelta(7.5, got4.Balance, 1e-6)
|
||||
|
||||
// 透支策略:允许扣除超过余额的金额
|
||||
err = s.repo.DeductBalance(s.ctx, user1.ID, 999)
|
||||
s.Require().NoError(err, "DeductBalance should allow overdraft")
|
||||
gotOverdraft, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||
s.Require().NoError(err, "GetByID after overdraft")
|
||||
s.Require().Less(gotOverdraft.Balance, 0.0, "Balance should be negative after overdraft")
|
||||
|
||||
s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency")
|
||||
got5, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||
s.Require().NoError(err, "GetByID after UpdateConcurrency")
|
||||
s.Require().Equal(user1.Concurrency+3, got5.Concurrency)
|
||||
|
||||
params := pagination.PaginationParams{Page: 1, PageSize: 10}
|
||||
users, page, err := s.repo.ListWithFilters(s.ctx, params, service.UserListFilters{Status: service.StatusActive, Role: service.RoleAdmin, Search: "b@"})
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
|
||||
s.Require().Len(users, 1, "ListWithFilters len mismatch")
|
||||
s.Require().Equal(user2.ID, users[0].ID, "ListWithFilters result mismatch")
|
||||
}
|
||||
|
||||
// --- UpdateBalance/UpdateConcurrency 影响行数校验测试 ---
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateBalance_NotFound() {
|
||||
err := s.repo.UpdateBalance(s.ctx, 999999, 10.0)
|
||||
s.Require().Error(err, "expected error for non-existent user")
|
||||
s.Require().ErrorIs(err, service.ErrUserNotFound)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateConcurrency_NotFound() {
|
||||
err := s.repo.UpdateConcurrency(s.ctx, 999999, 5)
|
||||
s.Require().Error(err, "expected error for non-existent user")
|
||||
s.Require().ErrorIs(err, service.ErrUserNotFound)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDeductBalance_NotFound() {
|
||||
err := s.repo.DeductBalance(s.ctx, 999999, 5)
|
||||
s.Require().Error(err, "expected error for non-existent user")
|
||||
// DeductBalance 在用户不存在时返回 ErrUserNotFound
|
||||
s.Require().ErrorIs(err, service.ErrUserNotFound)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user