Merge branch 'dev-release'
This commit is contained in:
@@ -3,6 +3,7 @@ package admin
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -789,57 +790,40 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
success := 0
|
||||
failed := 0
|
||||
results := []gin.H{}
|
||||
|
||||
// 阶段一:预验证所有账号存在,收集 credentials
|
||||
type accountUpdate struct {
|
||||
ID int64
|
||||
Credentials map[string]any
|
||||
}
|
||||
updates := make([]accountUpdate, 0, len(req.AccountIDs))
|
||||
for _, accountID := range req.AccountIDs {
|
||||
// Get account
|
||||
account, err := h.adminService.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
failed++
|
||||
results = append(results, gin.H{
|
||||
"account_id": accountID,
|
||||
"success": false,
|
||||
"error": "Account not found",
|
||||
})
|
||||
continue
|
||||
response.Error(c, 404, fmt.Sprintf("Account %d not found", accountID))
|
||||
return
|
||||
}
|
||||
|
||||
// Update credentials field
|
||||
if account.Credentials == nil {
|
||||
account.Credentials = make(map[string]any)
|
||||
}
|
||||
|
||||
account.Credentials[req.Field] = req.Value
|
||||
updates = append(updates, accountUpdate{ID: accountID, Credentials: account.Credentials})
|
||||
}
|
||||
|
||||
// Update account
|
||||
// 阶段二:依次更新,任何失败立即返回(避免部分成功部分失败)
|
||||
for _, u := range updates {
|
||||
updateInput := &service.UpdateAccountInput{
|
||||
Credentials: account.Credentials,
|
||||
Credentials: u.Credentials,
|
||||
}
|
||||
|
||||
_, err = h.adminService.UpdateAccount(ctx, accountID, updateInput)
|
||||
if err != nil {
|
||||
failed++
|
||||
results = append(results, gin.H{
|
||||
"account_id": accountID,
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
if _, err := h.adminService.UpdateAccount(ctx, u.ID, updateInput); err != nil {
|
||||
response.Error(c, 500, fmt.Sprintf("Failed to update account %d: %v", u.ID, err))
|
||||
return
|
||||
}
|
||||
|
||||
success++
|
||||
results = append(results, gin.H{
|
||||
"account_id": accountID,
|
||||
"success": true,
|
||||
})
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"results": results,
|
||||
"success": len(updates),
|
||||
"failed": 0,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
200
backend/internal/handler/admin/batch_update_credentials_test.go
Normal file
200
backend/internal/handler/admin/batch_update_credentials_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
//go:build unit
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// failingAdminService 嵌入 stubAdminService,可配置 UpdateAccount 在指定 ID 时失败。
|
||||
type failingAdminService struct {
|
||||
*stubAdminService
|
||||
failOnAccountID int64
|
||||
updateCallCount atomic.Int64
|
||||
}
|
||||
|
||||
func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
|
||||
f.updateCallCount.Add(1)
|
||||
if id == f.failOnAccountID {
|
||||
return nil, errors.New("database error")
|
||||
}
|
||||
return f.stubAdminService.UpdateAccount(ctx, id, input)
|
||||
}
|
||||
|
||||
func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials)
|
||||
return router, handler
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_AllSuccess(t *testing.T) {
|
||||
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
|
||||
AccountIDs: []int64{1, 2, 3},
|
||||
Field: "account_uuid",
|
||||
Value: "test-uuid",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code, "全部成功时应返回 200")
|
||||
require.Equal(t, int64(3), svc.updateCallCount.Load(), "应调用 3 次 UpdateAccount")
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_FailFast(t *testing.T) {
|
||||
// 让第 2 个账号(ID=2)更新时失败
|
||||
svc := &failingAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
failOnAccountID: 2,
|
||||
}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
|
||||
AccountIDs: []int64{1, 2, 3},
|
||||
Field: "org_uuid",
|
||||
Value: "test-org",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusInternalServerError, w.Code, "ID=2 失败时应返回 500")
|
||||
// 验证 fail-fast:ID=1 更新成功,ID=2 失败,ID=3 不应被调用
|
||||
require.Equal(t, int64(2), svc.updateCallCount.Load(),
|
||||
"fail-fast: 应只调用 2 次 UpdateAccount(ID=1 成功、ID=2 失败后停止)")
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_FirstAccountNotFound(t *testing.T) {
|
||||
// GetAccount 在 stubAdminService 中总是成功的,需要创建一个 GetAccount 会失败的 stub
|
||||
svc := &getAccountFailingService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
failOnAccountID: 1,
|
||||
}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
|
||||
AccountIDs: []int64{1, 2, 3},
|
||||
Field: "account_uuid",
|
||||
Value: "test",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusNotFound, w.Code, "第一阶段验证失败应返回 404")
|
||||
}
|
||||
|
||||
// getAccountFailingService 模拟 GetAccount 在特定 ID 时返回 not found。
|
||||
type getAccountFailingService struct {
|
||||
*stubAdminService
|
||||
failOnAccountID int64
|
||||
}
|
||||
|
||||
func (f *getAccountFailingService) GetAccount(ctx context.Context, id int64) (*service.Account, error) {
|
||||
if id == f.failOnAccountID {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
return f.stubAdminService.GetAccount(ctx, id)
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_InterceptWarmupRequests_NonBool(t *testing.T) {
|
||||
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
// intercept_warmup_requests 传入非 bool 类型(string),应返回 400
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"account_ids": []int64{1},
|
||||
"field": "intercept_warmup_requests",
|
||||
"value": "not-a-bool",
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, w.Code,
|
||||
"intercept_warmup_requests 传入非 bool 值应返回 400")
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_InterceptWarmupRequests_ValidBool(t *testing.T) {
|
||||
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"account_ids": []int64{1},
|
||||
"field": "intercept_warmup_requests",
|
||||
"value": true,
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code,
|
||||
"intercept_warmup_requests 传入合法 bool 值应返回 200")
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_AccountUUID_NonString(t *testing.T) {
|
||||
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
// account_uuid 传入非 string 类型(number),应返回 400
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"account_ids": []int64{1},
|
||||
"field": "account_uuid",
|
||||
"value": 12345,
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, w.Code,
|
||||
"account_uuid 传入非 string 值应返回 400")
|
||||
}
|
||||
|
||||
func TestBatchUpdateCredentials_AccountUUID_NullValue(t *testing.T) {
|
||||
svc := &failingAdminService{stubAdminService: newStubAdminService()}
|
||||
router, _ := setupAccountHandlerWithService(svc)
|
||||
|
||||
// account_uuid 传入 null(设置为空),应正常通过
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"account_ids": []int64{1},
|
||||
"field": "account_uuid",
|
||||
"value": nil,
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code,
|
||||
"account_uuid 传入 null 应返回 200")
|
||||
}
|
||||
@@ -379,7 +379,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
|
||||
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{})
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get user usage stats")
|
||||
return
|
||||
@@ -407,7 +407,7 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs)
|
||||
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{})
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get API key usage stats")
|
||||
return
|
||||
|
||||
97
backend/internal/handler/admin/search_truncate_test.go
Normal file
97
backend/internal/handler/admin/search_truncate_test.go
Normal file
@@ -0,0 +1,97 @@
|
||||
//go:build unit
|
||||
|
||||
package admin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// truncateSearchByRune 模拟 user_handler.go 中的 search 截断逻辑
|
||||
func truncateSearchByRune(search string, maxRunes int) string {
|
||||
if runes := []rune(search); len(runes) > maxRunes {
|
||||
return string(runes[:maxRunes])
|
||||
}
|
||||
return search
|
||||
}
|
||||
|
||||
func TestTruncateSearchByRune(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
maxRunes int
|
||||
wantLen int // 期望的 rune 长度
|
||||
}{
|
||||
{
|
||||
name: "纯中文超长",
|
||||
input: string(make([]rune, 150)),
|
||||
maxRunes: 100,
|
||||
wantLen: 100,
|
||||
},
|
||||
{
|
||||
name: "纯 ASCII 超长",
|
||||
input: string(make([]byte, 150)),
|
||||
maxRunes: 100,
|
||||
wantLen: 100,
|
||||
},
|
||||
{
|
||||
name: "空字符串",
|
||||
input: "",
|
||||
maxRunes: 100,
|
||||
wantLen: 0,
|
||||
},
|
||||
{
|
||||
name: "恰好 100 个字符",
|
||||
input: string(make([]rune, 100)),
|
||||
maxRunes: 100,
|
||||
wantLen: 100,
|
||||
},
|
||||
{
|
||||
name: "不足 100 字符不截断",
|
||||
input: "hello世界",
|
||||
maxRunes: 100,
|
||||
wantLen: 7,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := truncateSearchByRune(tc.input, tc.maxRunes)
|
||||
require.Equal(t, tc.wantLen, len([]rune(result)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateSearchByRune_PreservesMultibyte(t *testing.T) {
|
||||
// 101 个中文字符,截断到 100 个后应该仍然是有效 UTF-8
|
||||
input := ""
|
||||
for i := 0; i < 101; i++ {
|
||||
input += "中"
|
||||
}
|
||||
result := truncateSearchByRune(input, 100)
|
||||
|
||||
require.Equal(t, 100, len([]rune(result)))
|
||||
// 验证截断结果是有效的 UTF-8(每个中文字符 3 字节)
|
||||
require.Equal(t, 300, len(result))
|
||||
}
|
||||
|
||||
func TestTruncateSearchByRune_MixedASCIIAndMultibyte(t *testing.T) {
|
||||
// 50 个 ASCII + 51 个中文 = 101 个 rune
|
||||
input := ""
|
||||
for i := 0; i < 50; i++ {
|
||||
input += "a"
|
||||
}
|
||||
for i := 0; i < 51; i++ {
|
||||
input += "中"
|
||||
}
|
||||
result := truncateSearchByRune(input, 100)
|
||||
|
||||
runes := []rune(result)
|
||||
require.Equal(t, 100, len(runes))
|
||||
// 前 50 个应该是 'a',后 50 个应该是 '中'
|
||||
require.Equal(t, 'a', runes[0])
|
||||
require.Equal(t, 'a', runes[49])
|
||||
require.Equal(t, '中', runes[50])
|
||||
require.Equal(t, '中', runes[99])
|
||||
}
|
||||
@@ -70,8 +70,8 @@ func (h *UserHandler) List(c *gin.Context) {
|
||||
search := c.Query("search")
|
||||
// 标准化和验证 search 参数
|
||||
search = strings.TrimSpace(search)
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
if runes := []rune(search); len(runes) > 100 {
|
||||
search = string(runes[:100])
|
||||
}
|
||||
|
||||
filters := service.UserListFilters{
|
||||
|
||||
Reference in New Issue
Block a user