feat(table): 表格排序与搜索改为后端处理
This commit is contained in:
@@ -471,21 +471,58 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
|
||||
case service.StatusActive:
|
||||
q = q.Where(
|
||||
dbaccount.StatusEQ(status),
|
||||
dbaccount.SchedulableEQ(true),
|
||||
dbaccount.Or(
|
||||
dbaccount.RateLimitResetAtIsNil(),
|
||||
dbaccount.RateLimitResetAtLTE(time.Now()),
|
||||
),
|
||||
dbpredicate.Account(func(s *entsql.Selector) {
|
||||
col := s.C("temp_unschedulable_until")
|
||||
s.Where(entsql.Or(
|
||||
entsql.IsNull(col),
|
||||
entsql.LTE(col, entsql.Expr("NOW()")),
|
||||
))
|
||||
}),
|
||||
)
|
||||
case "rate_limited":
|
||||
q = q.Where(dbaccount.RateLimitResetAtGT(time.Now()))
|
||||
q = q.Where(
|
||||
dbaccount.StatusEQ(service.StatusActive),
|
||||
dbaccount.RateLimitResetAtGT(time.Now()),
|
||||
dbpredicate.Account(func(s *entsql.Selector) {
|
||||
col := s.C("temp_unschedulable_until")
|
||||
s.Where(entsql.Or(
|
||||
entsql.IsNull(col),
|
||||
entsql.LTE(col, entsql.Expr("NOW()")),
|
||||
))
|
||||
}),
|
||||
)
|
||||
case "temp_unschedulable":
|
||||
q = q.Where(dbpredicate.Account(func(s *entsql.Selector) {
|
||||
col := s.C("temp_unschedulable_until")
|
||||
s.Where(entsql.And(
|
||||
entsql.Not(entsql.IsNull(col)),
|
||||
entsql.GT(col, entsql.Expr("NOW()")),
|
||||
))
|
||||
}))
|
||||
q = q.Where(
|
||||
dbaccount.StatusEQ(service.StatusActive),
|
||||
dbpredicate.Account(func(s *entsql.Selector) {
|
||||
col := s.C("temp_unschedulable_until")
|
||||
s.Where(entsql.And(
|
||||
entsql.Not(entsql.IsNull(col)),
|
||||
entsql.GT(col, entsql.Expr("NOW()")),
|
||||
))
|
||||
}),
|
||||
)
|
||||
case "unschedulable":
|
||||
q = q.Where(
|
||||
dbaccount.StatusEQ(service.StatusActive),
|
||||
dbaccount.SchedulableEQ(false),
|
||||
dbaccount.Or(
|
||||
dbaccount.RateLimitResetAtIsNil(),
|
||||
dbaccount.RateLimitResetAtLTE(time.Now()),
|
||||
),
|
||||
dbpredicate.Account(func(s *entsql.Selector) {
|
||||
col := s.C("temp_unschedulable_until")
|
||||
s.Where(entsql.Or(
|
||||
entsql.IsNull(col),
|
||||
entsql.LTE(col, entsql.Expr("NOW()")),
|
||||
))
|
||||
}),
|
||||
)
|
||||
default:
|
||||
q = q.Where(dbaccount.StatusEQ(status))
|
||||
}
|
||||
@@ -518,11 +555,14 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
accounts, err := q.
|
||||
accountsQuery := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(dbaccount.FieldID)).
|
||||
All(ctx)
|
||||
Limit(params.Limit())
|
||||
for _, order := range accountListOrder(params) {
|
||||
accountsQuery = accountsQuery.Order(order)
|
||||
}
|
||||
|
||||
accounts, err := accountsQuery.All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -534,6 +574,50 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati
|
||||
return outAccounts, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func accountListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
|
||||
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
|
||||
sortOrder := params.NormalizedSortOrder(pagination.SortOrderAsc)
|
||||
|
||||
field := dbaccount.FieldName
|
||||
defaultOrder := true
|
||||
switch sortBy {
|
||||
case "", "name":
|
||||
field = dbaccount.FieldName
|
||||
case "id":
|
||||
field = dbaccount.FieldID
|
||||
defaultOrder = false
|
||||
case "status":
|
||||
field = dbaccount.FieldStatus
|
||||
defaultOrder = false
|
||||
case "schedulable":
|
||||
field = dbaccount.FieldSchedulable
|
||||
defaultOrder = false
|
||||
case "priority":
|
||||
field = dbaccount.FieldPriority
|
||||
defaultOrder = false
|
||||
case "rate_multiplier":
|
||||
field = dbaccount.FieldRateMultiplier
|
||||
defaultOrder = false
|
||||
case "last_used_at":
|
||||
field = dbaccount.FieldLastUsedAt
|
||||
defaultOrder = false
|
||||
case "expires_at":
|
||||
field = dbaccount.FieldExpiresAt
|
||||
defaultOrder = false
|
||||
case "created_at":
|
||||
field = dbaccount.FieldCreatedAt
|
||||
defaultOrder = false
|
||||
}
|
||||
|
||||
if sortOrder == pagination.SortOrderDesc {
|
||||
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbaccount.FieldID)}
|
||||
}
|
||||
if defaultOrder {
|
||||
return []func(*entsql.Selector){dbent.Asc(dbaccount.FieldName), dbent.Asc(dbaccount.FieldID)}
|
||||
}
|
||||
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbaccount.FieldID)}
|
||||
}
|
||||
|
||||
func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) {
|
||||
accounts, err := r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{
|
||||
status: service.StatusActive,
|
||||
|
||||
@@ -256,7 +256,7 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter_by_status_active_excludes_rate_limited",
|
||||
name: "filter_by_status_active_excludes_runtime_blocked_accounts",
|
||||
setup: func(client *dbent.Client) {
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "active-normal", Status: service.StatusActive})
|
||||
rateLimited := mustCreateAccount(s.T(), client, &service.Account{Name: "active-rate-limited", Status: service.StatusActive})
|
||||
@@ -264,6 +264,16 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
SetRateLimitResetAt(time.Now().Add(10 * time.Minute)).
|
||||
Exec(context.Background())
|
||||
s.Require().NoError(err)
|
||||
tempUnsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-temp-unsched", Status: service.StatusActive})
|
||||
err = client.Account.UpdateOneID(tempUnsched.ID).
|
||||
SetTempUnschedulableUntil(time.Now().Add(15 * time.Minute)).
|
||||
Exec(context.Background())
|
||||
s.Require().NoError(err)
|
||||
unsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-unsched", Status: service.StatusActive})
|
||||
err = client.Account.UpdateOneID(unsched.ID).
|
||||
SetSchedulable(false).
|
||||
Exec(context.Background())
|
||||
s.Require().NoError(err)
|
||||
},
|
||||
status: service.StatusActive,
|
||||
wantCount: 1,
|
||||
@@ -271,6 +281,75 @@ func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
s.Require().Equal("active-normal", accounts[0].Name)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter_by_status_unschedulable_excludes_rate_limited_and_temp_unschedulable",
|
||||
setup: func(client *dbent.Client) {
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "active-normal", Status: service.StatusActive, Schedulable: true})
|
||||
unsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-unsched", Status: service.StatusActive})
|
||||
err := client.Account.UpdateOneID(unsched.ID).
|
||||
SetSchedulable(false).
|
||||
Exec(context.Background())
|
||||
s.Require().NoError(err)
|
||||
rateLimited := mustCreateAccount(s.T(), client, &service.Account{Name: "active-rate-limited", Status: service.StatusActive})
|
||||
err = client.Account.UpdateOneID(rateLimited.ID).
|
||||
SetSchedulable(false).
|
||||
SetRateLimitResetAt(time.Now().Add(10 * time.Minute)).
|
||||
Exec(context.Background())
|
||||
s.Require().NoError(err)
|
||||
tempUnsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-temp-unsched", Status: service.StatusActive})
|
||||
err = client.Account.UpdateOneID(tempUnsched.ID).
|
||||
SetSchedulable(false).
|
||||
SetTempUnschedulableUntil(time.Now().Add(15 * time.Minute)).
|
||||
Exec(context.Background())
|
||||
s.Require().NoError(err)
|
||||
},
|
||||
status: "unschedulable",
|
||||
wantCount: 1,
|
||||
validate: func(accounts []service.Account) {
|
||||
s.Require().Equal("active-unsched", accounts[0].Name)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter_by_status_rate_limited_excludes_temp_unschedulable",
|
||||
setup: func(client *dbent.Client) {
|
||||
rateLimited := mustCreateAccount(s.T(), client, &service.Account{Name: "active-rate-limited", Status: service.StatusActive})
|
||||
err := client.Account.UpdateOneID(rateLimited.ID).
|
||||
SetRateLimitResetAt(time.Now().Add(10 * time.Minute)).
|
||||
Exec(context.Background())
|
||||
s.Require().NoError(err)
|
||||
tempUnsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-temp-unsched", Status: service.StatusActive})
|
||||
err = client.Account.UpdateOneID(tempUnsched.ID).
|
||||
SetRateLimitResetAt(time.Now().Add(20 * time.Minute)).
|
||||
SetTempUnschedulableUntil(time.Now().Add(15 * time.Minute)).
|
||||
Exec(context.Background())
|
||||
s.Require().NoError(err)
|
||||
},
|
||||
status: "rate_limited",
|
||||
wantCount: 1,
|
||||
validate: func(accounts []service.Account) {
|
||||
s.Require().Equal("active-rate-limited", accounts[0].Name)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter_by_status_temp_unschedulable_excludes_manually_unschedulable",
|
||||
setup: func(client *dbent.Client) {
|
||||
tempUnsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-temp-unsched", Status: service.StatusActive, Schedulable: true})
|
||||
err := client.Account.UpdateOneID(tempUnsched.ID).
|
||||
SetTempUnschedulableUntil(time.Now().Add(15 * time.Minute)).
|
||||
Exec(context.Background())
|
||||
s.Require().NoError(err)
|
||||
unsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-unsched", Status: service.StatusActive})
|
||||
err = client.Account.UpdateOneID(unsched.ID).
|
||||
SetSchedulable(false).
|
||||
Exec(context.Background())
|
||||
s.Require().NoError(err)
|
||||
},
|
||||
status: "temp_unschedulable",
|
||||
wantCount: 1,
|
||||
validate: func(accounts []service.Account) {
|
||||
s.Require().Equal("active-temp-unsched", accounts[0].Name)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter_by_search",
|
||||
setup: func(client *dbent.Client) {
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (s *AccountRepoSuite) TestList_DefaultSortByNameAsc() {
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "z-account"})
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-account"})
|
||||
|
||||
accounts, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 2)
|
||||
s.Require().Equal("a-account", accounts[0].Name)
|
||||
s.Require().Equal("z-account", accounts[1].Name)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListWithFilters_SortByPriorityDesc() {
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "low-priority", Priority: 10})
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "high-priority", Priority: 90})
|
||||
|
||||
accounts, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
|
||||
Page: 1,
|
||||
PageSize: 10,
|
||||
SortBy: "priority",
|
||||
SortOrder: "desc",
|
||||
}, "", "", "", "", 0, "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 2)
|
||||
s.Require().Equal("high-priority", accounts[0].Name)
|
||||
s.Require().Equal("low-priority", accounts[1].Name)
|
||||
}
|
||||
@@ -2,12 +2,15 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/announcement"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
)
|
||||
|
||||
type announcementRepository struct {
|
||||
@@ -128,11 +131,14 @@ func (r *announcementRepository) List(
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
items, err := q.
|
||||
itemsQuery := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(announcement.FieldID)).
|
||||
All(ctx)
|
||||
Limit(params.Limit())
|
||||
for _, order := range announcementListOrders(params) {
|
||||
itemsQuery = itemsQuery.Order(order)
|
||||
}
|
||||
|
||||
items, err := itemsQuery.All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -141,6 +147,56 @@ func (r *announcementRepository) List(
|
||||
return out, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func announcementListOrder(params pagination.PaginationParams) (string, string) {
|
||||
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
|
||||
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
|
||||
|
||||
switch sortBy {
|
||||
case "title":
|
||||
return announcement.FieldTitle, sortOrder
|
||||
case "status":
|
||||
return announcement.FieldStatus, sortOrder
|
||||
case "notify_mode":
|
||||
return announcement.FieldNotifyMode, sortOrder
|
||||
case "starts_at":
|
||||
return announcement.FieldStartsAt, sortOrder
|
||||
case "ends_at":
|
||||
return announcement.FieldEndsAt, sortOrder
|
||||
case "id":
|
||||
return announcement.FieldID, sortOrder
|
||||
case "", "created_at":
|
||||
return announcement.FieldCreatedAt, sortOrder
|
||||
default:
|
||||
return announcement.FieldCreatedAt, pagination.SortOrderDesc
|
||||
}
|
||||
}
|
||||
|
||||
func announcementListOrders(params pagination.PaginationParams) []func(*entsql.Selector) {
|
||||
field, sortOrder := announcementListOrder(params)
|
||||
|
||||
if sortOrder == pagination.SortOrderAsc {
|
||||
if field == announcement.FieldID {
|
||||
return []func(*entsql.Selector){
|
||||
dbent.Asc(field),
|
||||
}
|
||||
}
|
||||
return []func(*entsql.Selector){
|
||||
dbent.Asc(field),
|
||||
dbent.Asc(announcement.FieldID),
|
||||
}
|
||||
}
|
||||
|
||||
if field == announcement.FieldID {
|
||||
return []func(*entsql.Selector){
|
||||
dbent.Desc(field),
|
||||
}
|
||||
}
|
||||
return []func(*entsql.Selector){
|
||||
dbent.Desc(field),
|
||||
dbent.Desc(announcement.FieldID),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *announcementRepository) ListActive(ctx context.Context, now time.Time) ([]service.Announcement, error) {
|
||||
q := r.client.Announcement.Query().
|
||||
Where(
|
||||
|
||||
63
backend/internal/repository/announcement_repo_sort_test.go
Normal file
63
backend/internal/repository/announcement_repo_sort_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
func TestAnnouncementListOrder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
params pagination.PaginationParams
|
||||
wantBy string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "default created_at desc",
|
||||
params: pagination.PaginationParams{},
|
||||
wantBy: "created_at",
|
||||
want: "desc",
|
||||
},
|
||||
{
|
||||
name: "title asc",
|
||||
params: pagination.PaginationParams{
|
||||
SortBy: "title",
|
||||
SortOrder: "ASC",
|
||||
},
|
||||
wantBy: "title",
|
||||
want: "asc",
|
||||
},
|
||||
{
|
||||
name: "status desc",
|
||||
params: pagination.PaginationParams{
|
||||
SortBy: "status",
|
||||
SortOrder: "desc",
|
||||
},
|
||||
wantBy: "status",
|
||||
want: "desc",
|
||||
},
|
||||
{
|
||||
name: "invalid falls back",
|
||||
params: pagination.PaginationParams{
|
||||
SortBy: "sideways",
|
||||
SortOrder: "wat",
|
||||
},
|
||||
wantBy: "created_at",
|
||||
want: "desc",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gotBy, gotOrder := announcementListOrder(tt.params)
|
||||
if gotBy != tt.wantBy || gotOrder != tt.want {
|
||||
t.Fatalf("announcementListOrder(%+v) = (%q, %q), want (%q, %q)", tt.params, gotBy, gotOrder, tt.wantBy, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
@@ -14,6 +15,8 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
)
|
||||
|
||||
type apiKeyRepository struct {
|
||||
@@ -309,12 +312,15 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
keys, err := q.
|
||||
keysQuery := q.
|
||||
WithGroup().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(apikey.FieldID)).
|
||||
All(ctx)
|
||||
Limit(params.Limit())
|
||||
for _, order := range apiKeyListOrder(params) {
|
||||
keysQuery = keysQuery.Order(order)
|
||||
}
|
||||
|
||||
keys, err := keysQuery.All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -359,12 +365,15 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
keys, err := q.
|
||||
keysQuery := q.
|
||||
WithUser().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(apikey.FieldID)).
|
||||
All(ctx)
|
||||
Limit(params.Limit())
|
||||
for _, order := range apiKeyListOrder(params) {
|
||||
keysQuery = keysQuery.Order(order)
|
||||
}
|
||||
|
||||
keys, err := keysQuery.All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -377,6 +386,34 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
|
||||
return outKeys, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func apiKeyListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
|
||||
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
|
||||
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
|
||||
|
||||
field := apikey.FieldID
|
||||
switch sortBy {
|
||||
case "name":
|
||||
field = apikey.FieldName
|
||||
case "status":
|
||||
field = apikey.FieldStatus
|
||||
case "expires_at":
|
||||
field = apikey.FieldExpiresAt
|
||||
case "last_used_at":
|
||||
field = apikey.FieldLastUsedAt
|
||||
case "created_at":
|
||||
field = apikey.FieldCreatedAt
|
||||
case "id", "":
|
||||
field = apikey.FieldID
|
||||
default:
|
||||
field = apikey.FieldID
|
||||
}
|
||||
|
||||
if sortOrder == pagination.SortOrderAsc {
|
||||
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(apikey.FieldID)}
|
||||
}
|
||||
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(apikey.FieldID)}
|
||||
}
|
||||
|
||||
// SearchAPIKeys searches API keys by user ID and/or keyword (name)
|
||||
func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
|
||||
q := r.activeQuery()
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (s *APIKeyRepoSuite) TestListByUserID_SortByNameAsc() {
|
||||
user := s.mustCreateUser("sort-name@example.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-z", "z-key", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-a", "a-key", nil)
|
||||
|
||||
keys, _, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{
|
||||
Page: 1,
|
||||
PageSize: 10,
|
||||
SortBy: "name",
|
||||
SortOrder: "asc",
|
||||
}, service.APIKeyListFilters{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(keys, 2)
|
||||
s.Require().Equal("a-key", keys[0].Name)
|
||||
s.Require().Equal("z-key", keys[1].Name)
|
||||
}
|
||||
@@ -188,8 +188,8 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
||||
// 查询 channel 列表
|
||||
dataQuery := fmt.Sprintf(
|
||||
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.created_at, c.updated_at
|
||||
FROM channels c WHERE %s ORDER BY c.id ASC LIMIT $%d OFFSET $%d`,
|
||||
whereClause, argIdx, argIdx+1,
|
||||
FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`,
|
||||
whereClause, channelListOrderBy(params), argIdx, argIdx+1,
|
||||
)
|
||||
args = append(args, pageSize, offset)
|
||||
|
||||
@@ -246,6 +246,31 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
||||
return channels, paginationResult, nil
|
||||
}
|
||||
|
||||
func channelListOrderBy(params pagination.PaginationParams) string {
|
||||
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
|
||||
sortOrder := strings.ToUpper(params.NormalizedSortOrder(pagination.SortOrderAsc))
|
||||
|
||||
column := "c.id"
|
||||
switch sortBy {
|
||||
case "":
|
||||
column = "c.id"
|
||||
sortOrder = "ASC"
|
||||
case "id":
|
||||
column = "c.id"
|
||||
case "name":
|
||||
column = "c.name"
|
||||
case "status":
|
||||
column = "c.status"
|
||||
case "created_at":
|
||||
column = "c.created_at"
|
||||
default:
|
||||
column = "c.id"
|
||||
sortOrder = "ASC"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s %s, c.id %s", column, sortOrder, sortOrder)
|
||||
}
|
||||
|
||||
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at FROM channels ORDER BY id`,
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/lib/pq"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -225,3 +226,12 @@ func TestIsUniqueViolation(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChannelListOrderBy_AllowsDescendingIDSort(t *testing.T) {
|
||||
params := pagination.PaginationParams{
|
||||
SortBy: "id",
|
||||
SortOrder: "desc",
|
||||
}
|
||||
|
||||
require.Equal(t, "c.id DESC, c.id DESC", channelListOrderBy(params))
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
@@ -14,6 +15,8 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
)
|
||||
|
||||
type sqlExecutor interface {
|
||||
@@ -231,11 +234,18 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
groups, err := q.
|
||||
if strings.EqualFold(strings.TrimSpace(params.SortBy), "account_count") {
|
||||
return r.listWithAccountCountSort(ctx, q, params, total)
|
||||
}
|
||||
|
||||
groupsQuery := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)).
|
||||
All(ctx)
|
||||
Limit(params.Limit())
|
||||
for _, order := range groupListOrder(params) {
|
||||
groupsQuery = groupsQuery.Order(order)
|
||||
}
|
||||
|
||||
groups, err := groupsQuery.All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -261,6 +271,104 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
|
||||
return outGroups, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) listWithAccountCountSort(ctx context.Context, q *dbent.GroupQuery, params pagination.PaginationParams, total int) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
groups, err := q.
|
||||
Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
groupIDs := make([]int64, 0, len(groups))
|
||||
outGroups := make([]service.Group, 0, len(groups))
|
||||
for i := range groups {
|
||||
g := groupEntityToService(groups[i])
|
||||
outGroups = append(outGroups, *g)
|
||||
groupIDs = append(groupIDs, g.ID)
|
||||
}
|
||||
|
||||
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for i := range outGroups {
|
||||
c := counts[outGroups[i].ID]
|
||||
outGroups[i].AccountCount = c.Total
|
||||
outGroups[i].ActiveAccountCount = c.Active
|
||||
outGroups[i].RateLimitedAccountCount = c.RateLimited
|
||||
}
|
||||
|
||||
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
|
||||
sort.SliceStable(outGroups, func(i, j int) bool {
|
||||
if outGroups[i].AccountCount == outGroups[j].AccountCount {
|
||||
if outGroups[i].SortOrder == outGroups[j].SortOrder {
|
||||
return outGroups[i].ID < outGroups[j].ID
|
||||
}
|
||||
return outGroups[i].SortOrder < outGroups[j].SortOrder
|
||||
}
|
||||
if sortOrder == pagination.SortOrderAsc {
|
||||
return outGroups[i].AccountCount < outGroups[j].AccountCount
|
||||
}
|
||||
return outGroups[i].AccountCount > outGroups[j].AccountCount
|
||||
})
|
||||
|
||||
return paginateSlice(outGroups, params), paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func groupListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
|
||||
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
|
||||
sortOrder := params.NormalizedSortOrder(pagination.SortOrderAsc)
|
||||
|
||||
field := group.FieldSortOrder
|
||||
tieField := group.FieldID
|
||||
defaultOrder := true
|
||||
switch sortBy {
|
||||
case "", "sort_order":
|
||||
field = group.FieldSortOrder
|
||||
case "name":
|
||||
field = group.FieldName
|
||||
defaultOrder = false
|
||||
case "platform":
|
||||
field = group.FieldPlatform
|
||||
defaultOrder = false
|
||||
case "billing_type", "subscription_type":
|
||||
field = group.FieldSubscriptionType
|
||||
defaultOrder = false
|
||||
case "rate_multiplier":
|
||||
field = group.FieldRateMultiplier
|
||||
defaultOrder = false
|
||||
case "is_exclusive":
|
||||
field = group.FieldIsExclusive
|
||||
defaultOrder = false
|
||||
case "status":
|
||||
field = group.FieldStatus
|
||||
defaultOrder = false
|
||||
case "created_at":
|
||||
field = group.FieldCreatedAt
|
||||
defaultOrder = false
|
||||
case "id":
|
||||
field = group.FieldID
|
||||
defaultOrder = false
|
||||
tieField = ""
|
||||
default:
|
||||
field = group.FieldSortOrder
|
||||
}
|
||||
|
||||
if sortOrder == pagination.SortOrderDesc && sortBy != "" {
|
||||
if tieField == "" {
|
||||
return []func(*entsql.Selector){dbent.Desc(field)}
|
||||
}
|
||||
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(tieField)}
|
||||
}
|
||||
if defaultOrder {
|
||||
return []func(*entsql.Selector){dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)}
|
||||
}
|
||||
if tieField == "" {
|
||||
return []func(*entsql.Selector){dbent.Asc(field)}
|
||||
}
|
||||
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(tieField)}
|
||||
}
|
||||
|
||||
func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, error) {
|
||||
groups, err := r.client.Group.Query().
|
||||
Where(group.StatusEQ(service.StatusActive)).
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (s *GroupRepoSuite) TestList_DefaultSortBySortOrderAsc() {
|
||||
g1 := &service.Group{Name: "g1", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 20}
|
||||
g2 := &service.Group{Name: "g2", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 10}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g1))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g2))
|
||||
|
||||
groups, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(groups, 2)
|
||||
s.Require().Equal(g2.ID, groups[0].ID)
|
||||
s.Require().Equal(g1.ID, groups[1].ID)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestList_SortBySortOrderDesc() {
|
||||
g1 := &service.Group{Name: "g1", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 40}
|
||||
g2 := &service.Group{Name: "g2", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 50}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g1))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g2))
|
||||
|
||||
groups, _, err := s.repo.List(s.ctx, pagination.PaginationParams{
|
||||
Page: 1,
|
||||
PageSize: 10,
|
||||
SortBy: "sort_order",
|
||||
SortOrder: "desc",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().GreaterOrEqual(len(groups), 2)
|
||||
indexByID := make(map[int64]int, len(groups))
|
||||
for i, group := range groups {
|
||||
indexByID[group.ID] = i
|
||||
}
|
||||
s.Require().Contains(indexByID, g1.ID)
|
||||
s.Require().Contains(indexByID, g2.ID)
|
||||
s.Require().Less(indexByID[g2.ID], indexByID[g1.ID])
|
||||
}
|
||||
@@ -14,3 +14,22 @@ func paginationResultFromTotal(total int64, params pagination.PaginationParams)
|
||||
Pages: pages,
|
||||
}
|
||||
}
|
||||
|
||||
func paginateSlice[T any](items []T, params pagination.PaginationParams) []T {
|
||||
if len(items) == 0 {
|
||||
return []T{}
|
||||
}
|
||||
|
||||
offset := params.Offset()
|
||||
if offset >= len(items) {
|
||||
return []T{}
|
||||
}
|
||||
|
||||
limit := params.Limit()
|
||||
end := offset + limit
|
||||
if end > len(items) {
|
||||
end = len(items)
|
||||
}
|
||||
|
||||
return items[offset:end]
|
||||
}
|
||||
|
||||
@@ -2,12 +2,15 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
)
|
||||
|
||||
type promoCodeRepository struct {
|
||||
@@ -137,11 +140,14 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
codes, err := q.
|
||||
codesQuery := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(promocode.FieldID)).
|
||||
All(ctx)
|
||||
Limit(params.Limit())
|
||||
for _, order := range promoCodeListOrder(params) {
|
||||
codesQuery = codesQuery.Order(order)
|
||||
}
|
||||
|
||||
codes, err := codesQuery.All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -151,6 +157,34 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina
|
||||
return outCodes, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func promoCodeListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
|
||||
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
|
||||
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
|
||||
|
||||
field := promocode.FieldID
|
||||
switch sortBy {
|
||||
case "bonus_amount":
|
||||
field = promocode.FieldBonusAmount
|
||||
case "status":
|
||||
field = promocode.FieldStatus
|
||||
case "expires_at":
|
||||
field = promocode.FieldExpiresAt
|
||||
case "created_at":
|
||||
field = promocode.FieldCreatedAt
|
||||
case "code":
|
||||
field = promocode.FieldCode
|
||||
case "id", "":
|
||||
field = promocode.FieldID
|
||||
default:
|
||||
field = promocode.FieldID
|
||||
}
|
||||
|
||||
if sortOrder == pagination.SortOrderAsc {
|
||||
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(promocode.FieldID)}
|
||||
}
|
||||
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(promocode.FieldID)}
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) CreateUsage(ctx context.Context, usage *service.PromoCodeUsage) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
created, err := client.PromoCodeUsage.Create().
|
||||
|
||||
@@ -3,12 +3,16 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/proxy"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
)
|
||||
|
||||
type sqlQuerier interface {
|
||||
@@ -135,11 +139,14 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
proxies, err := q.
|
||||
proxiesQuery := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(proxy.FieldID)).
|
||||
All(ctx)
|
||||
Limit(params.Limit())
|
||||
for _, order := range proxyListOrder(params) {
|
||||
proxiesQuery = proxiesQuery.Order(order)
|
||||
}
|
||||
|
||||
proxies, err := proxiesQuery.All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -170,22 +177,58 @@ func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, pa
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
proxies, err := q.
|
||||
if strings.EqualFold(strings.TrimSpace(params.SortBy), "account_count") {
|
||||
return r.listWithAccountCountSort(ctx, q, params, total)
|
||||
}
|
||||
|
||||
proxiesQuery := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Limit(params.Limit())
|
||||
for _, order := range proxyListOrder(params) {
|
||||
proxiesQuery = proxiesQuery.Order(order)
|
||||
}
|
||||
|
||||
proxies, err := proxiesQuery.All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return r.buildProxyWithAccountCountResult(ctx, proxies, params, int64(total))
|
||||
}
|
||||
|
||||
func (r *proxyRepository) listWithAccountCountSort(ctx context.Context, q *dbent.ProxyQuery, params pagination.PaginationParams, total int) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) {
|
||||
proxies, err := q.
|
||||
Order(dbent.Desc(proxy.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Get account counts
|
||||
result, _, err := r.buildProxyWithAccountCountResult(ctx, proxies, params, int64(total))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
|
||||
sort.SliceStable(result, func(i, j int) bool {
|
||||
if result[i].AccountCount == result[j].AccountCount {
|
||||
return result[i].ID > result[j].ID
|
||||
}
|
||||
if sortOrder == pagination.SortOrderAsc {
|
||||
return result[i].AccountCount < result[j].AccountCount
|
||||
}
|
||||
return result[i].AccountCount > result[j].AccountCount
|
||||
})
|
||||
|
||||
return paginateSlice(result, params), paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *proxyRepository) buildProxyWithAccountCountResult(ctx context.Context, proxies []*dbent.Proxy, params pagination.PaginationParams, total int64) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) {
|
||||
counts, err := r.GetAccountCountsForProxies(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Build result with account counts
|
||||
result := make([]service.ProxyWithAccountCount, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
proxyOut := proxyEntityToService(proxies[i])
|
||||
@@ -198,7 +241,33 @@ func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, pa
|
||||
})
|
||||
}
|
||||
|
||||
return result, paginationResultFromTotal(int64(total), params), nil
|
||||
return result, paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
func proxyListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
|
||||
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
|
||||
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
|
||||
|
||||
field := proxy.FieldID
|
||||
switch sortBy {
|
||||
case "name":
|
||||
field = proxy.FieldName
|
||||
case "protocol":
|
||||
field = proxy.FieldProtocol
|
||||
case "status":
|
||||
field = proxy.FieldStatus
|
||||
case "created_at":
|
||||
field = proxy.FieldCreatedAt
|
||||
case "id", "":
|
||||
field = proxy.FieldID
|
||||
default:
|
||||
field = proxy.FieldID
|
||||
}
|
||||
|
||||
if sortOrder == pagination.SortOrderAsc {
|
||||
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(proxy.FieldID)}
|
||||
}
|
||||
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(proxy.FieldID)}
|
||||
}
|
||||
|
||||
func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) {
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (s *ProxyRepoSuite) TestListWithFiltersAndAccountCount_SortByAccountCountDesc() {
|
||||
p1 := s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
p2 := s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
|
||||
s.mustInsertAccount("a1", &p1.ID)
|
||||
s.mustInsertAccount("a2", &p1.ID)
|
||||
s.mustInsertAccount("a3", &p2.ID)
|
||||
|
||||
proxies, _, err := s.repo.ListWithFiltersAndAccountCount(s.ctx, pagination.PaginationParams{
|
||||
Page: 1,
|
||||
PageSize: 10,
|
||||
SortBy: "account_count",
|
||||
SortOrder: "desc",
|
||||
}, "", "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(proxies, 2)
|
||||
s.Require().Equal(p1.ID, proxies[0].ID)
|
||||
s.Require().Equal(int64(2), proxies[0].AccountCount)
|
||||
s.Require().Equal(p2.ID, proxies[1].ID)
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
@@ -9,6 +10,8 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
)
|
||||
|
||||
type redeemCodeRepository struct {
|
||||
@@ -120,13 +123,16 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
codes, err := q.
|
||||
codesQuery := q.
|
||||
WithUser().
|
||||
WithGroup().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(redeemcode.FieldID)).
|
||||
All(ctx)
|
||||
Limit(params.Limit())
|
||||
for _, order := range redeemCodeListOrder(params) {
|
||||
codesQuery = codesQuery.Order(order)
|
||||
}
|
||||
|
||||
codes, err := codesQuery.All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -136,6 +142,36 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
|
||||
return outCodes, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func redeemCodeListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
|
||||
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
|
||||
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
|
||||
|
||||
field := redeemcode.FieldID
|
||||
switch sortBy {
|
||||
case "type":
|
||||
field = redeemcode.FieldType
|
||||
case "value":
|
||||
field = redeemcode.FieldValue
|
||||
case "status":
|
||||
field = redeemcode.FieldStatus
|
||||
case "used_at":
|
||||
field = redeemcode.FieldUsedAt
|
||||
case "created_at":
|
||||
field = redeemcode.FieldCreatedAt
|
||||
case "code":
|
||||
field = redeemcode.FieldCode
|
||||
case "id", "":
|
||||
field = redeemcode.FieldID
|
||||
default:
|
||||
field = redeemcode.FieldID
|
||||
}
|
||||
|
||||
if sortOrder == pagination.SortOrderAsc {
|
||||
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(redeemcode.FieldID)}
|
||||
}
|
||||
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(redeemcode.FieldID)}
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemCode) error {
|
||||
up := r.client.RedeemCode.UpdateOneID(code.ID).
|
||||
SetCode(code.Code).
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListWithFilters_SortByValueAsc() {
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "VALUE-20", Type: service.RedeemTypeBalance, Value: 20, Status: service.StatusUnused}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "VALUE-10", Type: service.RedeemTypeBalance, Value: 10, Status: service.StatusUnused}))
|
||||
|
||||
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
|
||||
Page: 1,
|
||||
PageSize: 10,
|
||||
SortBy: "value",
|
||||
SortOrder: "asc",
|
||||
}, "", "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(codes, 2)
|
||||
s.Require().Equal("VALUE-10", codes[0].Code)
|
||||
s.Require().Equal("VALUE-20", codes[1].Code)
|
||||
}
|
||||
@@ -3771,7 +3771,7 @@ func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, wh
|
||||
limitPos := len(args) + 1
|
||||
offsetPos := len(args) + 2
|
||||
listArgs := append(append([]any{}, args...), params.Limit(), params.Offset())
|
||||
query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos)
|
||||
query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY %s LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, usageLogOrderBy(params), limitPos, offsetPos)
|
||||
logs, err := r.queryUsageLogs(ctx, query, listArgs...)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@@ -3786,7 +3786,7 @@ func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context
|
||||
limitPos := len(args) + 1
|
||||
offsetPos := len(args) + 2
|
||||
listArgs := append(append([]any{}, args...), limit+1, offset)
|
||||
query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos)
|
||||
query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY %s LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, usageLogOrderBy(params), limitPos, offsetPos)
|
||||
|
||||
logs, err := r.queryUsageLogs(ctx, query, listArgs...)
|
||||
if err != nil {
|
||||
@@ -3808,6 +3808,28 @@ func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context
|
||||
return logs, paginationResultFromTotal(total, params), nil
|
||||
}
|
||||
|
||||
func usageLogOrderBy(params pagination.PaginationParams) string {
|
||||
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
|
||||
sortOrder := strings.ToUpper(params.NormalizedSortOrder(pagination.SortOrderDesc))
|
||||
|
||||
column := "id"
|
||||
switch sortBy {
|
||||
case "model":
|
||||
column = "COALESCE(NULLIF(TRIM(requested_model), ''), model)"
|
||||
case "created_at":
|
||||
column = "created_at"
|
||||
case "id", "":
|
||||
column = "id"
|
||||
default:
|
||||
column = "id"
|
||||
}
|
||||
|
||||
if column == "id" {
|
||||
return fmt.Sprintf("id %s", sortOrder)
|
||||
}
|
||||
return fmt.Sprintf("%s %s, id %s", column, sortOrder, sortOrder)
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) {
|
||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func (s *UsageLogRepoSuite) TestListWithFilters_SortByModelAsc() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "usage-sort@example.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usage-sort", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "usage-sort-account"})
|
||||
|
||||
first := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.New().String(),
|
||||
Model: "z-model",
|
||||
RequestedModel: "z-model",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
_, err := s.repo.Create(s.ctx, first)
|
||||
s.Require().NoError(err)
|
||||
|
||||
second := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.New().String(),
|
||||
Model: "a-model",
|
||||
RequestedModel: "a-model",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
CreatedAt: time.Now().Add(time.Second),
|
||||
}
|
||||
_, err = s.repo.Create(s.ctx, second)
|
||||
s.Require().NoError(err)
|
||||
|
||||
logs, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
|
||||
Page: 1,
|
||||
PageSize: 10,
|
||||
SortBy: "model",
|
||||
SortOrder: "asc",
|
||||
}, usagestats.UsageLogFilters{UserID: user.ID})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(logs, 2)
|
||||
s.Require().Equal("a-model", logs[0].RequestedModel)
|
||||
s.Require().Equal("z-model", logs[1].RequestedModel)
|
||||
}
|
||||
@@ -17,6 +17,8 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/ent/usersubscription"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
)
|
||||
|
||||
type userRepository struct {
|
||||
@@ -224,11 +226,14 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
users, err := q.
|
||||
usersQuery := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(dbuser.FieldID)).
|
||||
All(ctx)
|
||||
Limit(params.Limit())
|
||||
for _, order := range userListOrder(params) {
|
||||
usersQuery = usersQuery.Order(order)
|
||||
}
|
||||
|
||||
users, err := usersQuery.All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -281,6 +286,52 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
|
||||
return outUsers, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
|
||||
sortBy := strings.ToLower(strings.TrimSpace(params.SortBy))
|
||||
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
|
||||
|
||||
field := dbuser.FieldID
|
||||
defaultField := true
|
||||
switch sortBy {
|
||||
case "email":
|
||||
field = dbuser.FieldEmail
|
||||
defaultField = false
|
||||
case "id", "":
|
||||
field = dbuser.FieldID
|
||||
case "username":
|
||||
field = dbuser.FieldUsername
|
||||
defaultField = false
|
||||
case "role":
|
||||
field = dbuser.FieldRole
|
||||
defaultField = false
|
||||
case "balance":
|
||||
field = dbuser.FieldBalance
|
||||
defaultField = false
|
||||
case "concurrency":
|
||||
field = dbuser.FieldConcurrency
|
||||
defaultField = false
|
||||
case "status":
|
||||
field = dbuser.FieldStatus
|
||||
defaultField = false
|
||||
case "created_at":
|
||||
field = dbuser.FieldCreatedAt
|
||||
defaultField = false
|
||||
default:
|
||||
field = dbuser.FieldID
|
||||
}
|
||||
|
||||
if sortOrder == pagination.SortOrderAsc {
|
||||
if defaultField && field == dbuser.FieldID {
|
||||
return []func(*entsql.Selector){dbent.Asc(dbuser.FieldID)}
|
||||
}
|
||||
return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbuser.FieldID)}
|
||||
}
|
||||
if defaultField && field == dbuser.FieldID {
|
||||
return []func(*entsql.Selector){dbent.Desc(dbuser.FieldID)}
|
||||
}
|
||||
return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)}
|
||||
}
|
||||
|
||||
// filterUsersByAttributes returns user IDs that match ALL the given attribute filters
|
||||
func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) {
|
||||
if len(attrs) == 0 {
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_SortByEmailAsc() {
|
||||
s.mustCreateUser(&service.User{Email: "z-last@example.com", Username: "z-user"})
|
||||
s.mustCreateUser(&service.User{Email: "a-first@example.com", Username: "a-user"})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{
|
||||
Page: 1,
|
||||
PageSize: 10,
|
||||
SortBy: "email",
|
||||
SortOrder: "asc",
|
||||
}, service.UserListFilters{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(users, 2)
|
||||
s.Require().Equal("a-first@example.com", users[0].Email)
|
||||
s.Require().Equal("z-last@example.com", users[1].Email)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestList_DefaultSortByNewestFirst() {
|
||||
first := s.mustCreateUser(&service.User{Email: "first@example.com"})
|
||||
second := s.mustCreateUser(&service.User{Email: "second@example.com"})
|
||||
|
||||
users, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(users, 2)
|
||||
s.Require().Equal(second.ID, users[0].ID)
|
||||
s.Require().Equal(first.ID, users[1].ID)
|
||||
}
|
||||
|
||||
func TestUserRepoSortSuiteSmoke(_ *testing.T) {}
|
||||
Reference in New Issue
Block a user