diff --git a/backend/internal/handler/admin/account_data.go b/backend/internal/handler/admin/account_data.go index 20cc09ee..00da4821 100644 --- a/backend/internal/handler/admin/account_data.go +++ b/backend/internal/handler/admin/account_data.go @@ -10,6 +10,7 @@ import ( "log/slog" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" @@ -359,7 +360,7 @@ func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, e pageSize := dataPageCap var out []service.Proxy for { - items, total, err := h.adminService.ListProxies(ctx, page, pageSize, "", "", "") + items, total, err := h.adminService.ListProxies(ctx, page, pageSize, "", "", "", "created_at", "desc") if err != nil { return nil, err } @@ -372,12 +373,12 @@ func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, e return out, nil } -func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, accountType, status, search string) ([]service.Account, error) { +func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, accountType, status, search string, groupID int64, privacyMode, sortBy, sortOrder string) ([]service.Account, error) { page := 1 pageSize := dataPageCap var out []service.Account for { - items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0, "") + items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, groupID, privacyMode, sortBy, sortOrder) if err != nil { return nil, err } @@ -409,11 +410,28 @@ func (h *AccountHandler) resolveExportAccounts(ctx context.Context, ids []int64, platform := c.Query("platform") accountType := c.Query("type") status := c.Query("status") + privacyMode := strings.TrimSpace(c.Query("privacy_mode")) search := strings.TrimSpace(c.Query("search")) + sortBy := c.DefaultQuery("sort_by", "name") + sortOrder := c.DefaultQuery("sort_order", "asc") if len(search) > 100 { search = search[:100] } - return h.listAccountsFiltered(ctx, platform, accountType, status, search) + + groupID := int64(0) + if groupIDStr := c.Query("group"); groupIDStr != "" { + if groupIDStr == accountListGroupUngroupedQueryValue { + groupID = service.AccountListGroupUngrouped + } else { + parsedGroupID, parseErr := strconv.ParseInt(groupIDStr, 10, 64) + if parseErr != nil || parsedGroupID <= 0 { + return nil, infraerrors.BadRequest("INVALID_GROUP_FILTER", "invalid group filter") + } + groupID = parsedGroupID + } + } + + return h.listAccountsFiltered(ctx, platform, accountType, status, search, groupID, privacyMode, sortBy, sortOrder) } func (h *AccountHandler) resolveExportProxies(ctx context.Context, accounts []service.Account) ([]service.Proxy, error) { diff --git a/backend/internal/handler/admin/account_data_handler_test.go b/backend/internal/handler/admin/account_data_handler_test.go index 285033a1..5793983c 100644 --- a/backend/internal/handler/admin/account_data_handler_test.go +++ b/backend/internal/handler/admin/account_data_handler_test.go @@ -172,6 +172,51 @@ func TestExportDataWithoutProxies(t *testing.T) { require.Nil(t, resp.Data.Accounts[0].ProxyKey) } +func TestExportDataPassesAccountFiltersAndSort(t *testing.T) { + router, adminSvc := setupAccountDataRouter() + adminSvc.accounts = []service.Account{ + {ID: 1, Name: "acc-1", Status: service.StatusActive}, + } + + rec := httptest.NewRecorder() + req := httptest.NewRequest( + http.MethodGet, + "/api/v1/admin/accounts/data?platform=openai&type=oauth&status=active&group=12&privacy_mode=blocked&search=keyword&sort_by=priority&sort_order=desc", + nil, + ) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + require.Equal(t, 1, adminSvc.lastListAccounts.calls) + require.Equal(t, "openai", adminSvc.lastListAccounts.platform) + require.Equal(t, "oauth", adminSvc.lastListAccounts.accountType) + require.Equal(t, "active", adminSvc.lastListAccounts.status) + require.Equal(t, int64(12), adminSvc.lastListAccounts.groupID) + require.Equal(t, "blocked", adminSvc.lastListAccounts.privacyMode) + require.Equal(t, "keyword", adminSvc.lastListAccounts.search) + require.Equal(t, "priority", adminSvc.lastListAccounts.sortBy) + require.Equal(t, "desc", adminSvc.lastListAccounts.sortOrder) +} + +func TestExportDataSelectedIDsOverrideFilters(t *testing.T) { + router, adminSvc := setupAccountDataRouter() + + rec := httptest.NewRecorder() + req := httptest.NewRequest( + http.MethodGet, + "/api/v1/admin/accounts/data?ids=1,2&platform=openai&search=keyword&sort_by=priority&sort_order=desc", + nil, + ) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + var resp dataResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Len(t, resp.Data.Accounts, 2) + require.Equal(t, 0, adminSvc.lastListAccounts.calls) +} + func TestImportDataReusesProxyAndSkipsDefaultGroup(t *testing.T) { router, adminSvc := setupAccountDataRouter() diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 9aed64d5..9e985a79 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -221,6 +221,8 @@ func (h *AccountHandler) List(c *gin.Context) { status := c.Query("status") search := c.Query("search") privacyMode := strings.TrimSpace(c.Query("privacy_mode")) + sortBy := c.DefaultQuery("sort_by", "name") + sortOrder := c.DefaultQuery("sort_order", "asc") // 标准化和验证 search 参数 search = strings.TrimSpace(search) if len(search) > 100 { @@ -246,7 +248,7 @@ func (h *AccountHandler) List(c *gin.Context) { } } - accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID, privacyMode) + accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID, privacyMode, sortBy, sortOrder) if err != nil { response.ErrorFrom(c, err) return @@ -2029,7 +2031,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) { accounts := make([]*service.Account, 0) if len(req.AccountIDs) == 0 { - allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0, "") + allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0, "", "name", "asc") if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 60d68913..6d1ef1b6 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -31,6 +31,33 @@ type stubAdminService struct { platform string groupIDs []int64 } + lastListAccounts struct { + platform string + accountType string + status string + search string + groupID int64 + privacyMode string + sortBy string + sortOrder string + calls int + } + lastListProxies struct { + protocol string + status string + search string + sortBy string + sortOrder string + calls int + } + lastListRedeemCodes struct { + codeType string + status string + search string + sortBy string + sortOrder string + calls int + } mu sync.Mutex } @@ -99,7 +126,7 @@ func newStubAdminService() *stubAdminService { } } -func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters) ([]service.User, int64, error) { +func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters, sortBy, sortOrder string) ([]service.User, int64, error) { return s.users, int64(len(s.users)), nil } @@ -132,7 +159,7 @@ func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64, return &user, nil } -func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]service.APIKey, int64, error) { +func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]service.APIKey, int64, error) { return s.apiKeys, int64(len(s.apiKeys)), nil } @@ -140,7 +167,7 @@ func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64, return map[string]any{"user_id": userID}, nil } -func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]service.Group, int64, error) { +func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]service.Group, int64, error) { return s.groups, int64(len(s.groups)), nil } @@ -187,7 +214,16 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int return nil } -func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]service.Account, int64, error) { +func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]service.Account, int64, error) { + s.lastListAccounts.platform = platform + s.lastListAccounts.accountType = accountType + s.lastListAccounts.status = status + s.lastListAccounts.search = search + s.lastListAccounts.groupID = groupID + s.lastListAccounts.privacyMode = privacyMode + s.lastListAccounts.sortBy = sortBy + s.lastListAccounts.sortOrder = sortOrder + s.lastListAccounts.calls++ return s.accounts, int64(len(s.accounts)), nil } @@ -261,7 +297,13 @@ func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAcc return s.checkMixedErr } -func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) { +func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]service.Proxy, int64, error) { + s.lastListProxies.protocol = protocol + s.lastListProxies.status = status + s.lastListProxies.search = search + s.lastListProxies.sortBy = sortBy + s.lastListProxies.sortOrder = sortOrder + s.lastListProxies.calls++ search = strings.TrimSpace(strings.ToLower(search)) filtered := make([]service.Proxy, 0, len(s.proxies)) for _, proxy := range s.proxies { @@ -283,7 +325,7 @@ func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, return filtered, int64(len(filtered)), nil } -func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.ProxyWithAccountCount, int64, error) { +func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]service.ProxyWithAccountCount, int64, error) { return s.proxyCounts, int64(len(s.proxyCounts)), nil } @@ -384,7 +426,13 @@ func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*se }, nil } -func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) { +func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string, sortBy, sortOrder string) ([]service.RedeemCode, int64, error) { + s.lastListRedeemCodes.codeType = codeType + s.lastListRedeemCodes.status = status + s.lastListRedeemCodes.search = search + s.lastListRedeemCodes.sortBy = sortBy + s.lastListRedeemCodes.sortOrder = sortOrder + s.lastListRedeemCodes.calls++ return s.redeems, int64(len(s.redeems)), nil } diff --git a/backend/internal/handler/admin/announcement_handler.go b/backend/internal/handler/admin/announcement_handler.go index d1312bc0..d3b9d173 100644 --- a/backend/internal/handler/admin/announcement_handler.go +++ b/backend/internal/handler/admin/announcement_handler.go @@ -52,13 +52,17 @@ func (h *AnnouncementHandler) List(c *gin.Context) { page, pageSize := response.ParsePagination(c) status := strings.TrimSpace(c.Query("status")) search := strings.TrimSpace(c.Query("search")) + sortBy := c.DefaultQuery("sort_by", "created_at") + sortOrder := c.DefaultQuery("sort_order", "desc") if len(search) > 200 { search = search[:200] } params := pagination.PaginationParams{ - Page: page, - PageSize: pageSize, + Page: page, + PageSize: pageSize, + SortBy: sortBy, + SortOrder: sortOrder, } items, paginationResult, err := h.announcementService.List( @@ -227,8 +231,10 @@ func (h *AnnouncementHandler) ListReadStatus(c *gin.Context) { page, pageSize := response.ParsePagination(c) params := pagination.PaginationParams{ - Page: page, - PageSize: pageSize, + Page: page, + PageSize: pageSize, + SortBy: c.DefaultQuery("sort_by", "email"), + SortOrder: c.DefaultQuery("sort_order", "asc"), } search := strings.TrimSpace(c.Query("search")) if len(search) > 200 { diff --git a/backend/internal/handler/admin/announcement_handler_sort_test.go b/backend/internal/handler/admin/announcement_handler_sort_test.go new file mode 100644 index 00000000..545e619e --- /dev/null +++ b/backend/internal/handler/admin/announcement_handler_sort_test.go @@ -0,0 +1,138 @@ +package admin + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type announcementRepoCapture struct { + service.AnnouncementRepository + listParams pagination.PaginationParams +} + +func (r *announcementRepoCapture) List(ctx context.Context, params pagination.PaginationParams, filters service.AnnouncementListFilters) ([]service.Announcement, *pagination.PaginationResult, error) { + r.listParams = params + return []service.Announcement{}, &pagination.PaginationResult{ + Total: 0, + Page: params.Page, + PageSize: params.PageSize, + Pages: 0, + }, nil +} + +func (r *announcementRepoCapture) GetByID(ctx context.Context, id int64) (*service.Announcement, error) { + return &service.Announcement{ + ID: id, + Title: "announcement", + Content: "content", + Status: service.AnnouncementStatusActive, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, nil +} + +type announcementUserRepoCapture struct { + service.UserRepository + listParams pagination.PaginationParams +} + +func (r *announcementUserRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { + r.listParams = params + return []service.User{}, &pagination.PaginationResult{ + Total: 0, + Page: params.Page, + PageSize: params.PageSize, + Pages: 0, + }, nil +} + +type announcementReadRepoCapture struct { + service.AnnouncementReadRepository +} + +func (r *announcementReadRepoCapture) GetReadMapByUsers(ctx context.Context, announcementID int64, userIDs []int64) (map[int64]time.Time, error) { + return map[int64]time.Time{}, nil +} + +type announcementUserSubRepoCapture struct { + service.UserSubscriptionRepository +} + +func newAnnouncementSortTestRouter(announcementRepo *announcementRepoCapture, userRepo *announcementUserRepoCapture) *gin.Engine { + gin.SetMode(gin.TestMode) + svc := service.NewAnnouncementService( + announcementRepo, + &announcementReadRepoCapture{}, + userRepo, + &announcementUserSubRepoCapture{}, + ) + handler := NewAnnouncementHandler(svc) + router := gin.New() + router.GET("/admin/announcements", handler.List) + router.GET("/admin/announcements/:id/read-status", handler.ListReadStatus) + return router +} + +func TestAdminAnnouncementListSortParams(t *testing.T) { + announcementRepo := &announcementRepoCapture{} + userRepo := &announcementUserRepoCapture{} + router := newAnnouncementSortTestRouter(announcementRepo, userRepo) + + req := httptest.NewRequest(http.MethodGet, "/admin/announcements?sort_by=title&sort_order=ASC", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "title", announcementRepo.listParams.SortBy) + require.Equal(t, "ASC", announcementRepo.listParams.SortOrder) +} + +func TestAdminAnnouncementListSortDefaults(t *testing.T) { + announcementRepo := &announcementRepoCapture{} + userRepo := &announcementUserRepoCapture{} + router := newAnnouncementSortTestRouter(announcementRepo, userRepo) + + req := httptest.NewRequest(http.MethodGet, "/admin/announcements", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "created_at", announcementRepo.listParams.SortBy) + require.Equal(t, "desc", announcementRepo.listParams.SortOrder) +} + +func TestAdminAnnouncementReadStatusSortParams(t *testing.T) { + announcementRepo := &announcementRepoCapture{} + userRepo := &announcementUserRepoCapture{} + router := newAnnouncementSortTestRouter(announcementRepo, userRepo) + + req := httptest.NewRequest(http.MethodGet, "/admin/announcements/1/read-status?sort_by=balance&sort_order=DESC", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "balance", userRepo.listParams.SortBy) + require.Equal(t, "DESC", userRepo.listParams.SortOrder) +} + +func TestAdminAnnouncementReadStatusSortDefaults(t *testing.T) { + announcementRepo := &announcementRepoCapture{} + userRepo := &announcementUserRepoCapture{} + router := newAnnouncementSortTestRouter(announcementRepo, userRepo) + + req := httptest.NewRequest(http.MethodGet, "/admin/announcements/1/read-status", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "email", userRepo.listParams.SortBy) + require.Equal(t, "asc", userRepo.listParams.SortOrder) +} diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index b503e5c3..c92b35bb 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -245,7 +245,12 @@ func (h *ChannelHandler) List(c *gin.Context) { search = search[:100] } - channels, pag, err := h.channelService.List(c.Request.Context(), pagination.PaginationParams{Page: page, PageSize: pageSize}, status, search) + channels, pag, err := h.channelService.List(c.Request.Context(), pagination.PaginationParams{ + Page: page, + PageSize: pageSize, + SortBy: c.DefaultQuery("sort_by", "created_at"), + SortOrder: c.DefaultQuery("sort_order", "desc"), + }, status, search) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 8b6b056d..cb2bd201 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -162,6 +162,8 @@ func (h *GroupHandler) List(c *gin.Context) { search = search[:100] } isExclusiveStr := c.Query("is_exclusive") + sortBy := c.DefaultQuery("sort_by", "sort_order") + sortOrder := c.DefaultQuery("sort_order", "asc") var isExclusive *bool if isExclusiveStr != "" { @@ -169,7 +171,7 @@ func (h *GroupHandler) List(c *gin.Context) { isExclusive = &val } - groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, search, isExclusive) + groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, search, isExclusive, sortBy, sortOrder) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/admin/promo_handler.go b/backend/internal/handler/admin/promo_handler.go index 3eafa380..77d5f171 100644 --- a/backend/internal/handler/admin/promo_handler.go +++ b/backend/internal/handler/admin/promo_handler.go @@ -55,8 +55,10 @@ func (h *PromoHandler) List(c *gin.Context) { } params := pagination.PaginationParams{ - Page: page, - PageSize: pageSize, + Page: page, + PageSize: pageSize, + SortBy: c.DefaultQuery("sort_by", "created_at"), + SortOrder: c.DefaultQuery("sort_order", "desc"), } codes, paginationResult, err := h.promoService.List(c.Request.Context(), params, status, search) diff --git a/backend/internal/handler/admin/proxy_data.go b/backend/internal/handler/admin/proxy_data.go index 72ecd6c1..8149ce3b 100644 --- a/backend/internal/handler/admin/proxy_data.go +++ b/backend/internal/handler/admin/proxy_data.go @@ -33,11 +33,13 @@ func (h *ProxyHandler) ExportData(c *gin.Context) { protocol := c.Query("protocol") status := c.Query("status") search := strings.TrimSpace(c.Query("search")) + sortBy := c.DefaultQuery("sort_by", "id") + sortOrder := c.DefaultQuery("sort_order", "desc") if len(search) > 100 { search = search[:100] } - proxies, err = h.listProxiesFiltered(ctx, protocol, status, search) + proxies, err = h.listProxiesFiltered(ctx, protocol, status, search, sortBy, sortOrder) if err != nil { response.ErrorFrom(c, err) return @@ -89,7 +91,7 @@ func (h *ProxyHandler) ImportData(c *gin.Context) { ctx := c.Request.Context() result := DataImportResult{} - existingProxies, err := h.listProxiesFiltered(ctx, "", "", "") + existingProxies, err := h.listProxiesFiltered(ctx, "", "", "", "id", "desc") if err != nil { response.ErrorFrom(c, err) return @@ -220,18 +222,33 @@ func parseProxyIDs(c *gin.Context) ([]int64, error) { return ids, nil } -func (h *ProxyHandler) listProxiesFiltered(ctx context.Context, protocol, status, search string) ([]service.Proxy, error) { +func (h *ProxyHandler) listProxiesFiltered(ctx context.Context, protocol, status, search, sortBy, sortOrder string) ([]service.Proxy, error) { page := 1 pageSize := dataPageCap var out []service.Proxy + sortBy = strings.TrimSpace(sortBy) + useAccountCountSort := strings.EqualFold(sortBy, "account_count") for { - items, total, err := h.adminService.ListProxies(ctx, page, pageSize, protocol, status, search) - if err != nil { - return nil, err - } - out = append(out, items...) - if len(out) >= int(total) || len(items) == 0 { - break + if useAccountCountSort { + items, total, err := h.adminService.ListProxiesWithAccountCount(ctx, page, pageSize, protocol, status, search, sortBy, sortOrder) + if err != nil { + return nil, err + } + for i := range items { + out = append(out, items[i].Proxy) + } + if len(out) >= int(total) || len(items) == 0 { + break + } + } else { + items, total, err := h.adminService.ListProxies(ctx, page, pageSize, protocol, status, search, sortBy, sortOrder) + if err != nil { + return nil, err + } + out = append(out, items...) + if len(out) >= int(total) || len(items) == 0 { + break + } } page++ } diff --git a/backend/internal/handler/admin/proxy_data_handler_test.go b/backend/internal/handler/admin/proxy_data_handler_test.go index 803f9b61..8cd035ed 100644 --- a/backend/internal/handler/admin/proxy_data_handler_test.go +++ b/backend/internal/handler/admin/proxy_data_handler_test.go @@ -74,6 +74,10 @@ func TestProxyExportDataRespectsFilters(t *testing.T) { require.Len(t, resp.Data.Proxies, 1) require.Len(t, resp.Data.Accounts, 0) require.Equal(t, "https", resp.Data.Proxies[0].Protocol) + require.Equal(t, 1, adminSvc.lastListProxies.calls) + require.Equal(t, "https", adminSvc.lastListProxies.protocol) + require.Equal(t, "id", adminSvc.lastListProxies.sortBy) + require.Equal(t, "desc", adminSvc.lastListProxies.sortOrder) } func TestProxyExportDataWithSelectedIDs(t *testing.T) { @@ -113,6 +117,96 @@ func TestProxyExportDataWithSelectedIDs(t *testing.T) { require.Len(t, resp.Data.Proxies, 1) require.Equal(t, "https", resp.Data.Proxies[0].Protocol) require.Equal(t, "10.0.0.2", resp.Data.Proxies[0].Host) + require.Equal(t, 0, adminSvc.lastListProxies.calls) +} + +func TestProxyExportDataPassesSortParams(t *testing.T) { + router, adminSvc := setupProxyDataRouter() + + adminSvc.proxies = []service.Proxy{ + { + ID: 1, + Name: "proxy-a", + Protocol: "http", + Host: "127.0.0.1", + Port: 8080, + Username: "user", + Password: "pass", + Status: service.StatusActive, + }, + } + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/data?protocol=http&status=active&search=proxy&sort_by=name&sort_order=asc", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + require.Equal(t, 1, adminSvc.lastListProxies.calls) + require.Equal(t, "http", adminSvc.lastListProxies.protocol) + require.Equal(t, "active", adminSvc.lastListProxies.status) + require.Equal(t, "proxy", adminSvc.lastListProxies.search) + require.Equal(t, "name", adminSvc.lastListProxies.sortBy) + require.Equal(t, "asc", adminSvc.lastListProxies.sortOrder) +} + +func TestProxyExportDataSortByAccountCountUsesAccountCountListing(t *testing.T) { + router, adminSvc := setupProxyDataRouter() + + adminSvc.proxies = []service.Proxy{ + { + ID: 1, + Name: "proxy-id-1", + Protocol: "http", + Host: "127.0.0.1", + Port: 8080, + Status: service.StatusActive, + }, + { + ID: 2, + Name: "proxy-id-2", + Protocol: "http", + Host: "127.0.0.2", + Port: 8081, + Status: service.StatusActive, + }, + } + adminSvc.proxyCounts = []service.ProxyWithAccountCount{ + { + Proxy: service.Proxy{ + ID: 2, + Name: "proxy-count-high", + Protocol: "http", + Host: "127.0.0.2", + Port: 8081, + Status: service.StatusActive, + }, + AccountCount: 9, + }, + { + Proxy: service.Proxy{ + ID: 1, + Name: "proxy-count-low", + Protocol: "http", + Host: "127.0.0.1", + Port: 8080, + Status: service.StatusActive, + }, + AccountCount: 1, + }, + } + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/data?sort_by=account_count&sort_order=desc", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + var resp proxyDataResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Len(t, resp.Data.Proxies, 2) + require.Equal(t, "proxy-count-high", resp.Data.Proxies[0].Name) + require.Equal(t, "proxy-count-low", resp.Data.Proxies[1].Name) + require.Equal(t, 0, adminSvc.lastListProxies.calls) } func TestProxyImportDataReusesAndTriggersLatencyProbe(t *testing.T) { diff --git a/backend/internal/handler/admin/proxy_handler.go b/backend/internal/handler/admin/proxy_handler.go index e8ae0ce2..f97fcb0a 100644 --- a/backend/internal/handler/admin/proxy_handler.go +++ b/backend/internal/handler/admin/proxy_handler.go @@ -52,13 +52,15 @@ func (h *ProxyHandler) List(c *gin.Context) { protocol := c.Query("protocol") status := c.Query("status") search := c.Query("search") + sortBy := c.DefaultQuery("sort_by", "id") + sortOrder := c.DefaultQuery("sort_order", "desc") // 标准化和验证 search 参数 search = strings.TrimSpace(search) if len(search) > 100 { search = search[:100] } - proxies, total, err := h.adminService.ListProxiesWithAccountCount(c.Request.Context(), page, pageSize, protocol, status, search) + proxies, total, err := h.adminService.ListProxiesWithAccountCount(c.Request.Context(), page, pageSize, protocol, status, search, sortBy, sortOrder) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/admin/redeem_export_handler_test.go b/backend/internal/handler/admin/redeem_export_handler_test.go new file mode 100644 index 00000000..9983fe31 --- /dev/null +++ b/backend/internal/handler/admin/redeem_export_handler_test.go @@ -0,0 +1,49 @@ +package admin + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func setupRedeemExportRouter() (*gin.Engine, *stubAdminService) { + gin.SetMode(gin.TestMode) + router := gin.New() + adminSvc := newStubAdminService() + + h := NewRedeemHandler(adminSvc, nil) + router.GET("/api/v1/admin/redeem-codes/export", h.Export) + return router, adminSvc +} + +func TestRedeemExportPassesSearchAndSort(t *testing.T) { + router, adminSvc := setupRedeemExportRouter() + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/export?type=balance&status=unused&search=ABC&sort_by=value&sort_order=asc", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + require.Equal(t, 1, adminSvc.lastListRedeemCodes.calls) + require.Equal(t, "balance", adminSvc.lastListRedeemCodes.codeType) + require.Equal(t, "unused", adminSvc.lastListRedeemCodes.status) + require.Equal(t, "ABC", adminSvc.lastListRedeemCodes.search) + require.Equal(t, "value", adminSvc.lastListRedeemCodes.sortBy) + require.Equal(t, "asc", adminSvc.lastListRedeemCodes.sortOrder) +} + +func TestRedeemExportSortDefaults(t *testing.T) { + router, adminSvc := setupRedeemExportRouter() + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/export", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + require.Equal(t, 1, adminSvc.lastListRedeemCodes.calls) + require.Equal(t, "id", adminSvc.lastListRedeemCodes.sortBy) + require.Equal(t, "desc", adminSvc.lastListRedeemCodes.sortOrder) +} diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go index c494e5fb..24365f3d 100644 --- a/backend/internal/handler/admin/redeem_handler.go +++ b/backend/internal/handler/admin/redeem_handler.go @@ -59,13 +59,15 @@ func (h *RedeemHandler) List(c *gin.Context) { codeType := c.Query("type") status := c.Query("status") search := c.Query("search") + sortBy := c.DefaultQuery("sort_by", "id") + sortOrder := c.DefaultQuery("sort_order", "desc") // 标准化和验证 search 参数 search = strings.TrimSpace(search) if len(search) > 100 { search = search[:100] } - codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search) + codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search, sortBy, sortOrder) if err != nil { response.ErrorFrom(c, err) return @@ -300,9 +302,15 @@ func (h *RedeemHandler) GetStats(c *gin.Context) { func (h *RedeemHandler) Export(c *gin.Context) { codeType := c.Query("type") status := c.Query("status") + search := strings.TrimSpace(c.Query("search")) + sortBy := c.DefaultQuery("sort_by", "id") + sortOrder := c.DefaultQuery("sort_order", "desc") + if len(search) > 100 { + search = search[:100] + } // Get all codes without pagination (use large page size) - codes, _, err := h.adminService.ListRedeemCodes(c.Request.Context(), 1, 10000, codeType, status, "") + codes, _, err := h.adminService.ListRedeemCodes(c.Request.Context(), 1, 10000, codeType, status, search, sortBy, sortOrder) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index cbda151a..ba751131 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -150,6 +150,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { HideCcsImportButton: settings.HideCcsImportButton, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, + TableDefaultPageSize: settings.TableDefaultPageSize, + TablePageSizeOptions: settings.TablePageSizeOptions, CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems), CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), DefaultConcurrency: settings.DefaultConcurrency, @@ -261,6 +263,8 @@ type UpdateSettingsRequest struct { HideCcsImportButton bool `json:"hide_ccs_import_button"` PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL *string `json:"purchase_subscription_url"` + TableDefaultPageSize int `json:"table_default_page_size"` + TablePageSizeOptions []int `json:"table_page_size_options"` CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"` CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"` @@ -345,6 +349,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { if req.DefaultBalance < 0 { req.DefaultBalance = 0 } + // 通用表格配置:兼容旧客户端未传字段时保留当前值。 + if req.TableDefaultPageSize <= 0 { + req.TableDefaultPageSize = previousSettings.TableDefaultPageSize + } + if req.TablePageSizeOptions == nil { + req.TablePageSizeOptions = previousSettings.TablePageSizeOptions + } req.SMTPHost = strings.TrimSpace(req.SMTPHost) req.SMTPUsername = strings.TrimSpace(req.SMTPUsername) req.SMTPPassword = strings.TrimSpace(req.SMTPPassword) @@ -810,6 +821,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { HideCcsImportButton: req.HideCcsImportButton, PurchaseSubscriptionEnabled: purchaseEnabled, PurchaseSubscriptionURL: purchaseURL, + TableDefaultPageSize: req.TableDefaultPageSize, + TablePageSizeOptions: req.TablePageSizeOptions, CustomMenuItems: customMenuJSON, CustomEndpoints: customEndpointsJSON, DefaultConcurrency: req.DefaultConcurrency, @@ -989,6 +1002,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { HideCcsImportButton: updatedSettings.HideCcsImportButton, PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL, + TableDefaultPageSize: updatedSettings.TableDefaultPageSize, + TablePageSizeOptions: updatedSettings.TablePageSizeOptions, CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems), CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints), DefaultConcurrency: updatedSettings.DefaultConcurrency, @@ -1278,6 +1293,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.PurchaseSubscriptionURL != after.PurchaseSubscriptionURL { changed = append(changed, "purchase_subscription_url") } + if before.TableDefaultPageSize != after.TableDefaultPageSize { + changed = append(changed, "table_default_page_size") + } + if !equalIntSlice(before.TablePageSizeOptions, after.TablePageSizeOptions) { + changed = append(changed, "table_page_size_options") + } if before.CustomMenuItems != after.CustomMenuItems { changed = append(changed, "custom_menu_items") } @@ -1334,6 +1355,18 @@ func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool { return true } +func equalIntSlice(a, b []int) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + // TestSMTPRequest 测试SMTP连接请求 type TestSMTPRequest struct { SMTPHost string `json:"smtp_host"` diff --git a/backend/internal/handler/admin/usage_handler.go b/backend/internal/handler/admin/usage_handler.go index 2967b384..0857a138 100644 --- a/backend/internal/handler/admin/usage_handler.go +++ b/backend/internal/handler/admin/usage_handler.go @@ -165,7 +165,12 @@ func (h *UsageHandler) List(c *gin.Context) { endTime = &t } - params := pagination.PaginationParams{Page: page, PageSize: pageSize} + params := pagination.PaginationParams{ + Page: page, + PageSize: pageSize, + SortBy: c.DefaultQuery("sort_by", "created_at"), + SortOrder: c.DefaultQuery("sort_order", "desc"), + } filters := usagestats.UsageLogFilters{ UserID: userID, APIKeyID: apiKeyID, @@ -339,7 +344,7 @@ func (h *UsageHandler) SearchUsers(c *gin.Context) { } // Limit to 30 results - users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, service.UserListFilters{Search: keyword}) + users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, service.UserListFilters{Search: keyword}, "email", "asc") if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/admin/usage_handler_request_type_test.go b/backend/internal/handler/admin/usage_handler_request_type_test.go index 3f158316..882cbe93 100644 --- a/backend/internal/handler/admin/usage_handler_request_type_test.go +++ b/backend/internal/handler/admin/usage_handler_request_type_test.go @@ -15,11 +15,13 @@ import ( type adminUsageRepoCapture struct { service.UsageLogRepository + listParams pagination.PaginationParams listFilters usagestats.UsageLogFilters statsFilters usagestats.UsageLogFilters } func (s *adminUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { + s.listParams = params s.listFilters = filters return []service.UsageLog{}, &pagination.PaginationResult{ Total: 0, diff --git a/backend/internal/handler/admin/usage_handler_sort_test.go b/backend/internal/handler/admin/usage_handler_sort_test.go new file mode 100644 index 00000000..dac82676 --- /dev/null +++ b/backend/internal/handler/admin/usage_handler_sort_test.go @@ -0,0 +1,35 @@ +package admin + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAdminUsageListSortParams(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage?sort_by=model&sort_order=ASC", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "model", repo.listParams.SortBy) + require.Equal(t, "ASC", repo.listParams.SortOrder) +} + +func TestAdminUsageListSortDefaults(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "created_at", repo.listParams.SortBy) + require.Equal(t, "desc", repo.listParams.SortOrder) +} diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index a357657e..1453bd07 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -91,12 +91,14 @@ func (h *UserHandler) List(c *gin.Context) { GroupName: strings.TrimSpace(c.Query("group_name")), Attributes: parseAttributeFilters(c), } + sortBy := c.DefaultQuery("sort_by", "created_at") + sortOrder := c.DefaultQuery("sort_order", "desc") if raw, ok := c.GetQuery("include_subscriptions"); ok { includeSubscriptions := parseBoolQueryWithDefault(raw, true) filters.IncludeSubscriptions = &includeSubscriptions } - users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters) + users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, filters, sortBy, sortOrder) if err != nil { response.ErrorFrom(c, err) return @@ -290,8 +292,10 @@ func (h *UserHandler) GetUserAPIKeys(c *gin.Context) { } page, pageSize := response.ParsePagination(c) + sortBy := c.DefaultQuery("sort_by", "created_at") + sortOrder := c.DefaultQuery("sort_order", "desc") - keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize) + keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize, sortBy, sortOrder) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/api_key_handler.go b/backend/internal/handler/api_key_handler.go index 951aed08..9d6c6c15 100644 --- a/backend/internal/handler/api_key_handler.go +++ b/backend/internal/handler/api_key_handler.go @@ -72,7 +72,12 @@ func (h *APIKeyHandler) List(c *gin.Context) { } page, pageSize := response.ParsePagination(c) - params := pagination.PaginationParams{Page: page, PageSize: pageSize} + params := pagination.PaginationParams{ + Page: page, + PageSize: pageSize, + SortBy: c.DefaultQuery("sort_by", "created_at"), + SortOrder: c.DefaultQuery("sort_order", "desc"), + } // Parse filter parameters var filters service.APIKeyListFilters diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index a791051d..cbbe9216 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -84,6 +84,8 @@ type SystemSettings struct { HideCcsImportButton bool `json:"hide_ccs_import_button"` PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL string `json:"purchase_subscription_url"` + TableDefaultPageSize int `json:"table_default_page_size"` + TablePageSizeOptions []int `json:"table_page_size_options"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` @@ -170,6 +172,8 @@ type PublicSettings struct { HideCcsImportButton bool `json:"hide_ccs_import_button"` PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL string `json:"purchase_subscription_url"` + TableDefaultPageSize int `json:"table_default_page_size"` + TablePageSizeOptions []int `json:"table_page_size_options"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 09a68783..54a92a8c 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -51,6 +51,8 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { HideCcsImportButton: settings.HideCcsImportButton, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, + TableDefaultPageSize: settings.TableDefaultPageSize, + TablePageSizeOptions: settings.TablePageSizeOptions, CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, diff --git a/backend/internal/handler/usage_handler.go b/backend/internal/handler/usage_handler.go index 483f5105..b8506154 100644 --- a/backend/internal/handler/usage_handler.go +++ b/backend/internal/handler/usage_handler.go @@ -119,7 +119,12 @@ func (h *UsageHandler) List(c *gin.Context) { endTime = &t } - params := pagination.PaginationParams{Page: page, PageSize: pageSize} + params := pagination.PaginationParams{ + Page: page, + PageSize: pageSize, + SortBy: c.DefaultQuery("sort_by", "created_at"), + SortOrder: c.DefaultQuery("sort_order", "desc"), + } filters := usagestats.UsageLogFilters{ UserID: subject.UserID, // Always filter by current user for security APIKeyID: apiKeyID, diff --git a/backend/internal/handler/usage_handler_request_type_test.go b/backend/internal/handler/usage_handler_request_type_test.go index 7c4c7913..b49ed59b 100644 --- a/backend/internal/handler/usage_handler_request_type_test.go +++ b/backend/internal/handler/usage_handler_request_type_test.go @@ -16,10 +16,12 @@ import ( type userUsageRepoCapture struct { service.UsageLogRepository + listParams pagination.PaginationParams listFilters usagestats.UsageLogFilters } func (s *userUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { + s.listParams = params s.listFilters = filters return []service.UsageLog{}, &pagination.PaginationResult{ Total: 0, diff --git a/backend/internal/handler/usage_handler_sort_test.go b/backend/internal/handler/usage_handler_sort_test.go new file mode 100644 index 00000000..1af313b0 --- /dev/null +++ b/backend/internal/handler/usage_handler_sort_test.go @@ -0,0 +1,35 @@ +package handler + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUserUsageListSortParams(t *testing.T) { + repo := &userUsageRepoCapture{} + router := newUserUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/usage?sort_by=model&sort_order=ASC", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "model", repo.listParams.SortBy) + require.Equal(t, "ASC", repo.listParams.SortOrder) +} + +func TestUserUsageListSortDefaults(t *testing.T) { + repo := &userUsageRepoCapture{} + router := newUserUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/usage", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "created_at", repo.listParams.SortBy) + require.Equal(t, "desc", repo.listParams.SortOrder) +} diff --git a/backend/internal/pkg/pagination/pagination.go b/backend/internal/pkg/pagination/pagination.go index c162588a..ce8e74b8 100644 --- a/backend/internal/pkg/pagination/pagination.go +++ b/backend/internal/pkg/pagination/pagination.go @@ -1,10 +1,19 @@ // Package pagination provides types and helpers for paginated responses. package pagination +import "strings" + +const ( + SortOrderAsc = "asc" + SortOrderDesc = "desc" +) + // PaginationParams 分页参数 type PaginationParams struct { - Page int - PageSize int + Page int + PageSize int + SortBy string + SortOrder string } // PaginationResult 分页结果 @@ -18,8 +27,9 @@ type PaginationResult struct { // DefaultPagination 默认分页参数 func DefaultPagination() PaginationParams { return PaginationParams{ - Page: 1, - PageSize: 20, + Page: 1, + PageSize: 20, + SortOrder: SortOrderDesc, } } @@ -36,8 +46,32 @@ func (p PaginationParams) Limit() int { if p.PageSize < 1 { return 20 } - if p.PageSize > 100 { - return 100 + if p.PageSize > 1000 { + return 1000 } return p.PageSize } + +// NormalizeSortOrder normalizes sort order to asc/desc and falls back to defaultOrder. +func NormalizeSortOrder(order string, defaultOrder string) string { + switch strings.ToLower(strings.TrimSpace(defaultOrder)) { + case SortOrderAsc: + defaultOrder = SortOrderAsc + default: + defaultOrder = SortOrderDesc + } + + switch strings.ToLower(strings.TrimSpace(order)) { + case SortOrderAsc: + return SortOrderAsc + case SortOrderDesc: + return SortOrderDesc + default: + return defaultOrder + } +} + +// NormalizedSortOrder returns the normalized sort order using defaultOrder as fallback. +func (p PaginationParams) NormalizedSortOrder(defaultOrder string) string { + return NormalizeSortOrder(p.SortOrder, defaultOrder) +} diff --git a/backend/internal/pkg/pagination/pagination_test.go b/backend/internal/pkg/pagination/pagination_test.go new file mode 100644 index 00000000..9a3b069d --- /dev/null +++ b/backend/internal/pkg/pagination/pagination_test.go @@ -0,0 +1,71 @@ +package pagination + +import "testing" + +func TestNormalizeSortOrder(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + defaultOrder string + want string + }{ + {name: "asc", input: "asc", defaultOrder: "desc", want: "asc"}, + {name: "uppercase asc", input: "ASC", defaultOrder: "desc", want: "asc"}, + {name: "desc", input: "desc", defaultOrder: "asc", want: "desc"}, + {name: "trim spaces", input: " desc ", defaultOrder: "asc", want: "desc"}, + {name: "invalid falls back", input: "sideways", defaultOrder: "asc", want: "asc"}, + {name: "empty falls back", input: "", defaultOrder: "desc", want: "desc"}, + {name: "invalid default falls back to desc", input: "", defaultOrder: "wat", want: "desc"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := NormalizeSortOrder(tt.input, tt.defaultOrder); got != tt.want { + t.Fatalf("NormalizeSortOrder(%q, %q) = %q, want %q", tt.input, tt.defaultOrder, got, tt.want) + } + }) + } +} + +func TestPaginationParamsNormalizedSortOrder(t *testing.T) { + t.Parallel() + + params := PaginationParams{SortOrder: "ASC"} + if got := params.NormalizedSortOrder("desc"); got != "asc" { + t.Fatalf("NormalizedSortOrder = %q, want asc", got) + } + + params = PaginationParams{SortOrder: "bad"} + if got := params.NormalizedSortOrder("asc"); got != "asc" { + t.Fatalf("NormalizedSortOrder invalid fallback = %q, want asc", got) + } +} + +func TestPaginationParamsLimit(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + pageSize int + want int + }{ + {name: "non-positive falls back to default", pageSize: 0, want: 20}, + {name: "negative falls back to default", pageSize: -1, want: 20}, + {name: "normal value keeps", pageSize: 50, want: 50}, + {name: "max value keeps", pageSize: 1000, want: 1000}, + {name: "beyond max clamps to 1000", pageSize: 1500, want: 1000}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + p := PaginationParams{PageSize: tt.pageSize} + if got := p.Limit(); got != tt.want { + t.Fatalf("Limit() for PageSize=%d = %d, want %d", tt.pageSize, got, tt.want) + } + }) + } +} diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 14498715..24115c33 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -471,21 +471,58 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati case service.StatusActive: q = q.Where( dbaccount.StatusEQ(status), + dbaccount.SchedulableEQ(true), dbaccount.Or( dbaccount.RateLimitResetAtIsNil(), dbaccount.RateLimitResetAtLTE(time.Now()), ), + dbpredicate.Account(func(s *entsql.Selector) { + col := s.C("temp_unschedulable_until") + s.Where(entsql.Or( + entsql.IsNull(col), + entsql.LTE(col, entsql.Expr("NOW()")), + )) + }), ) case "rate_limited": - q = q.Where(dbaccount.RateLimitResetAtGT(time.Now())) + q = q.Where( + dbaccount.StatusEQ(service.StatusActive), + dbaccount.RateLimitResetAtGT(time.Now()), + dbpredicate.Account(func(s *entsql.Selector) { + col := s.C("temp_unschedulable_until") + s.Where(entsql.Or( + entsql.IsNull(col), + entsql.LTE(col, entsql.Expr("NOW()")), + )) + }), + ) case "temp_unschedulable": - q = q.Where(dbpredicate.Account(func(s *entsql.Selector) { - col := s.C("temp_unschedulable_until") - s.Where(entsql.And( - entsql.Not(entsql.IsNull(col)), - entsql.GT(col, entsql.Expr("NOW()")), - )) - })) + q = q.Where( + dbaccount.StatusEQ(service.StatusActive), + dbpredicate.Account(func(s *entsql.Selector) { + col := s.C("temp_unschedulable_until") + s.Where(entsql.And( + entsql.Not(entsql.IsNull(col)), + entsql.GT(col, entsql.Expr("NOW()")), + )) + }), + ) + case "unschedulable": + q = q.Where( + dbaccount.StatusEQ(service.StatusActive), + dbaccount.SchedulableEQ(false), + dbaccount.Or( + dbaccount.RateLimitResetAtIsNil(), + dbaccount.RateLimitResetAtLTE(time.Now()), + ), + dbpredicate.Account(func(s *entsql.Selector) { + col := s.C("temp_unschedulable_until") + s.Where(entsql.Or( + entsql.IsNull(col), + entsql.LTE(col, entsql.Expr("NOW()")), + )) + }), + ) default: q = q.Where(dbaccount.StatusEQ(status)) } @@ -518,11 +555,14 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati return nil, nil, err } - accounts, err := q. + accountsQuery := q. Offset(params.Offset()). - Limit(params.Limit()). - Order(dbent.Desc(dbaccount.FieldID)). - All(ctx) + Limit(params.Limit()) + for _, order := range accountListOrder(params) { + accountsQuery = accountsQuery.Order(order) + } + + accounts, err := accountsQuery.All(ctx) if err != nil { return nil, nil, err } @@ -534,6 +574,50 @@ func (r *accountRepository) ListWithFilters(ctx context.Context, params paginati return outAccounts, paginationResultFromTotal(int64(total), params), nil } +func accountListOrder(params pagination.PaginationParams) []func(*entsql.Selector) { + sortBy := strings.ToLower(strings.TrimSpace(params.SortBy)) + sortOrder := params.NormalizedSortOrder(pagination.SortOrderAsc) + + field := dbaccount.FieldName + defaultOrder := true + switch sortBy { + case "", "name": + field = dbaccount.FieldName + case "id": + field = dbaccount.FieldID + defaultOrder = false + case "status": + field = dbaccount.FieldStatus + defaultOrder = false + case "schedulable": + field = dbaccount.FieldSchedulable + defaultOrder = false + case "priority": + field = dbaccount.FieldPriority + defaultOrder = false + case "rate_multiplier": + field = dbaccount.FieldRateMultiplier + defaultOrder = false + case "last_used_at": + field = dbaccount.FieldLastUsedAt + defaultOrder = false + case "expires_at": + field = dbaccount.FieldExpiresAt + defaultOrder = false + case "created_at": + field = dbaccount.FieldCreatedAt + defaultOrder = false + } + + if sortOrder == pagination.SortOrderDesc { + return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbaccount.FieldID)} + } + if defaultOrder { + return []func(*entsql.Selector){dbent.Asc(dbaccount.FieldName), dbent.Asc(dbaccount.FieldID)} + } + return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbaccount.FieldID)} +} + func (r *accountRepository) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) { accounts, err := r.queryAccountsByGroup(ctx, groupID, accountGroupQueryOptions{ status: service.StatusActive, diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index f3e3f745..b249bb61 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -256,7 +256,7 @@ func (s *AccountRepoSuite) TestListWithFilters() { }, }, { - name: "filter_by_status_active_excludes_rate_limited", + name: "filter_by_status_active_excludes_runtime_blocked_accounts", setup: func(client *dbent.Client) { mustCreateAccount(s.T(), client, &service.Account{Name: "active-normal", Status: service.StatusActive}) rateLimited := mustCreateAccount(s.T(), client, &service.Account{Name: "active-rate-limited", Status: service.StatusActive}) @@ -264,6 +264,16 @@ func (s *AccountRepoSuite) TestListWithFilters() { SetRateLimitResetAt(time.Now().Add(10 * time.Minute)). Exec(context.Background()) s.Require().NoError(err) + tempUnsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-temp-unsched", Status: service.StatusActive}) + err = client.Account.UpdateOneID(tempUnsched.ID). + SetTempUnschedulableUntil(time.Now().Add(15 * time.Minute)). + Exec(context.Background()) + s.Require().NoError(err) + unsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-unsched", Status: service.StatusActive}) + err = client.Account.UpdateOneID(unsched.ID). + SetSchedulable(false). + Exec(context.Background()) + s.Require().NoError(err) }, status: service.StatusActive, wantCount: 1, @@ -271,6 +281,75 @@ func (s *AccountRepoSuite) TestListWithFilters() { s.Require().Equal("active-normal", accounts[0].Name) }, }, + { + name: "filter_by_status_unschedulable_excludes_rate_limited_and_temp_unschedulable", + setup: func(client *dbent.Client) { + mustCreateAccount(s.T(), client, &service.Account{Name: "active-normal", Status: service.StatusActive, Schedulable: true}) + unsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-unsched", Status: service.StatusActive}) + err := client.Account.UpdateOneID(unsched.ID). + SetSchedulable(false). + Exec(context.Background()) + s.Require().NoError(err) + rateLimited := mustCreateAccount(s.T(), client, &service.Account{Name: "active-rate-limited", Status: service.StatusActive}) + err = client.Account.UpdateOneID(rateLimited.ID). + SetSchedulable(false). + SetRateLimitResetAt(time.Now().Add(10 * time.Minute)). + Exec(context.Background()) + s.Require().NoError(err) + tempUnsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-temp-unsched", Status: service.StatusActive}) + err = client.Account.UpdateOneID(tempUnsched.ID). + SetSchedulable(false). + SetTempUnschedulableUntil(time.Now().Add(15 * time.Minute)). + Exec(context.Background()) + s.Require().NoError(err) + }, + status: "unschedulable", + wantCount: 1, + validate: func(accounts []service.Account) { + s.Require().Equal("active-unsched", accounts[0].Name) + }, + }, + { + name: "filter_by_status_rate_limited_excludes_temp_unschedulable", + setup: func(client *dbent.Client) { + rateLimited := mustCreateAccount(s.T(), client, &service.Account{Name: "active-rate-limited", Status: service.StatusActive}) + err := client.Account.UpdateOneID(rateLimited.ID). + SetRateLimitResetAt(time.Now().Add(10 * time.Minute)). + Exec(context.Background()) + s.Require().NoError(err) + tempUnsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-temp-unsched", Status: service.StatusActive}) + err = client.Account.UpdateOneID(tempUnsched.ID). + SetRateLimitResetAt(time.Now().Add(20 * time.Minute)). + SetTempUnschedulableUntil(time.Now().Add(15 * time.Minute)). + Exec(context.Background()) + s.Require().NoError(err) + }, + status: "rate_limited", + wantCount: 1, + validate: func(accounts []service.Account) { + s.Require().Equal("active-rate-limited", accounts[0].Name) + }, + }, + { + name: "filter_by_status_temp_unschedulable_excludes_manually_unschedulable", + setup: func(client *dbent.Client) { + tempUnsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-temp-unsched", Status: service.StatusActive, Schedulable: true}) + err := client.Account.UpdateOneID(tempUnsched.ID). + SetTempUnschedulableUntil(time.Now().Add(15 * time.Minute)). + Exec(context.Background()) + s.Require().NoError(err) + unsched := mustCreateAccount(s.T(), client, &service.Account{Name: "active-unsched", Status: service.StatusActive}) + err = client.Account.UpdateOneID(unsched.ID). + SetSchedulable(false). + Exec(context.Background()) + s.Require().NoError(err) + }, + status: "temp_unschedulable", + wantCount: 1, + validate: func(accounts []service.Account) { + s.Require().Equal("active-temp-unsched", accounts[0].Name) + }, + }, { name: "filter_by_search", setup: func(client *dbent.Client) { diff --git a/backend/internal/repository/account_repo_sort_integration_test.go b/backend/internal/repository/account_repo_sort_integration_test.go new file mode 100644 index 00000000..098dde7b --- /dev/null +++ b/backend/internal/repository/account_repo_sort_integration_test.go @@ -0,0 +1,35 @@ +//go:build integration + +package repository + +import ( + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func (s *AccountRepoSuite) TestList_DefaultSortByNameAsc() { + mustCreateAccount(s.T(), s.client, &service.Account{Name: "z-account"}) + mustCreateAccount(s.T(), s.client, &service.Account{Name: "a-account"}) + + accounts, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) + s.Require().NoError(err) + s.Require().Len(accounts, 2) + s.Require().Equal("a-account", accounts[0].Name) + s.Require().Equal("z-account", accounts[1].Name) +} + +func (s *AccountRepoSuite) TestListWithFilters_SortByPriorityDesc() { + mustCreateAccount(s.T(), s.client, &service.Account{Name: "low-priority", Priority: 10}) + mustCreateAccount(s.T(), s.client, &service.Account{Name: "high-priority", Priority: 90}) + + accounts, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{ + Page: 1, + PageSize: 10, + SortBy: "priority", + SortOrder: "desc", + }, "", "", "", "", 0, "") + s.Require().NoError(err) + s.Require().Len(accounts, 2) + s.Require().Equal("high-priority", accounts[0].Name) + s.Require().Equal("low-priority", accounts[1].Name) +} diff --git a/backend/internal/repository/announcement_repo.go b/backend/internal/repository/announcement_repo.go index 53dc335f..afe1fb25 100644 --- a/backend/internal/repository/announcement_repo.go +++ b/backend/internal/repository/announcement_repo.go @@ -2,12 +2,15 @@ package repository import ( "context" + "strings" "time" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" + + entsql "entgo.io/ent/dialect/sql" ) type announcementRepository struct { @@ -128,11 +131,14 @@ func (r *announcementRepository) List( return nil, nil, err } - items, err := q. + itemsQuery := q. Offset(params.Offset()). - Limit(params.Limit()). - Order(dbent.Desc(announcement.FieldID)). - All(ctx) + Limit(params.Limit()) + for _, order := range announcementListOrders(params) { + itemsQuery = itemsQuery.Order(order) + } + + items, err := itemsQuery.All(ctx) if err != nil { return nil, nil, err } @@ -141,6 +147,56 @@ func (r *announcementRepository) List( return out, paginationResultFromTotal(int64(total), params), nil } +func announcementListOrder(params pagination.PaginationParams) (string, string) { + sortBy := strings.ToLower(strings.TrimSpace(params.SortBy)) + sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc) + + switch sortBy { + case "title": + return announcement.FieldTitle, sortOrder + case "status": + return announcement.FieldStatus, sortOrder + case "notify_mode": + return announcement.FieldNotifyMode, sortOrder + case "starts_at": + return announcement.FieldStartsAt, sortOrder + case "ends_at": + return announcement.FieldEndsAt, sortOrder + case "id": + return announcement.FieldID, sortOrder + case "", "created_at": + return announcement.FieldCreatedAt, sortOrder + default: + return announcement.FieldCreatedAt, pagination.SortOrderDesc + } +} + +func announcementListOrders(params pagination.PaginationParams) []func(*entsql.Selector) { + field, sortOrder := announcementListOrder(params) + + if sortOrder == pagination.SortOrderAsc { + if field == announcement.FieldID { + return []func(*entsql.Selector){ + dbent.Asc(field), + } + } + return []func(*entsql.Selector){ + dbent.Asc(field), + dbent.Asc(announcement.FieldID), + } + } + + if field == announcement.FieldID { + return []func(*entsql.Selector){ + dbent.Desc(field), + } + } + return []func(*entsql.Selector){ + dbent.Desc(field), + dbent.Desc(announcement.FieldID), + } +} + func (r *announcementRepository) ListActive(ctx context.Context, now time.Time) ([]service.Announcement, error) { q := r.client.Announcement.Query(). Where( diff --git a/backend/internal/repository/announcement_repo_sort_test.go b/backend/internal/repository/announcement_repo_sort_test.go new file mode 100644 index 00000000..e47f98dc --- /dev/null +++ b/backend/internal/repository/announcement_repo_sort_test.go @@ -0,0 +1,63 @@ +package repository + +import ( + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" +) + +func TestAnnouncementListOrder(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + params pagination.PaginationParams + wantBy string + want string + }{ + { + name: "default created_at desc", + params: pagination.PaginationParams{}, + wantBy: "created_at", + want: "desc", + }, + { + name: "title asc", + params: pagination.PaginationParams{ + SortBy: "title", + SortOrder: "ASC", + }, + wantBy: "title", + want: "asc", + }, + { + name: "status desc", + params: pagination.PaginationParams{ + SortBy: "status", + SortOrder: "desc", + }, + wantBy: "status", + want: "desc", + }, + { + name: "invalid falls back", + params: pagination.PaginationParams{ + SortBy: "sideways", + SortOrder: "wat", + }, + wantBy: "created_at", + want: "desc", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + gotBy, gotOrder := announcementListOrder(tt.params) + if gotBy != tt.wantBy || gotOrder != tt.want { + t.Fatalf("announcementListOrder(%+v) = (%q, %q), want (%q, %q)", tt.params, gotBy, gotOrder, tt.wantBy, tt.want) + } + }) + } +} diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index b3b12e81..7fd98855 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "strings" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -14,6 +15,8 @@ import ( "github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + + entsql "entgo.io/ent/dialect/sql" ) type apiKeyRepository struct { @@ -164,6 +167,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se group.FieldSupportedModelScopes, group.FieldAllowMessagesDispatch, group.FieldDefaultMappedModel, + group.FieldMessagesDispatchModelConfig, ) }). Only(ctx) @@ -309,12 +313,15 @@ func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, param return nil, nil, err } - keys, err := q. + keysQuery := q. WithGroup(). Offset(params.Offset()). - Limit(params.Limit()). - Order(dbent.Desc(apikey.FieldID)). - All(ctx) + Limit(params.Limit()) + for _, order := range apiKeyListOrder(params) { + keysQuery = keysQuery.Order(order) + } + + keys, err := keysQuery.All(ctx) if err != nil { return nil, nil, err } @@ -359,12 +366,15 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par return nil, nil, err } - keys, err := q. + keysQuery := q. WithUser(). Offset(params.Offset()). - Limit(params.Limit()). - Order(dbent.Desc(apikey.FieldID)). - All(ctx) + Limit(params.Limit()) + for _, order := range apiKeyListOrder(params) { + keysQuery = keysQuery.Order(order) + } + + keys, err := keysQuery.All(ctx) if err != nil { return nil, nil, err } @@ -377,6 +387,32 @@ func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par return outKeys, paginationResultFromTotal(int64(total), params), nil } +func apiKeyListOrder(params pagination.PaginationParams) []func(*entsql.Selector) { + sortBy := strings.ToLower(strings.TrimSpace(params.SortBy)) + sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc) + + var field string + switch sortBy { + case "name": + field = apikey.FieldName + case "status": + field = apikey.FieldStatus + case "expires_at": + field = apikey.FieldExpiresAt + case "last_used_at": + field = apikey.FieldLastUsedAt + case "created_at": + field = apikey.FieldCreatedAt + default: + field = apikey.FieldID + } + + if sortOrder == pagination.SortOrderAsc { + return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(apikey.FieldID)} + } + return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(apikey.FieldID)} +} + // SearchAPIKeys searches API keys by user ID and/or keyword (name) func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) { q := r.activeQuery() @@ -654,6 +690,7 @@ func groupEntityToService(g *dbent.Group) *service.Group { RequireOAuthOnly: g.RequireOauthOnly, RequirePrivacySet: g.RequirePrivacySet, DefaultMappedModel: g.DefaultMappedModel, + MessagesDispatchModelConfig: g.MessagesDispatchModelConfig, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } diff --git a/backend/internal/repository/api_key_repo_integration_test.go b/backend/internal/repository/api_key_repo_integration_test.go index 7d5c1826..e926ed86 100644 --- a/backend/internal/repository/api_key_repo_integration_test.go +++ b/backend/internal/repository/api_key_repo_integration_test.go @@ -86,6 +86,45 @@ func (s *APIKeyRepoSuite) TestGetByKey_NotFound() { s.Require().Error(err, "expected error for non-existent key") } +func (s *APIKeyRepoSuite) TestGetByKeyForAuth_PreservesMessagesDispatchModelConfig() { + user := s.mustCreateUser("getbykey-auth-dispatch@test.com") + group, err := s.client.Group.Create(). + SetName("g-auth-dispatch"). + SetPlatform(service.PlatformOpenAI). + SetStatus(service.StatusActive). + SetSubscriptionType(service.SubscriptionTypeStandard). + SetRateMultiplier(1). + SetAllowMessagesDispatch(true). + SetDefaultMappedModel("gpt-5.4"). + SetMessagesDispatchModelConfig(service.OpenAIMessagesDispatchModelConfig{ + OpusMappedModel: "gpt-5.4-nano", + SonnetMappedModel: "gpt-5.3-codex", + HaikuMappedModel: "gpt-5.4-mini", + ExactModelMappings: map[string]string{ + "claude-sonnet-4.5": "gpt-5.4-nano", + }, + }). + Save(s.ctx) + s.Require().NoError(err) + + key := &service.APIKey{ + UserID: user.ID, + Key: "sk-getbykey-auth-dispatch", + Name: "Dispatch Key", + GroupID: &group.ID, + Status: service.StatusActive, + } + s.Require().NoError(s.repo.Create(s.ctx, key)) + + got, err := s.repo.GetByKeyForAuth(s.ctx, key.Key) + s.Require().NoError(err) + s.Require().NotNil(got.Group) + s.Require().True(got.Group.AllowMessagesDispatch) + s.Require().Equal("gpt-5.4", got.Group.DefaultMappedModel) + s.Require().Equal("gpt-5.4-nano", got.Group.MessagesDispatchModelConfig.OpusMappedModel) + s.Require().Equal("gpt-5.4-nano", got.Group.MessagesDispatchModelConfig.ExactModelMappings["claude-sonnet-4.5"]) +} + // --- Update --- func (s *APIKeyRepoSuite) TestUpdate() { diff --git a/backend/internal/repository/api_key_repo_messages_dispatch_unit_test.go b/backend/internal/repository/api_key_repo_messages_dispatch_unit_test.go new file mode 100644 index 00000000..aba62ead --- /dev/null +++ b/backend/internal/repository/api_key_repo_messages_dispatch_unit_test.go @@ -0,0 +1,74 @@ +package repository + +import ( + "context" + "testing" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestGroupEntityToService_PreservesMessagesDispatchModelConfig(t *testing.T) { + group := &dbent.Group{ + ID: 1, + Name: "openai-dispatch", + Platform: service.PlatformOpenAI, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + RateMultiplier: 1, + AllowMessagesDispatch: true, + DefaultMappedModel: "gpt-5.4", + MessagesDispatchModelConfig: service.OpenAIMessagesDispatchModelConfig{ + OpusMappedModel: "gpt-5.4-nano", + SonnetMappedModel: "gpt-5.3-codex", + HaikuMappedModel: "gpt-5.4-mini", + ExactModelMappings: map[string]string{ + "claude-sonnet-4.5": "gpt-5.4-nano", + }, + }, + } + + got := groupEntityToService(group) + require.NotNil(t, got) + require.Equal(t, group.MessagesDispatchModelConfig, got.MessagesDispatchModelConfig) +} + +func TestAPIKeyRepository_GetByKeyForAuth_PreservesMessagesDispatchModelConfig_SQLite(t *testing.T) { + repo, client := newAPIKeyRepoSQLite(t) + ctx := context.Background() + user := mustCreateAPIKeyRepoUser(t, ctx, client, "getbykey-auth-dispatch-unit@test.com") + + group, err := client.Group.Create(). + SetName("g-auth-dispatch-unit"). + SetPlatform(service.PlatformOpenAI). + SetStatus(service.StatusActive). + SetSubscriptionType(service.SubscriptionTypeStandard). + SetRateMultiplier(1). + SetAllowMessagesDispatch(true). + SetDefaultMappedModel("gpt-5.4"). + SetMessagesDispatchModelConfig(service.OpenAIMessagesDispatchModelConfig{ + OpusMappedModel: "gpt-5.4-nano", + SonnetMappedModel: "gpt-5.3-codex", + HaikuMappedModel: "gpt-5.4-mini", + ExactModelMappings: map[string]string{ + "claude-sonnet-4.5": "gpt-5.4-nano", + }, + }). + Save(ctx) + require.NoError(t, err) + + key := &service.APIKey{ + UserID: user.ID, + Key: "sk-getbykey-auth-dispatch-unit", + Name: "Dispatch Key Unit", + GroupID: &group.ID, + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, key)) + + got, err := repo.GetByKeyForAuth(ctx, key.Key) + require.NoError(t, err) + require.NotNil(t, got.Group) + require.Equal(t, group.MessagesDispatchModelConfig, got.Group.MessagesDispatchModelConfig) +} diff --git a/backend/internal/repository/api_key_repo_sort_integration_test.go b/backend/internal/repository/api_key_repo_sort_integration_test.go new file mode 100644 index 00000000..69812882 --- /dev/null +++ b/backend/internal/repository/api_key_repo_sort_integration_test.go @@ -0,0 +1,25 @@ +//go:build integration + +package repository + +import ( + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func (s *APIKeyRepoSuite) TestListByUserID_SortByNameAsc() { + user := s.mustCreateUser("sort-name@example.com") + s.mustCreateApiKey(user.ID, "sk-z", "z-key", nil) + s.mustCreateApiKey(user.ID, "sk-a", "a-key", nil) + + keys, _, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{ + Page: 1, + PageSize: 10, + SortBy: "name", + SortOrder: "asc", + }, service.APIKeyListFilters{}) + s.Require().NoError(err) + s.Require().Len(keys, 2) + s.Require().Equal("a-key", keys[0].Name) + s.Require().Equal("z-key", keys[1].Name) +} diff --git a/backend/internal/repository/channel_repo.go b/backend/internal/repository/channel_repo.go index 1e2c2e4c..49c2d8d9 100644 --- a/backend/internal/repository/channel_repo.go +++ b/backend/internal/repository/channel_repo.go @@ -188,8 +188,8 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati // 查询 channel 列表 dataQuery := fmt.Sprintf( `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.created_at, c.updated_at - FROM channels c WHERE %s ORDER BY c.id ASC LIMIT $%d OFFSET $%d`, - whereClause, argIdx, argIdx+1, + FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`, + whereClause, channelListOrderBy(params), argIdx, argIdx+1, ) args = append(args, pageSize, offset) @@ -246,6 +246,31 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati return channels, paginationResult, nil } +func channelListOrderBy(params pagination.PaginationParams) string { + sortBy := strings.ToLower(strings.TrimSpace(params.SortBy)) + sortOrder := strings.ToUpper(params.NormalizedSortOrder(pagination.SortOrderAsc)) + + var column string + switch sortBy { + case "": + column = "c.id" + sortOrder = "ASC" + case "id": + column = "c.id" + case "name": + column = "c.name" + case "status": + column = "c.status" + case "created_at": + column = "c.created_at" + default: + column = "c.id" + sortOrder = "ASC" + } + + return fmt.Sprintf("%s %s, c.id %s", column, sortOrder, sortOrder) +} + func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) { rows, err := r.db.QueryContext(ctx, `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at FROM channels ORDER BY id`, diff --git a/backend/internal/repository/channel_repo_test.go b/backend/internal/repository/channel_repo_test.go index 5a59948d..e761866d 100644 --- a/backend/internal/repository/channel_repo_test.go +++ b/backend/internal/repository/channel_repo_test.go @@ -8,6 +8,7 @@ import ( "fmt" "testing" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/lib/pq" "github.com/stretchr/testify/require" ) @@ -225,3 +226,12 @@ func TestIsUniqueViolation(t *testing.T) { }) } } + +func TestChannelListOrderBy_AllowsDescendingIDSort(t *testing.T) { + params := pagination.PaginationParams{ + SortBy: "id", + SortOrder: "desc", + } + + require.Equal(t, "c.id DESC, c.id DESC", channelListOrderBy(params)) +} diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 1803cf30..c17e3365 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "fmt" + "sort" "strings" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -14,6 +15,8 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/lib/pq" + + entsql "entgo.io/ent/dialect/sql" ) type sqlExecutor interface { @@ -40,6 +43,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetDescription(groupIn.Description). SetPlatform(groupIn.Platform). SetRateMultiplier(groupIn.RateMultiplier). + SetSortOrder(groupIn.SortOrder). SetIsExclusive(groupIn.IsExclusive). SetStatus(groupIn.Status). SetSubscriptionType(groupIn.SubscriptionType). @@ -233,11 +237,18 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination return nil, nil, err } - groups, err := q. + if strings.EqualFold(strings.TrimSpace(params.SortBy), "account_count") { + return r.listWithAccountCountSort(ctx, q, params, total) + } + + groupsQuery := q. Offset(params.Offset()). - Limit(params.Limit()). - Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)). - All(ctx) + Limit(params.Limit()) + for _, order := range groupListOrder(params) { + groupsQuery = groupsQuery.Order(order) + } + + groups, err := groupsQuery.All(ctx) if err != nil { return nil, nil, err } @@ -263,6 +274,104 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination return outGroups, paginationResultFromTotal(int64(total), params), nil } +func (r *groupRepository) listWithAccountCountSort(ctx context.Context, q *dbent.GroupQuery, params pagination.PaginationParams, total int) ([]service.Group, *pagination.PaginationResult, error) { + groups, err := q. + Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)). + All(ctx) + if err != nil { + return nil, nil, err + } + + groupIDs := make([]int64, 0, len(groups)) + outGroups := make([]service.Group, 0, len(groups)) + for i := range groups { + g := groupEntityToService(groups[i]) + outGroups = append(outGroups, *g) + groupIDs = append(groupIDs, g.ID) + } + + counts, err := r.loadAccountCounts(ctx, groupIDs) + if err != nil { + return nil, nil, err + } + for i := range outGroups { + c := counts[outGroups[i].ID] + outGroups[i].AccountCount = c.Total + outGroups[i].ActiveAccountCount = c.Active + outGroups[i].RateLimitedAccountCount = c.RateLimited + } + + sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc) + sort.SliceStable(outGroups, func(i, j int) bool { + if outGroups[i].AccountCount == outGroups[j].AccountCount { + if outGroups[i].SortOrder == outGroups[j].SortOrder { + return outGroups[i].ID < outGroups[j].ID + } + return outGroups[i].SortOrder < outGroups[j].SortOrder + } + if sortOrder == pagination.SortOrderAsc { + return outGroups[i].AccountCount < outGroups[j].AccountCount + } + return outGroups[i].AccountCount > outGroups[j].AccountCount + }) + + return paginateSlice(outGroups, params), paginationResultFromTotal(int64(total), params), nil +} + +func groupListOrder(params pagination.PaginationParams) []func(*entsql.Selector) { + sortBy := strings.ToLower(strings.TrimSpace(params.SortBy)) + sortOrder := params.NormalizedSortOrder(pagination.SortOrderAsc) + + var field string + tieField := group.FieldID + defaultOrder := true + switch sortBy { + case "", "sort_order": + field = group.FieldSortOrder + case "name": + field = group.FieldName + defaultOrder = false + case "platform": + field = group.FieldPlatform + defaultOrder = false + case "billing_type", "subscription_type": + field = group.FieldSubscriptionType + defaultOrder = false + case "rate_multiplier": + field = group.FieldRateMultiplier + defaultOrder = false + case "is_exclusive": + field = group.FieldIsExclusive + defaultOrder = false + case "status": + field = group.FieldStatus + defaultOrder = false + case "created_at": + field = group.FieldCreatedAt + defaultOrder = false + case "id": + field = group.FieldID + defaultOrder = false + tieField = "" + default: + field = group.FieldSortOrder + } + + if sortOrder == pagination.SortOrderDesc && sortBy != "" { + if tieField == "" { + return []func(*entsql.Selector){dbent.Desc(field)} + } + return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(tieField)} + } + if defaultOrder { + return []func(*entsql.Selector){dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)} + } + if tieField == "" { + return []func(*entsql.Selector){dbent.Asc(field)} + } + return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(tieField)} +} + func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, error) { groups, err := r.client.Group.Query(). Where(group.StatusEQ(service.StatusActive)). diff --git a/backend/internal/repository/group_repo_integration_test.go b/backend/internal/repository/group_repo_integration_test.go index eccf5cea..f91dae43 100644 --- a/backend/internal/repository/group_repo_integration_test.go +++ b/backend/internal/repository/group_repo_integration_test.go @@ -113,6 +113,33 @@ func (s *GroupRepoSuite) TestUpdate() { s.Require().Equal("updated", got.Name) } +func (s *GroupRepoSuite) TestGetByID_PreservesMessagesDispatchModelConfig() { + group := &service.Group{ + Name: "openai-dispatch", + Platform: service.PlatformOpenAI, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + AllowMessagesDispatch: true, + DefaultMappedModel: "gpt-5.4", + MessagesDispatchModelConfig: service.OpenAIMessagesDispatchModelConfig{ + OpusMappedModel: "gpt-5.4", + SonnetMappedModel: "gpt-5.3-codex", + HaikuMappedModel: "gpt-5.4-mini", + ExactModelMappings: map[string]string{ + "claude-sonnet-4.5": "gpt-5.4-nano", + }, + }, + } + + s.Require().NoError(s.repo.Create(s.ctx, group)) + + got, err := s.repo.GetByID(s.ctx, group.ID) + s.Require().NoError(err) + s.Require().Equal(group.MessagesDispatchModelConfig, got.MessagesDispatchModelConfig) +} + func (s *GroupRepoSuite) TestDelete() { group := &service.Group{ Name: "to-delete", diff --git a/backend/internal/repository/group_repo_sort_integration_test.go b/backend/internal/repository/group_repo_sort_integration_test.go new file mode 100644 index 00000000..85b2efcc --- /dev/null +++ b/backend/internal/repository/group_repo_sort_integration_test.go @@ -0,0 +1,50 @@ +//go:build integration + +package repository + +import ( + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func (s *GroupRepoSuite) TestList_DefaultSortBySortOrderAsc() { + g1 := &service.Group{Name: "g1", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 20} + g2 := &service.Group{Name: "g2", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 10} + s.Require().NoError(s.repo.Create(s.ctx, g1)) + s.Require().NoError(s.repo.Create(s.ctx, g2)) + + groups, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100}) + s.Require().NoError(err) + s.Require().GreaterOrEqual(len(groups), 2) + indexByID := make(map[int64]int, len(groups)) + for i, g := range groups { + indexByID[g.ID] = i + } + s.Require().Contains(indexByID, g1.ID) + s.Require().Contains(indexByID, g2.ID) + // g2 has SortOrder=10, g1 has SortOrder=20; ascending means g2 comes first + s.Require().Less(indexByID[g2.ID], indexByID[g1.ID]) +} + +func (s *GroupRepoSuite) TestList_SortBySortOrderDesc() { + g1 := &service.Group{Name: "g1", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 40} + g2 := &service.Group{Name: "g2", Platform: service.PlatformAnthropic, RateMultiplier: 1, Status: service.StatusActive, SubscriptionType: service.SubscriptionTypeStandard, SortOrder: 50} + s.Require().NoError(s.repo.Create(s.ctx, g1)) + s.Require().NoError(s.repo.Create(s.ctx, g2)) + + groups, _, err := s.repo.List(s.ctx, pagination.PaginationParams{ + Page: 1, + PageSize: 10, + SortBy: "sort_order", + SortOrder: "desc", + }) + s.Require().NoError(err) + s.Require().GreaterOrEqual(len(groups), 2) + indexByID := make(map[int64]int, len(groups)) + for i, group := range groups { + indexByID[group.ID] = i + } + s.Require().Contains(indexByID, g1.ID) + s.Require().Contains(indexByID, g2.ID) + s.Require().Less(indexByID[g2.ID], indexByID[g1.ID]) +} diff --git a/backend/internal/repository/pagination.go b/backend/internal/repository/pagination.go index ff08c34b..87c42a59 100644 --- a/backend/internal/repository/pagination.go +++ b/backend/internal/repository/pagination.go @@ -14,3 +14,22 @@ func paginationResultFromTotal(total int64, params pagination.PaginationParams) Pages: pages, } } + +func paginateSlice[T any](items []T, params pagination.PaginationParams) []T { + if len(items) == 0 { + return []T{} + } + + offset := params.Offset() + if offset >= len(items) { + return []T{} + } + + limit := params.Limit() + end := offset + limit + if end > len(items) { + end = len(items) + } + + return items[offset:end] +} diff --git a/backend/internal/repository/promo_code_repo.go b/backend/internal/repository/promo_code_repo.go index 95ce687a..d9c76bb3 100644 --- a/backend/internal/repository/promo_code_repo.go +++ b/backend/internal/repository/promo_code_repo.go @@ -2,12 +2,15 @@ package repository import ( "context" + "strings" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" + + entsql "entgo.io/ent/dialect/sql" ) type promoCodeRepository struct { @@ -137,11 +140,14 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina return nil, nil, err } - codes, err := q. + codesQuery := q. Offset(params.Offset()). - Limit(params.Limit()). - Order(dbent.Desc(promocode.FieldID)). - All(ctx) + Limit(params.Limit()) + for _, order := range promoCodeListOrder(params) { + codesQuery = codesQuery.Order(order) + } + + codes, err := codesQuery.All(ctx) if err != nil { return nil, nil, err } @@ -151,6 +157,32 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina return outCodes, paginationResultFromTotal(int64(total), params), nil } +func promoCodeListOrder(params pagination.PaginationParams) []func(*entsql.Selector) { + sortBy := strings.ToLower(strings.TrimSpace(params.SortBy)) + sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc) + + var field string + switch sortBy { + case "bonus_amount": + field = promocode.FieldBonusAmount + case "status": + field = promocode.FieldStatus + case "expires_at": + field = promocode.FieldExpiresAt + case "created_at": + field = promocode.FieldCreatedAt + case "code": + field = promocode.FieldCode + default: + field = promocode.FieldID + } + + if sortOrder == pagination.SortOrderAsc { + return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(promocode.FieldID)} + } + return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(promocode.FieldID)} +} + func (r *promoCodeRepository) CreateUsage(ctx context.Context, usage *service.PromoCodeUsage) error { client := clientFromContext(ctx, r.client) created, err := client.PromoCodeUsage.Create(). diff --git a/backend/internal/repository/proxy_repo.go b/backend/internal/repository/proxy_repo.go index 07c2a204..60b2f069 100644 --- a/backend/internal/repository/proxy_repo.go +++ b/backend/internal/repository/proxy_repo.go @@ -3,12 +3,16 @@ package repository import ( "context" "database/sql" + "sort" + "strings" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + + entsql "entgo.io/ent/dialect/sql" ) type sqlQuerier interface { @@ -135,11 +139,14 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination return nil, nil, err } - proxies, err := q. + proxiesQuery := q. Offset(params.Offset()). - Limit(params.Limit()). - Order(dbent.Desc(proxy.FieldID)). - All(ctx) + Limit(params.Limit()) + for _, order := range proxyListOrder(params) { + proxiesQuery = proxiesQuery.Order(order) + } + + proxies, err := proxiesQuery.All(ctx) if err != nil { return nil, nil, err } @@ -170,22 +177,58 @@ func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, pa return nil, nil, err } - proxies, err := q. + if strings.EqualFold(strings.TrimSpace(params.SortBy), "account_count") { + return r.listWithAccountCountSort(ctx, q, params, total) + } + + proxiesQuery := q. Offset(params.Offset()). - Limit(params.Limit()). + Limit(params.Limit()) + for _, order := range proxyListOrder(params) { + proxiesQuery = proxiesQuery.Order(order) + } + + proxies, err := proxiesQuery.All(ctx) + if err != nil { + return nil, nil, err + } + + return r.buildProxyWithAccountCountResult(ctx, proxies, params, int64(total)) +} + +func (r *proxyRepository) listWithAccountCountSort(ctx context.Context, q *dbent.ProxyQuery, params pagination.PaginationParams, total int) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) { + proxies, err := q. Order(dbent.Desc(proxy.FieldID)). All(ctx) if err != nil { return nil, nil, err } - // Get account counts + result, _, err := r.buildProxyWithAccountCountResult(ctx, proxies, params, int64(total)) + if err != nil { + return nil, nil, err + } + + sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc) + sort.SliceStable(result, func(i, j int) bool { + if result[i].AccountCount == result[j].AccountCount { + return result[i].ID > result[j].ID + } + if sortOrder == pagination.SortOrderAsc { + return result[i].AccountCount < result[j].AccountCount + } + return result[i].AccountCount > result[j].AccountCount + }) + + return paginateSlice(result, params), paginationResultFromTotal(int64(total), params), nil +} + +func (r *proxyRepository) buildProxyWithAccountCountResult(ctx context.Context, proxies []*dbent.Proxy, params pagination.PaginationParams, total int64) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) { counts, err := r.GetAccountCountsForProxies(ctx) if err != nil { return nil, nil, err } - // Build result with account counts result := make([]service.ProxyWithAccountCount, 0, len(proxies)) for i := range proxies { proxyOut := proxyEntityToService(proxies[i]) @@ -198,7 +241,31 @@ func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, pa }) } - return result, paginationResultFromTotal(int64(total), params), nil + return result, paginationResultFromTotal(total, params), nil +} + +func proxyListOrder(params pagination.PaginationParams) []func(*entsql.Selector) { + sortBy := strings.ToLower(strings.TrimSpace(params.SortBy)) + sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc) + + var field string + switch sortBy { + case "name": + field = proxy.FieldName + case "protocol": + field = proxy.FieldProtocol + case "status": + field = proxy.FieldStatus + case "created_at": + field = proxy.FieldCreatedAt + default: + field = proxy.FieldID + } + + if sortOrder == pagination.SortOrderAsc { + return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(proxy.FieldID)} + } + return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(proxy.FieldID)} } func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) { diff --git a/backend/internal/repository/proxy_repo_sort_integration_test.go b/backend/internal/repository/proxy_repo_sort_integration_test.go new file mode 100644 index 00000000..fe1c2873 --- /dev/null +++ b/backend/internal/repository/proxy_repo_sort_integration_test.go @@ -0,0 +1,28 @@ +//go:build integration + +package repository + +import ( + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func (s *ProxyRepoSuite) TestListWithFiltersAndAccountCount_SortByAccountCountDesc() { + p1 := s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive}) + p2 := s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive}) + s.mustInsertAccount("a1", &p1.ID) + s.mustInsertAccount("a2", &p1.ID) + s.mustInsertAccount("a3", &p2.ID) + + proxies, _, err := s.repo.ListWithFiltersAndAccountCount(s.ctx, pagination.PaginationParams{ + Page: 1, + PageSize: 10, + SortBy: "account_count", + SortOrder: "desc", + }, "", "", "") + s.Require().NoError(err) + s.Require().Len(proxies, 2) + s.Require().Equal(p1.ID, proxies[0].ID) + s.Require().Equal(int64(2), proxies[0].AccountCount) + s.Require().Equal(p2.ID, proxies[1].ID) +} diff --git a/backend/internal/repository/redeem_code_repo.go b/backend/internal/repository/redeem_code_repo.go index 934a3095..07975970 100644 --- a/backend/internal/repository/redeem_code_repo.go +++ b/backend/internal/repository/redeem_code_repo.go @@ -2,6 +2,7 @@ package repository import ( "context" + "strings" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -9,6 +10,8 @@ import ( "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" + + entsql "entgo.io/ent/dialect/sql" ) type redeemCodeRepository struct { @@ -120,13 +123,16 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin return nil, nil, err } - codes, err := q. + codesQuery := q. WithUser(). WithGroup(). Offset(params.Offset()). - Limit(params.Limit()). - Order(dbent.Desc(redeemcode.FieldID)). - All(ctx) + Limit(params.Limit()) + for _, order := range redeemCodeListOrder(params) { + codesQuery = codesQuery.Order(order) + } + + codes, err := codesQuery.All(ctx) if err != nil { return nil, nil, err } @@ -136,6 +142,34 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin return outCodes, paginationResultFromTotal(int64(total), params), nil } +func redeemCodeListOrder(params pagination.PaginationParams) []func(*entsql.Selector) { + sortBy := strings.ToLower(strings.TrimSpace(params.SortBy)) + sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc) + + var field string + switch sortBy { + case "type": + field = redeemcode.FieldType + case "value": + field = redeemcode.FieldValue + case "status": + field = redeemcode.FieldStatus + case "used_at": + field = redeemcode.FieldUsedAt + case "created_at": + field = redeemcode.FieldCreatedAt + case "code": + field = redeemcode.FieldCode + default: + field = redeemcode.FieldID + } + + if sortOrder == pagination.SortOrderAsc { + return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(redeemcode.FieldID)} + } + return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(redeemcode.FieldID)} +} + func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemCode) error { up := r.client.RedeemCode.UpdateOneID(code.ID). SetCode(code.Code). diff --git a/backend/internal/repository/redeem_code_repo_sort_integration_test.go b/backend/internal/repository/redeem_code_repo_sort_integration_test.go new file mode 100644 index 00000000..30d32f4c --- /dev/null +++ b/backend/internal/repository/redeem_code_repo_sort_integration_test.go @@ -0,0 +1,24 @@ +//go:build integration + +package repository + +import ( + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func (s *RedeemCodeRepoSuite) TestListWithFilters_SortByValueAsc() { + s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "VALUE-20", Type: service.RedeemTypeBalance, Value: 20, Status: service.StatusUnused})) + s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "VALUE-10", Type: service.RedeemTypeBalance, Value: 10, Status: service.StatusUnused})) + + codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{ + Page: 1, + PageSize: 10, + SortBy: "value", + SortOrder: "asc", + }, "", "", "") + s.Require().NoError(err) + s.Require().Len(codes, 2) + s.Require().Equal("VALUE-10", codes[0].Code) + s.Require().Equal("VALUE-20", codes[1].Code) +} diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index d7bcd094..3ba2191e 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -3771,7 +3771,7 @@ func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, wh limitPos := len(args) + 1 offsetPos := len(args) + 2 listArgs := append(append([]any{}, args...), params.Limit(), params.Offset()) - query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos) + query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY %s LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, usageLogOrderBy(params), limitPos, offsetPos) logs, err := r.queryUsageLogs(ctx, query, listArgs...) if err != nil { return nil, nil, err @@ -3786,7 +3786,7 @@ func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context limitPos := len(args) + 1 offsetPos := len(args) + 2 listArgs := append(append([]any{}, args...), limit+1, offset) - query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY id DESC LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, limitPos, offsetPos) + query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY %s LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, usageLogOrderBy(params), limitPos, offsetPos) logs, err := r.queryUsageLogs(ctx, query, listArgs...) if err != nil { @@ -3808,6 +3808,26 @@ func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context return logs, paginationResultFromTotal(total, params), nil } +func usageLogOrderBy(params pagination.PaginationParams) string { + sortBy := strings.ToLower(strings.TrimSpace(params.SortBy)) + sortOrder := strings.ToUpper(params.NormalizedSortOrder(pagination.SortOrderDesc)) + + var column string + switch sortBy { + case "model": + column = "COALESCE(NULLIF(TRIM(requested_model), ''), model)" + case "created_at": + column = "created_at" + default: + column = "id" + } + + if column == "id" { + return fmt.Sprintf("id %s", sortOrder) + } + return fmt.Sprintf("%s %s, id %s", column, sortOrder, sortOrder) +} + func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) { rows, err := r.sql.QueryContext(ctx, query, args...) if err != nil { diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go index ce0c5f00..b9cb6a13 100644 --- a/backend/internal/repository/usage_log_repo_request_type_test.go +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -330,6 +330,15 @@ func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T) "total_account_cost", "avg_duration_ms", }).AddRow(int64(1), int64(2), int64(3), int64(4), 1.2, 1.0, 1.2, 20.0)) + mock.ExpectQuery("SELECT COALESCE\\(NULLIF\\(TRIM\\(inbound_endpoint\\), ''\\), 'unknown'\\) AS endpoint"). + WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), requestType). + WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"})) + mock.ExpectQuery("SELECT COALESCE\\(NULLIF\\(TRIM\\(upstream_endpoint\\), ''\\), 'unknown'\\) AS endpoint"). + WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), requestType). + WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"})) + mock.ExpectQuery("SELECT CONCAT\\("). + WithArgs(sqlmock.AnyArg(), sqlmock.AnyArg(), requestType). + WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"})) stats, err := repo.GetStatsWithFilters(context.Background(), filters) require.NoError(t, err) diff --git a/backend/internal/repository/usage_log_repo_sort_integration_test.go b/backend/internal/repository/usage_log_repo_sort_integration_test.go new file mode 100644 index 00000000..4c69f975 --- /dev/null +++ b/backend/internal/repository/usage_log_repo_sort_integration_test.go @@ -0,0 +1,61 @@ +//go:build integration + +package repository + +import ( + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/google/uuid" +) + +func (s *UsageLogRepoSuite) TestListWithFilters_SortByModelAsc() { + user := mustCreateUser(s.T(), s.client, &service.User{Email: "usage-sort@example.com"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-usage-sort", Name: "k"}) + account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "usage-sort-account"}) + + first := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.New().String(), + Model: "z-model", + RequestedModel: "z-model", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now(), + } + _, err := s.repo.Create(s.ctx, first) + s.Require().NoError(err) + + second := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.New().String(), + Model: "a-model", + RequestedModel: "a-model", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 0.5, + ActualCost: 0.5, + CreatedAt: time.Now().Add(time.Second), + } + _, err = s.repo.Create(s.ctx, second) + s.Require().NoError(err) + + logs, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{ + Page: 1, + PageSize: 10, + SortBy: "model", + SortOrder: "asc", + }, usagestats.UsageLogFilters{UserID: user.ID}) + s.Require().NoError(err) + s.Require().Len(logs, 2) + s.Require().Equal("a-model", logs[0].RequestedModel) + s.Require().Equal("z-model", logs[1].RequestedModel) +} diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 06c79113..d5a13607 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -17,6 +17,8 @@ import ( "github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" + + entsql "entgo.io/ent/dialect/sql" ) type userRepository struct { @@ -224,11 +226,14 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. return nil, nil, err } - users, err := q. + usersQuery := q. Offset(params.Offset()). - Limit(params.Limit()). - Order(dbent.Desc(dbuser.FieldID)). - All(ctx) + Limit(params.Limit()) + for _, order := range userListOrder(params) { + usersQuery = usersQuery.Order(order) + } + + users, err := usersQuery.All(ctx) if err != nil { return nil, nil, err } @@ -281,6 +286,50 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination. return outUsers, paginationResultFromTotal(int64(total), params), nil } +func userListOrder(params pagination.PaginationParams) []func(*entsql.Selector) { + sortBy := strings.ToLower(strings.TrimSpace(params.SortBy)) + sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc) + + var field string + defaultField := true + switch sortBy { + case "email": + field = dbuser.FieldEmail + defaultField = false + case "username": + field = dbuser.FieldUsername + defaultField = false + case "role": + field = dbuser.FieldRole + defaultField = false + case "balance": + field = dbuser.FieldBalance + defaultField = false + case "concurrency": + field = dbuser.FieldConcurrency + defaultField = false + case "status": + field = dbuser.FieldStatus + defaultField = false + case "created_at": + field = dbuser.FieldCreatedAt + defaultField = false + default: + field = dbuser.FieldID + } + + if sortOrder == pagination.SortOrderAsc { + if defaultField && field == dbuser.FieldID { + return []func(*entsql.Selector){dbent.Asc(dbuser.FieldID)} + } + return []func(*entsql.Selector){dbent.Asc(field), dbent.Asc(dbuser.FieldID)} + } + if defaultField && field == dbuser.FieldID { + return []func(*entsql.Selector){dbent.Desc(dbuser.FieldID)} + } + return []func(*entsql.Selector){dbent.Desc(field), dbent.Desc(dbuser.FieldID)} +} + // filterUsersByAttributes returns user IDs that match ALL the given attribute filters func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[int64]string) ([]int64, error) { if len(attrs) == 0 { diff --git a/backend/internal/repository/user_repo_sort_integration_test.go b/backend/internal/repository/user_repo_sort_integration_test.go new file mode 100644 index 00000000..ab84b0e9 --- /dev/null +++ b/backend/internal/repository/user_repo_sort_integration_test.go @@ -0,0 +1,39 @@ +//go:build integration + +package repository + +import ( + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func (s *UserRepoSuite) TestListWithFilters_SortByEmailAsc() { + s.mustCreateUser(&service.User{Email: "z-last@example.com", Username: "z-user"}) + s.mustCreateUser(&service.User{Email: "a-first@example.com", Username: "a-user"}) + + users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{ + Page: 1, + PageSize: 10, + SortBy: "email", + SortOrder: "asc", + }, service.UserListFilters{}) + s.Require().NoError(err) + s.Require().Len(users, 2) + s.Require().Equal("a-first@example.com", users[0].Email) + s.Require().Equal("z-last@example.com", users[1].Email) +} + +func (s *UserRepoSuite) TestList_DefaultSortByNewestFirst() { + first := s.mustCreateUser(&service.User{Email: "first@example.com"}) + second := s.mustCreateUser(&service.User{Email: "second@example.com"}) + + users, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}) + s.Require().NoError(err) + s.Require().Len(users, 2) + s.Require().Equal(second.ID, users[0].ID) + s.Require().Equal(first.ID, users[1].ID) +} + +func TestUserRepoSortSuiteSmoke(_ *testing.T) {} diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index be2fef38..1a4892fa 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -491,8 +491,10 @@ func TestAPIContracts(t *testing.T) { service.SettingKeyContactInfo: "support", service.SettingKeyDocURL: "https://docs.example.com", - service.SettingKeyDefaultConcurrency: "5", - service.SettingKeyDefaultBalance: "1.25", + service.SettingKeyDefaultConcurrency: "5", + service.SettingKeyDefaultBalance: "1.25", + service.SettingKeyTableDefaultPageSize: "20", + service.SettingKeyTablePageSizeOptions: "[10,20,50,100]", service.SettingKeyOpsMonitoringEnabled: "false", service.SettingKeyOpsRealtimeMonitoringEnabled: "true", @@ -576,6 +578,8 @@ func TestAPIContracts(t *testing.T) { "hide_ccs_import_button": false, "purchase_subscription_enabled": false, "purchase_subscription_url": "", + "table_default_page_size": 20, + "table_page_size_options": [10, 20, 50, 100], "min_claude_code_version": "", "max_claude_code_version": "", "allow_ungrouped_key_scheduling": false, diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index c2553eee..97b42c24 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -21,13 +21,13 @@ import ( // AdminService interface defines admin management operations type AdminService interface { // User management - ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) + ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters, sortBy, sortOrder string) ([]User, int64, error) GetUser(ctx context.Context, id int64) (*User, error) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) DeleteUser(ctx context.Context, id int64) error UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) - GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) + GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) // GetUserBalanceHistory returns paginated balance/concurrency change records for a user. // codeType is optional - pass empty string to return all types. @@ -35,7 +35,7 @@ type AdminService interface { GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) // Group management - ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) + ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) GetAllGroups(ctx context.Context) ([]Group, error) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) GetGroup(ctx context.Context, id int64) (*Group, error) @@ -55,7 +55,7 @@ type AdminService interface { ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) // Account management - ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error) + ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]Account, int64, error) GetAccount(ctx context.Context, id int64) (*Account, error) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) @@ -77,8 +77,8 @@ type AdminService interface { CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error // Proxy management - ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) - ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) + ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]Proxy, int64, error) + ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]ProxyWithAccountCount, int64, error) GetAllProxies(ctx context.Context) ([]Proxy, error) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) GetProxy(ctx context.Context, id int64) (*Proxy, error) @@ -93,7 +93,7 @@ type AdminService interface { CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error) // Redeem code management - ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) + ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string, sortBy, sortOrder string) ([]RedeemCode, int64, error) GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) DeleteRedeemCode(ctx context.Context, id int64) error @@ -485,8 +485,8 @@ func NewAdminService( } // User management implementations -func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) { - params := pagination.PaginationParams{Page: page, PageSize: pageSize} +func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters, sortBy, sortOrder string) ([]User, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder} users, result, err := s.userRepo.ListWithFilters(ctx, params, filters) if err != nil { return nil, 0, err @@ -753,8 +753,8 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, return user, nil } -func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) { - params := pagination.PaginationParams{Page: page, PageSize: pageSize} +func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder} keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{}) if err != nil { return nil, 0, err @@ -789,8 +789,8 @@ func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int } // Group management implementations -func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) { - params := pagination.PaginationParams{Page: page, PageSize: pageSize} +func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder} groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive) if err != nil { return nil, 0, err @@ -1464,8 +1464,8 @@ func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGrou } // Account management implementations -func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string) ([]Account, int64, error) { - params := pagination.PaginationParams{Page: page, PageSize: pageSize} +func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]Account, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder} accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search, groupID, privacyMode) if err != nil { return nil, 0, err @@ -1893,8 +1893,8 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, } // Proxy management implementations -func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) { - params := pagination.PaginationParams{Page: page, PageSize: pageSize} +func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]Proxy, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder} proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search) if err != nil { return nil, 0, err @@ -1902,8 +1902,8 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, return proxies, result.Total, nil } -func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) { - params := pagination.PaginationParams{Page: page, PageSize: pageSize} +func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string, sortBy, sortOrder string) ([]ProxyWithAccountCount, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder} proxies, result, err := s.proxyRepo.ListWithFiltersAndAccountCount(ctx, params, protocol, status, search) if err != nil { return nil, 0, err @@ -2040,8 +2040,8 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po } // Redeem code management implementations -func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) { - params := pagination.PaginationParams{Page: page, PageSize: pageSize} +func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string, sortBy, sortOrder string) ([]RedeemCode, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder} codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search) if err != nil { return nil, 0, err diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index 364022bd..a4c6d0ca 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -125,6 +125,22 @@ func (s *groupRepoStubForAdmin) UpdateSortOrders(_ context.Context, _ []GroupSor return nil } +func TestAdminService_ListGroups_PassesSortParams(t *testing.T) { + repo := &groupRepoStubForAdmin{ + listWithFiltersGroups: []Group{{ID: 1, Name: "g1"}}, + } + svc := &adminServiceImpl{groupRepo: repo} + + _, _, err := svc.ListGroups(context.Background(), 3, 25, PlatformOpenAI, StatusActive, "needle", nil, "account_count", "ASC") + require.NoError(t, err) + require.Equal(t, pagination.PaginationParams{ + Page: 3, + PageSize: 25, + SortBy: "account_count", + SortOrder: "ASC", + }, repo.listWithFiltersParams) +} + // TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递 func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) { repo := &groupRepoStubForAdmin{} @@ -373,7 +389,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) { } svc := &adminServiceImpl{groupRepo: repo} - groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil) + groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil, "", "") require.NoError(t, err) require.Equal(t, int64(1), total) require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups) @@ -391,7 +407,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) { } svc := &adminServiceImpl{groupRepo: repo} - groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil) + groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil, "", "") require.NoError(t, err) require.Empty(t, groups) require.Equal(t, int64(0), total) @@ -410,7 +426,7 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) { } svc := &adminServiceImpl{groupRepo: repo} - groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive) + groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive, "", "") require.NoError(t, err) require.Equal(t, int64(42), total) require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups) diff --git a/backend/internal/service/admin_service_list_users_test.go b/backend/internal/service/admin_service_list_users_test.go index 37f348df..ceeb52c2 100644 --- a/backend/internal/service/admin_service_list_users_test.go +++ b/backend/internal/service/admin_service_list_users_test.go @@ -13,11 +13,13 @@ import ( type userRepoStubForListUsers struct { userRepoStub - users []User - err error + users []User + err error + listWithFiltersParams pagination.PaginationParams } func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) { + s.listWithFiltersParams = params if s.err != nil { return nil, nil, s.err } @@ -103,7 +105,7 @@ func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) { userGroupRateRepo: rateRepo, } - users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{}) + users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{}, "", "") require.NoError(t, err) require.Equal(t, int64(2), total) require.Len(t, users, 2) @@ -112,3 +114,19 @@ func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) { require.Equal(t, 1.1, users[0].GroupRates[11]) require.Equal(t, 2.2, users[1].GroupRates[22]) } + +func TestAdminService_ListUsers_PassesSortParams(t *testing.T) { + userRepo := &userRepoStubForListUsers{ + users: []User{{ID: 1, Email: "a@example.com"}}, + } + svc := &adminServiceImpl{userRepo: userRepo} + + _, _, err := svc.ListUsers(context.Background(), 2, 50, UserListFilters{}, "email", "ASC") + require.NoError(t, err) + require.Equal(t, pagination.PaginationParams{ + Page: 2, + PageSize: 50, + SortBy: "email", + SortOrder: "ASC", + }, userRepo.listWithFiltersParams) +} diff --git a/backend/internal/service/admin_service_search_test.go b/backend/internal/service/admin_service_search_test.go index eb213e6a..595e99e3 100644 --- a/backend/internal/service/admin_service_search_test.go +++ b/backend/internal/service/admin_service_search_test.go @@ -170,13 +170,13 @@ func TestAdminService_ListAccounts_WithSearch(t *testing.T) { } svc := &adminServiceImpl{accountRepo: repo} - accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0, "") + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0, "", "name", "ASC") require.NoError(t, err) require.Equal(t, int64(10), total) require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts) require.Equal(t, 1, repo.listWithFiltersCalls) - require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams) + require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20, SortBy: "name", SortOrder: "ASC"}, repo.listWithFiltersParams) require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform) require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType) require.Equal(t, StatusActive, repo.listWithFiltersStatus) @@ -192,7 +192,7 @@ func TestAdminService_ListAccounts_WithPrivacyMode(t *testing.T) { } svc := &adminServiceImpl{accountRepo: repo} - accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, StatusActive, "acc2", 0, PrivacyModeCFBlocked) + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, StatusActive, "acc2", 0, PrivacyModeCFBlocked, "", "") require.NoError(t, err) require.Equal(t, int64(1), total) require.Equal(t, []Account{{ID: 2, Name: "acc2"}}, accounts) @@ -208,13 +208,13 @@ func TestAdminService_ListProxies_WithSearch(t *testing.T) { } svc := &adminServiceImpl{proxyRepo: repo} - proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1") + proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1", "name", "ASC") require.NoError(t, err) require.Equal(t, int64(7), total) require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies) require.Equal(t, 1, repo.listWithFiltersCalls) - require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams) + require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50, SortBy: "name", SortOrder: "ASC"}, repo.listWithFiltersParams) require.Equal(t, "http", repo.listWithFiltersProtocol) require.Equal(t, StatusActive, repo.listWithFiltersStatus) require.Equal(t, "p1", repo.listWithFiltersSearch) @@ -229,13 +229,13 @@ func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) { } svc := &adminServiceImpl{proxyRepo: repo} - proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2") + proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2", "account_count", "DESC") require.NoError(t, err) require.Equal(t, int64(9), total) require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies) require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls) - require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersAndAccountCountParams) + require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10, SortBy: "account_count", SortOrder: "DESC"}, repo.listWithFiltersAndAccountCountParams) require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol) require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus) require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch) @@ -250,13 +250,13 @@ func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) { } svc := &adminServiceImpl{redeemCodeRepo: repo} - codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC") + codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC", "value", "ASC") require.NoError(t, err) require.Equal(t, int64(3), total) require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes) require.Equal(t, 1, repo.listWithFiltersCalls) - require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams) + require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20, SortBy: "value", SortOrder: "ASC"}, repo.listWithFiltersParams) require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType) require.Equal(t, StatusUnused, repo.listWithFiltersStatus) require.Equal(t, "ABC", repo.listWithFiltersSearch) diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index ad6ba0e9..c2e96df1 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -4,6 +4,7 @@ import "time" // APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段) type APIKeyAuthSnapshot struct { + Version int `json:"version"` APIKeyID int64 `json:"api_key_id"` UserID int64 `json:"user_id"` GroupID *int64 `json:"group_id,omitempty"` @@ -63,8 +64,9 @@ type APIKeyAuthGroupSnapshot struct { SupportedModelScopes []string `json:"supported_model_scopes,omitempty"` // OpenAI Messages 调度配置(仅 openai 平台使用) - AllowMessagesDispatch bool `json:"allow_messages_dispatch"` - DefaultMappedModel string `json:"default_mapped_model,omitempty"` + AllowMessagesDispatch bool `json:"allow_messages_dispatch"` + DefaultMappedModel string `json:"default_mapped_model,omitempty"` + MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"` } // APIKeyAuthCacheEntry 缓存条目,支持负缓存 diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 64a70e8c..8069ed4f 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -13,6 +13,8 @@ import ( "github.com/dgraph-io/ristretto" ) +const apiKeyAuthSnapshotVersion = 3 + type apiKeyAuthCacheConfig struct { l1Size int l1TTL time.Duration @@ -192,6 +194,9 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn if entry.Snapshot == nil { return nil, false, nil } + if entry.Snapshot.Version != apiKeyAuthSnapshotVersion { + return nil, false, nil + } return s.snapshotToAPIKey(key, entry.Snapshot), true, nil } @@ -200,6 +205,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { return nil } snapshot := &APIKeyAuthSnapshot{ + Version: apiKeyAuthSnapshotVersion, APIKeyID: apiKey.ID, UserID: apiKey.UserID, GroupID: apiKey.GroupID, @@ -243,6 +249,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { SupportedModelScopes: apiKey.Group.SupportedModelScopes, AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch, DefaultMappedModel: apiKey.Group.DefaultMappedModel, + MessagesDispatchModelConfig: apiKey.Group.MessagesDispatchModelConfig, } } return snapshot @@ -298,6 +305,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho SupportedModelScopes: snapshot.Group.SupportedModelScopes, AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch, DefaultMappedModel: snapshot.Group.DefaultMappedModel, + MessagesDispatchModelConfig: snapshot.Group.MessagesDispatchModelConfig, } } s.compileAPIKeyIPRules(apiKey) diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go index 357f8def..3c2f7dbb 100644 --- a/backend/internal/service/api_key_service_cache_test.go +++ b/backend/internal/service/api_key_service_cache_test.go @@ -188,6 +188,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { groupID := int64(9) cacheEntry := &APIKeyAuthCacheEntry{ Snapshot: &APIKeyAuthSnapshot{ + Version: apiKeyAuthSnapshotVersion, APIKeyID: 1, UserID: 2, GroupID: &groupID, @@ -226,6 +227,129 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { require.Equal(t, map[string][]int64{"claude-opus-*": {1, 2}}, apiKey.Group.ModelRouting) } +func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t *testing.T) { + svc := NewAPIKeyService(nil, nil, nil, nil, nil, nil, &config.Config{}) + groupID := int64(9) + apiKey := &APIKey{ + ID: 1, + UserID: 2, + GroupID: &groupID, + Key: "k-roundtrip", + Status: StatusActive, + User: &User{ + ID: 2, + Status: StatusActive, + Role: RoleUser, + Balance: 10, + Concurrency: 3, + }, + Group: &Group{ + ID: groupID, + Name: "openai", + Platform: PlatformOpenAI, + Status: StatusActive, + SubscriptionType: SubscriptionTypeStandard, + RateMultiplier: 1, + AllowMessagesDispatch: true, + DefaultMappedModel: "gpt-5.4", + MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{ + OpusMappedModel: "gpt-5.4-nano", + SonnetMappedModel: "gpt-5.3-codex", + HaikuMappedModel: "gpt-5.4-mini", + ExactModelMappings: map[string]string{ + "claude-sonnet-4.5": "gpt-5.4-nano", + }, + }, + }, + } + + snapshot := svc.snapshotFromAPIKey(apiKey) + roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot) + + require.NotNil(t, roundTrip) + require.NotNil(t, roundTrip.Group) + require.Equal(t, apiKey.Group.MessagesDispatchModelConfig, roundTrip.Group.MessagesDispatchModelConfig) +} + +func TestAPIKeyService_GetByKey_IgnoresLegacyAuthCacheSnapshotWithoutMessagesDispatchConfig(t *testing.T) { + cache := &authCacheStub{} + var repoCalls int32 + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + atomic.AddInt32(&repoCalls, 1) + groupID := int64(9) + return &APIKey{ + ID: 1, + UserID: 2, + GroupID: &groupID, + Status: StatusActive, + User: &User{ + ID: 2, + Status: StatusActive, + Role: RoleUser, + Balance: 10, + Concurrency: 3, + }, + Group: &Group{ + ID: groupID, + Name: "openai", + Platform: PlatformOpenAI, + Status: StatusActive, + Hydrated: true, + SubscriptionType: SubscriptionTypeStandard, + RateMultiplier: 1, + AllowMessagesDispatch: true, + DefaultMappedModel: "gpt-5.4", + MessagesDispatchModelConfig: OpenAIMessagesDispatchModelConfig{ + OpusMappedModel: "gpt-5.4-nano", + }, + }, + }, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) + + groupID := int64(9) + cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + return &APIKeyAuthCacheEntry{ + Snapshot: &APIKeyAuthSnapshot{ + APIKeyID: 1, + UserID: 2, + GroupID: &groupID, + Status: StatusActive, + User: APIKeyAuthUserSnapshot{ + ID: 2, + Status: StatusActive, + Role: RoleUser, + Balance: 10, + Concurrency: 3, + }, + Group: &APIKeyAuthGroupSnapshot{ + ID: groupID, + Name: "openai", + Platform: PlatformOpenAI, + Status: StatusActive, + SubscriptionType: SubscriptionTypeStandard, + RateMultiplier: 1, + AllowMessagesDispatch: true, + DefaultMappedModel: "gpt-5.4", + }, + }, + }, nil + } + + apiKey, err := svc.GetByKey(context.Background(), "k-legacy") + require.NoError(t, err) + require.Equal(t, int32(1), atomic.LoadInt32(&repoCalls)) + require.NotNil(t, apiKey.Group) + require.Equal(t, "gpt-5.4-nano", apiKey.Group.MessagesDispatchModelConfig.OpusMappedModel) +} + func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) { cache := &authCacheStub{} repo := &authRepoStub{ diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index e194f921..68d7da3b 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -143,6 +143,8 @@ const ( SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮 SettingKeyPurchaseSubscriptionEnabled = "purchase_subscription_enabled" // 是否展示"购买订阅"页面入口 SettingKeyPurchaseSubscriptionURL = "purchase_subscription_url" // "购买订阅"页面 URL(作为 iframe src) + SettingKeyTableDefaultPageSize = "table_default_page_size" // 表格默认每页条数 + SettingKeyTablePageSizeOptions = "table_page_size_options" // 表格可选每页条数(JSON 数组) SettingKeyCustomMenuItems = "custom_menu_items" // 自定义菜单项(JSON 数组) SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表(JSON 数组) diff --git a/backend/internal/service/openai_ws_ratelimit_signal_test.go b/backend/internal/service/openai_ws_ratelimit_signal_test.go index ffe79152..6313d0c0 100644 --- a/backend/internal/service/openai_ws_ratelimit_signal_test.go +++ b/backend/internal/service/openai_ws_ratelimit_signal_test.go @@ -492,7 +492,7 @@ func TestAdminService_ListAccounts_ExhaustedCodexExtraReturnsRateLimitedAccount( } svc := &adminServiceImpl{accountRepo: repo} - accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0, "") + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformOpenAI, AccountTypeOAuth, "", "", 0, "", "", "") require.NoError(t, err) require.Equal(t, int64(1), total) require.Len(t, accounts, 1) diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 4b7bd988..48f25da0 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -9,6 +9,7 @@ import ( "fmt" "log/slog" "net/url" + "sort" "strconv" "strings" "sync/atomic" @@ -161,6 +162,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyHideCcsImportButton, SettingKeyPurchaseSubscriptionEnabled, SettingKeyPurchaseSubscriptionURL, + SettingKeyTableDefaultPageSize, + SettingKeyTablePageSizeOptions, SettingKeyCustomMenuItems, SettingKeyCustomEndpoints, SettingKeyLinuxDoConnectEnabled, @@ -201,6 +204,10 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings registrationEmailSuffixWhitelist := ParseRegistrationEmailSuffixWhitelist( settings[SettingKeyRegistrationEmailSuffixWhitelist], ) + tableDefaultPageSize, tablePageSizeOptions := parseTablePreferences( + settings[SettingKeyTableDefaultPageSize], + settings[SettingKeyTablePageSizeOptions], + ) return &PublicSettings{ RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", @@ -222,6 +229,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), + TableDefaultPageSize: tableDefaultPageSize, + TablePageSizeOptions: tablePageSizeOptions, CustomMenuItems: settings[SettingKeyCustomMenuItems], CustomEndpoints: settings[SettingKeyCustomEndpoints], LinuxDoOAuthEnabled: linuxDoEnabled, @@ -272,6 +281,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any HideCcsImportButton bool `json:"hide_ccs_import_button"` PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` + TableDefaultPageSize int `json:"table_default_page_size"` + TablePageSizeOptions []int `json:"table_page_size_options"` CustomMenuItems json.RawMessage `json:"custom_menu_items"` CustomEndpoints json.RawMessage `json:"custom_endpoints"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` @@ -300,6 +311,8 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any HideCcsImportButton: settings.HideCcsImportButton, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, + TableDefaultPageSize: settings.TableDefaultPageSize, + TablePageSizeOptions: settings.TablePageSizeOptions, CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, @@ -526,6 +539,16 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton) updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled) updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL) + tableDefaultPageSize, tablePageSizeOptions := normalizeTablePreferences( + settings.TableDefaultPageSize, + settings.TablePageSizeOptions, + ) + updates[SettingKeyTableDefaultPageSize] = strconv.Itoa(tableDefaultPageSize) + tablePageSizeOptionsJSON, err := json.Marshal(tablePageSizeOptions) + if err != nil { + return fmt.Errorf("marshal table page size options: %w", err) + } + updates[SettingKeyTablePageSizeOptions] = string(tablePageSizeOptionsJSON) updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems updates[SettingKeyCustomEndpoints] = settings.CustomEndpoints @@ -879,6 +902,8 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeySiteLogo: "", SettingKeyPurchaseSubscriptionEnabled: "false", SettingKeyPurchaseSubscriptionURL: "", + SettingKeyTableDefaultPageSize: "20", + SettingKeyTablePageSizeOptions: "[10,20,50,100]", SettingKeyCustomMenuItems: "[]", SettingKeyCustomEndpoints: "[]", SettingKeyOIDCConnectEnabled: "false", @@ -950,6 +975,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin CustomEndpoints: settings[SettingKeyCustomEndpoints], BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", } + result.TableDefaultPageSize, result.TablePageSizeOptions = parseTablePreferences( + settings[SettingKeyTableDefaultPageSize], + settings[SettingKeyTablePageSizeOptions], + ) // 解析整数类型 if port, err := strconv.Atoi(settings[SettingKeySMTPPort]); err == nil { @@ -1225,6 +1254,50 @@ func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting { return normalized } +func parseTablePreferences(defaultPageSizeRaw, optionsRaw string) (int, []int) { + defaultPageSize := 20 + if v, err := strconv.Atoi(strings.TrimSpace(defaultPageSizeRaw)); err == nil { + defaultPageSize = v + } + + var options []int + if strings.TrimSpace(optionsRaw) != "" { + _ = json.Unmarshal([]byte(optionsRaw), &options) + } + + return normalizeTablePreferences(defaultPageSize, options) +} + +func normalizeTablePreferences(defaultPageSize int, options []int) (int, []int) { + const minPageSize = 5 + const maxPageSize = 1000 + const fallbackPageSize = 20 + + seen := make(map[int]struct{}, len(options)) + normalizedOptions := make([]int, 0, len(options)) + for _, option := range options { + if option < minPageSize || option > maxPageSize { + continue + } + if _, ok := seen[option]; ok { + continue + } + seen[option] = struct{}{} + normalizedOptions = append(normalizedOptions, option) + } + sort.Ints(normalizedOptions) + + if defaultPageSize < minPageSize || defaultPageSize > maxPageSize { + defaultPageSize = fallbackPageSize + } + + if len(normalizedOptions) == 0 { + normalizedOptions = []int{10, 20, 50} + } + + return defaultPageSize, normalizedOptions +} + // getStringOrDefault 获取字符串值或默认值 func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string { if value, ok := settings[key]; ok && value != "" { diff --git a/backend/internal/service/setting_service_public_test.go b/backend/internal/service/setting_service_public_test.go index b511cd29..6dfa627c 100644 --- a/backend/internal/service/setting_service_public_test.go +++ b/backend/internal/service/setting_service_public_test.go @@ -62,3 +62,18 @@ func TestSettingService_GetPublicSettings_ExposesRegistrationEmailSuffixWhitelis require.NoError(t, err) require.Equal(t, []string{"@example.com", "@foo.bar"}, settings.RegistrationEmailSuffixWhitelist) } + +func TestSettingService_GetPublicSettings_ExposesTablePreferences(t *testing.T) { + repo := &settingPublicRepoStub{ + values: map[string]string{ + SettingKeyTableDefaultPageSize: "50", + SettingKeyTablePageSizeOptions: "[20,50,100]", + }, + } + svc := NewSettingService(repo, &config.Config{}) + + settings, err := svc.GetPublicSettings(context.Background()) + require.NoError(t, err) + require.Equal(t, 50, settings.TableDefaultPageSize) + require.Equal(t, []int{20, 50, 100}, settings.TablePageSizeOptions) +} diff --git a/backend/internal/service/setting_service_update_test.go b/backend/internal/service/setting_service_update_test.go index 1de08611..28c7ad02 100644 --- a/backend/internal/service/setting_service_update_test.go +++ b/backend/internal/service/setting_service_update_test.go @@ -202,3 +202,24 @@ func TestParseDefaultSubscriptions_NormalizesValues(t *testing.T) { {GroupID: 12, ValidityDays: MaxValidityDays}, }, got) } + +func TestSettingService_UpdateSettings_TablePreferences(t *testing.T) { + repo := &settingUpdateRepoStub{} + svc := NewSettingService(repo, &config.Config{}) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + TableDefaultPageSize: 50, + TablePageSizeOptions: []int{20, 50, 100}, + }) + require.NoError(t, err) + require.Equal(t, "50", repo.updates[SettingKeyTableDefaultPageSize]) + require.Equal(t, "[20,50,100]", repo.updates[SettingKeyTablePageSizeOptions]) + + err = svc.UpdateSettings(context.Background(), &SystemSettings{ + TableDefaultPageSize: 1000, + TablePageSizeOptions: []int{20, 100}, + }) + require.NoError(t, err) + require.Equal(t, "1000", repo.updates[SettingKeyTableDefaultPageSize]) + require.Equal(t, "[20,100]", repo.updates[SettingKeyTablePageSizeOptions]) +} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 68076013..de92b796 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -66,6 +66,8 @@ type SystemSettings struct { HideCcsImportButton bool PurchaseSubscriptionEnabled bool PurchaseSubscriptionURL string + TableDefaultPageSize int + TablePageSizeOptions []int CustomMenuItems string // JSON array of custom menu items CustomEndpoints string // JSON array of custom endpoints @@ -132,6 +134,8 @@ type PublicSettings struct { PurchaseSubscriptionEnabled bool PurchaseSubscriptionURL string + TableDefaultPageSize int + TablePageSizeOptions []int CustomMenuItems string // JSON array of custom menu items CustomEndpoints string // JSON array of custom endpoints diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 8e40e18b..a146f1f7 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -38,6 +38,8 @@ export async function list( search?: string privacy_mode?: string lite?: string + sort_by?: string + sort_order?: 'asc' | 'desc' }, options?: { signal?: AbortSignal @@ -71,6 +73,8 @@ export async function listWithEtag( search?: string privacy_mode?: string lite?: string + sort_by?: string + sort_order?: 'asc' | 'desc' }, options?: { signal?: AbortSignal @@ -500,7 +504,11 @@ export async function exportData(options?: { platform?: string type?: string status?: string + group?: string + privacy_mode?: string search?: string + sort_by?: string + sort_order?: 'asc' | 'desc' } includeProxies?: boolean }): Promise { @@ -508,11 +516,15 @@ export async function exportData(options?: { if (options?.ids && options.ids.length > 0) { params.ids = options.ids.join(',') } else if (options?.filters) { - const { platform, type, status, search } = options.filters + const { platform, type, status, group, privacy_mode, search, sort_by, sort_order } = options.filters if (platform) params.platform = platform if (type) params.type = type if (status) params.status = status + if (group) params.group = group + if (privacy_mode) params.privacy_mode = privacy_mode if (search) params.search = search + if (sort_by) params.sort_by = sort_by + if (sort_order) params.sort_order = sort_order } if (options?.includeProxies === false) { params.include_proxies = 'false' diff --git a/frontend/src/api/admin/announcements.ts b/frontend/src/api/admin/announcements.ts index d02fdda7..92392a67 100644 --- a/frontend/src/api/admin/announcements.ts +++ b/frontend/src/api/admin/announcements.ts @@ -17,10 +17,16 @@ export async function list( filters?: { status?: string search?: string + sort_by?: string + sort_order?: 'asc' | 'desc' + }, + options?: { + signal?: AbortSignal } ): Promise> { const { data } = await apiClient.get>('/admin/announcements', { - params: { page, page_size: pageSize, ...filters } + params: { page, page_size: pageSize, ...filters }, + signal: options?.signal }) return data } @@ -49,11 +55,21 @@ export async function getReadStatus( id: number, page: number = 1, pageSize: number = 20, - search: string = '' + filters?: { + search?: string + sort_by?: string + sort_order?: 'asc' | 'desc' + }, + options?: { + signal?: AbortSignal + } ): Promise> { const { data } = await apiClient.get>( `/admin/announcements/${id}/read-status`, - { params: { page, page_size: pageSize, search } } + { + params: { page, page_size: pageSize, ...filters }, + signal: options?.signal + } ) return data } @@ -68,4 +84,3 @@ const announcementsAPI = { } export default announcementsAPI - diff --git a/frontend/src/api/admin/channels.ts b/frontend/src/api/admin/channels.ts index 5334dd47..b3455022 100644 --- a/frontend/src/api/admin/channels.ts +++ b/frontend/src/api/admin/channels.ts @@ -83,6 +83,8 @@ export async function list( filters?: { status?: string search?: string + sort_by?: string + sort_order?: 'asc' | 'desc' }, options?: { signal?: AbortSignal } ): Promise> { diff --git a/frontend/src/api/admin/groups.ts b/frontend/src/api/admin/groups.ts index 5885dc6a..8739d5cb 100644 --- a/frontend/src/api/admin/groups.ts +++ b/frontend/src/api/admin/groups.ts @@ -27,6 +27,8 @@ export async function list( status?: 'active' | 'inactive' is_exclusive?: boolean search?: string + sort_by?: string + sort_order?: 'asc' | 'desc' }, options?: { signal?: AbortSignal diff --git a/frontend/src/api/admin/promo.ts b/frontend/src/api/admin/promo.ts index 6a8c4559..b24dffc2 100644 --- a/frontend/src/api/admin/promo.ts +++ b/frontend/src/api/admin/promo.ts @@ -17,10 +17,16 @@ export async function list( filters?: { status?: string search?: string + sort_by?: string + sort_order?: 'asc' | 'desc' + }, + options?: { + signal?: AbortSignal } ): Promise> { const { data } = await apiClient.get>('/admin/promo-codes', { - params: { page, page_size: pageSize, ...filters } + params: { page, page_size: pageSize, ...filters }, + signal: options?.signal }) return data } diff --git a/frontend/src/api/admin/proxies.ts b/frontend/src/api/admin/proxies.ts index 5e31ae20..3e041ba9 100644 --- a/frontend/src/api/admin/proxies.ts +++ b/frontend/src/api/admin/proxies.ts @@ -29,6 +29,8 @@ export async function list( protocol?: string status?: 'active' | 'inactive' search?: string + sort_by?: string + sort_order?: 'asc' | 'desc' }, options?: { signal?: AbortSignal @@ -227,16 +229,20 @@ export async function exportData(options?: { protocol?: string status?: 'active' | 'inactive' search?: string + sort_by?: string + sort_order?: 'asc' | 'desc' } }): Promise { const params: Record = {} if (options?.ids && options.ids.length > 0) { params.ids = options.ids.join(',') } else if (options?.filters) { - const { protocol, status, search } = options.filters + const { protocol, status, search, sort_by, sort_order } = options.filters if (protocol) params.protocol = protocol if (status) params.status = status if (search) params.search = search + if (sort_by) params.sort_by = sort_by + if (sort_order) params.sort_order = sort_order } const { data } = await apiClient.get('/admin/proxies/data', { params }) return data diff --git a/frontend/src/api/admin/redeem.ts b/frontend/src/api/admin/redeem.ts index a53c3566..57626b1e 100644 --- a/frontend/src/api/admin/redeem.ts +++ b/frontend/src/api/admin/redeem.ts @@ -25,6 +25,8 @@ export async function list( type?: RedeemCodeType status?: 'active' | 'used' | 'expired' | 'unused' search?: string + sort_by?: string + sort_order?: 'asc' | 'desc' }, options?: { signal?: AbortSignal @@ -151,7 +153,10 @@ export async function getStats(): Promise<{ */ export async function exportCodes(filters?: { type?: RedeemCodeType - status?: 'active' | 'used' | 'expired' + status?: 'used' | 'expired' | 'unused' + search?: string + sort_by?: string + sort_order?: 'asc' | 'desc' }): Promise { const response = await apiClient.get('/admin/redeem-codes/export', { params: filters, diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index d725528d..504abe9c 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -38,6 +38,8 @@ export interface SystemSettings { doc_url: string home_content: string hide_ccs_import_button: boolean + table_default_page_size: number + table_page_size_options: number[] backend_mode_enabled: boolean custom_menu_items: CustomMenuItem[] custom_endpoints: CustomEndpoint[] @@ -154,6 +156,8 @@ export interface UpdateSettingsRequest { doc_url?: string home_content?: string hide_ccs_import_button?: boolean + table_default_page_size?: number + table_page_size_options?: number[] backend_mode_enabled?: boolean custom_menu_items?: CustomMenuItem[] custom_endpoints?: CustomEndpoint[] diff --git a/frontend/src/api/admin/usage.ts b/frontend/src/api/admin/usage.ts index d21b28dc..37df7553 100644 --- a/frontend/src/api/admin/usage.ts +++ b/frontend/src/api/admin/usage.ts @@ -81,6 +81,8 @@ export interface AdminUsageQueryParams extends UsageQueryParams { user_id?: number exact_total?: boolean billing_mode?: string + sort_by?: string + sort_order?: 'asc' | 'desc' } // ==================== API Functions ==================== diff --git a/frontend/src/api/admin/users.ts b/frontend/src/api/admin/users.ts index bbf0ab51..39cb1dfa 100644 --- a/frontend/src/api/admin/users.ts +++ b/frontend/src/api/admin/users.ts @@ -24,6 +24,8 @@ export async function list( group_name?: string // fuzzy filter by allowed group name attributes?: Record // attributeId -> value include_subscriptions?: boolean + sort_by?: string + sort_order?: 'asc' | 'desc' }, options?: { signal?: AbortSignal @@ -37,7 +39,9 @@ export async function list( role: filters?.role, search: filters?.search, group_name: filters?.group_name, - include_subscriptions: filters?.include_subscriptions + include_subscriptions: filters?.include_subscriptions, + sort_by: filters?.sort_by, + sort_order: filters?.sort_order } // Add attribute filters as attr[id]=value diff --git a/frontend/src/api/keys.ts b/frontend/src/api/keys.ts index 137e10ba..34dd5b4b 100644 --- a/frontend/src/api/keys.ts +++ b/frontend/src/api/keys.ts @@ -17,7 +17,13 @@ import type { ApiKey, CreateApiKeyRequest, UpdateApiKeyRequest, PaginatedRespons export async function list( page: number = 1, pageSize: number = 10, - filters?: { search?: string; status?: string; group_id?: number | string }, + filters?: { + search?: string + status?: string + group_id?: number | string + sort_by?: string + sort_order?: 'asc' | 'desc' + }, options?: { signal?: AbortSignal } diff --git a/frontend/src/api/usage.ts b/frontend/src/api/usage.ts index 6efd7657..802c428f 100644 --- a/frontend/src/api/usage.ts +++ b/frontend/src/api/usage.ts @@ -91,7 +91,7 @@ export async function list( * @returns Paginated list of usage logs */ export async function query( - params: UsageQueryParams, + params: UsageQueryParams & { sort_by?: string; sort_order?: 'asc' | 'desc' }, config: { signal?: AbortSignal } = {} ): Promise> { const { data } = await apiClient.get>('/usage', { diff --git a/frontend/src/components/admin/account/AccountTableFilters.vue b/frontend/src/components/admin/account/AccountTableFilters.vue index 6b474183..b33dad84 100644 --- a/frontend/src/components/admin/account/AccountTableFilters.vue +++ b/frontend/src/components/admin/account/AccountTableFilters.vue @@ -27,7 +27,7 @@ const updatePrivacyMode = (value: string | number | boolean | null) => { emit('u const updateGroup = (value: string | number | boolean | null) => { emit('update:filters', { ...props.filters, group: value }) } const pOpts = computed(() => [{ value: '', label: t('admin.accounts.allPlatforms') }, { value: 'anthropic', label: 'Anthropic' }, { value: 'openai', label: 'OpenAI' }, { value: 'gemini', label: 'Gemini' }, { value: 'antigravity', label: 'Antigravity' }]) const tOpts = computed(() => [{ value: '', label: t('admin.accounts.allTypes') }, { value: 'oauth', label: t('admin.accounts.oauthType') }, { value: 'setup-token', label: t('admin.accounts.setupToken') }, { value: 'apikey', label: t('admin.accounts.apiKey') }, { value: 'bedrock', label: 'AWS Bedrock' }]) -const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }, { value: 'temp_unschedulable', label: t('admin.accounts.status.tempUnschedulable') }]) +const sOpts = computed(() => [{ value: '', label: t('admin.accounts.allStatus') }, { value: 'active', label: t('admin.accounts.status.active') }, { value: 'inactive', label: t('admin.accounts.status.inactive') }, { value: 'error', label: t('admin.accounts.status.error') }, { value: 'rate_limited', label: t('admin.accounts.status.rateLimited') }, { value: 'temp_unschedulable', label: t('admin.accounts.status.tempUnschedulable') }, { value: 'unschedulable', label: t('admin.accounts.status.unschedulable') }]) const privacyOpts = computed(() => [ { value: '', label: t('admin.accounts.allPrivacyModes') }, { value: '__unset__', label: t('admin.accounts.privacyUnset') }, diff --git a/frontend/src/components/admin/account/__tests__/AccountTableFilters.spec.ts b/frontend/src/components/admin/account/__tests__/AccountTableFilters.spec.ts deleted file mode 100644 index 5a0044e5..00000000 --- a/frontend/src/components/admin/account/__tests__/AccountTableFilters.spec.ts +++ /dev/null @@ -1,56 +0,0 @@ -import { describe, expect, it, vi } from 'vitest' -import { mount } from '@vue/test-utils' - -import AccountTableFilters from '../AccountTableFilters.vue' - -vi.mock('vue-i18n', async () => { - const actual = await vi.importActual('vue-i18n') - return { - ...actual, - useI18n: () => ({ - t: (key: string) => key - }) - } -}) - -describe('AccountTableFilters', () => { - it('renders privacy mode options and emits privacy_mode updates', async () => { - const wrapper = mount(AccountTableFilters, { - props: { - searchQuery: '', - filters: { - platform: '', - type: '', - status: '', - group: '', - privacy_mode: '' - }, - groups: [] - }, - global: { - stubs: { - SearchInput: { - template: '
' - }, - Select: { - props: ['modelValue', 'options'], - emits: ['update:modelValue', 'change'], - template: '
' - } - } - } - }) - - const selects = wrapper.findAll('.select-stub') - expect(selects).toHaveLength(5) - - const privacyOptions = JSON.parse(selects[3].attributes('data-options')) - expect(privacyOptions).toEqual([ - { value: '', label: 'admin.accounts.allPrivacyModes' }, - { value: '__unset__', label: 'admin.accounts.privacyUnset' }, - { value: 'training_off', label: 'Privacy' }, - { value: 'training_set_cf_blocked', label: 'CF' }, - { value: 'training_set_failed', label: 'Fail' } - ]) - }) -}) diff --git a/frontend/src/components/admin/announcements/AnnouncementReadStatusDialog.vue b/frontend/src/components/admin/announcements/AnnouncementReadStatusDialog.vue index a0d9de3c..60c01c6d 100644 --- a/frontend/src/components/admin/announcements/AnnouncementReadStatusDialog.vue +++ b/frontend/src/components/admin/announcements/AnnouncementReadStatusDialog.vue @@ -21,7 +21,15 @@
- + @@ -62,7 +70,7 @@ diff --git a/frontend/src/components/admin/announcements/__tests__/AnnouncementReadStatusDialog.spec.ts b/frontend/src/components/admin/announcements/__tests__/AnnouncementReadStatusDialog.spec.ts new file mode 100644 index 00000000..26c87d73 --- /dev/null +++ b/frontend/src/components/admin/announcements/__tests__/AnnouncementReadStatusDialog.spec.ts @@ -0,0 +1,95 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import { flushPromises, mount } from '@vue/test-utils' + +import AnnouncementReadStatusDialog from '../AnnouncementReadStatusDialog.vue' + +const { getReadStatus, showError } = vi.hoisted(() => ({ + getReadStatus: vi.fn(), + showError: vi.fn(), +})) + +vi.mock('@/api/admin', () => ({ + adminAPI: { + announcements: { + getReadStatus, + }, + }, +})) + +vi.mock('@/stores/app', () => ({ + useAppStore: () => ({ + showError, + }), +})) + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string) => key, + }), + } +}) + +vi.mock('@/composables/usePersistedPageSize', () => ({ + getPersistedPageSize: () => 20, +})) + +const BaseDialogStub = { + props: ['show', 'title', 'width'], + emits: ['close'], + template: '
', +} + +describe('AnnouncementReadStatusDialog', () => { + beforeEach(() => { + getReadStatus.mockReset() + showError.mockReset() + vi.useFakeTimers() + }) + + it('closes by aborting active requests and clearing debounced reloads', async () => { + let activeSignal: AbortSignal | undefined + getReadStatus.mockImplementation(async (...args: any[]) => { + activeSignal = args[4]?.signal + return new Promise(() => {}) + }) + + const wrapper = mount(AnnouncementReadStatusDialog, { + props: { + show: false, + announcementId: 1, + }, + global: { + stubs: { + BaseDialog: BaseDialogStub, + DataTable: true, + Pagination: true, + Icon: true, + }, + }, + }) + + await wrapper.setProps({ show: true }) + await flushPromises() + + expect(getReadStatus).toHaveBeenCalledTimes(1) + expect(activeSignal?.aborted).toBe(false) + + const setupState = (wrapper.vm as any).$?.setupState + setupState.search = 'alice' + setupState.handleSearch() + + setupState.handleClose() + await flushPromises() + + expect(activeSignal?.aborted).toBe(true) + expect(wrapper.emitted('close')).toHaveLength(1) + + vi.advanceTimersByTime(350) + await flushPromises() + + expect(getReadStatus).toHaveBeenCalledTimes(1) + }) +}) diff --git a/frontend/src/components/admin/group/GroupRateMultipliersModal.vue b/frontend/src/components/admin/group/GroupRateMultipliersModal.vue index cbd18af6..bf79bea2 100644 --- a/frontend/src/components/admin/group/GroupRateMultipliersModal.vue +++ b/frontend/src/components/admin/group/GroupRateMultipliersModal.vue @@ -196,7 +196,6 @@ :total="localEntries.length" :page="currentPage" :page-size="pageSize" - :page-size-options="[10, 20, 50]" @update:page="currentPage = $event" @update:pageSize="handlePageSizeChange" /> diff --git a/frontend/src/components/admin/usage/UsageTable.vue b/frontend/src/components/admin/usage/UsageTable.vue index 9bbdb380..f4494e69 100644 --- a/frontend/src/components/admin/usage/UsageTable.vue +++ b/frontend/src/components/admin/usage/UsageTable.vue @@ -1,7 +1,15 @@