Merge pull request #670 from DaydreamCoding/feat/admin-apikey-group-update

feat(admin): 添加管理员直接修改用户 API Key 分组的功能
This commit is contained in:
Wesley Liddick
2026-02-28 22:20:29 +08:00
committed by GitHub
26 changed files with 1138 additions and 22 deletions

View File

@@ -103,7 +103,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
@@ -192,7 +192,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient)
errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache)
errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler)
adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)

View File

@@ -109,6 +109,7 @@ require (
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/subcommands v1.2.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
@@ -177,6 +178,7 @@ require (
golang.org/x/mod v0.32.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
golang.org/x/tools v0.41.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
modernc.org/libc v1.67.6 // indirect

View File

@@ -182,6 +182,8 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=

View File

@@ -403,5 +403,23 @@ func (s *stubAdminService) UpdateGroupSortOrders(ctx context.Context, updates []
return nil
}
func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*service.AdminUpdateAPIKeyGroupIDResult, error) {
for i := range s.apiKeys {
if s.apiKeys[i].ID == keyID {
k := s.apiKeys[i]
if groupID != nil {
if *groupID == 0 {
k.GroupID = nil
} else {
gid := *groupID
k.GroupID = &gid
}
}
return &service.AdminUpdateAPIKeyGroupIDResult{APIKey: &k}, nil
}
}
return nil, service.ErrAPIKeyNotFound
}
// Ensure stub implements interface.
var _ service.AdminService = (*stubAdminService)(nil)

View File

@@ -0,0 +1,63 @@
package admin
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AdminAPIKeyHandler handles admin API key management
type AdminAPIKeyHandler struct {
adminService service.AdminService
}
// NewAdminAPIKeyHandler creates a new admin API key handler
func NewAdminAPIKeyHandler(adminService service.AdminService) *AdminAPIKeyHandler {
return &AdminAPIKeyHandler{
adminService: adminService,
}
}
// AdminUpdateAPIKeyGroupRequest represents the request to update an API key's group
type AdminUpdateAPIKeyGroupRequest struct {
GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组
}
// UpdateGroup handles updating an API key's group binding
// PUT /api/v1/admin/api-keys/:id
func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid API key ID")
return
}
var req AdminUpdateAPIKeyGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
result, err := h.adminService.AdminUpdateAPIKeyGroupID(c.Request.Context(), keyID, req.GroupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
resp := struct {
APIKey *dto.APIKey `json:"api_key"`
AutoGrantedGroupAccess bool `json:"auto_granted_group_access"`
GrantedGroupID *int64 `json:"granted_group_id,omitempty"`
GrantedGroupName string `json:"granted_group_name,omitempty"`
}{
APIKey: dto.APIKeyFromService(result.APIKey),
AutoGrantedGroupAccess: result.AutoGrantedGroupAccess,
GrantedGroupID: result.GrantedGroupID,
GrantedGroupName: result.GrantedGroupName,
}
response.Success(c, resp)
}

View File

@@ -0,0 +1,202 @@
package admin
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func setupAPIKeyHandler(adminSvc service.AdminService) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
h := NewAdminAPIKeyHandler(adminSvc)
router.PUT("/api/v1/admin/api-keys/:id", h.UpdateGroup)
return router
}
func TestAdminAPIKeyHandler_UpdateGroup_InvalidID(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
body := `{"group_id": 2}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/abc", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "Invalid API key ID")
}
func TestAdminAPIKeyHandler_UpdateGroup_InvalidJSON(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{bad json`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "Invalid request")
}
func TestAdminAPIKeyHandler_UpdateGroup_KeyNotFound(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
body := `{"group_id": 2}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/999", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
// ErrAPIKeyNotFound maps to 404
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestAdminAPIKeyHandler_UpdateGroup_BindGroup(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
body := `{"group_id": 2}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Code int `json:"code"`
Data json.RawMessage `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
var data struct {
APIKey struct {
ID int64 `json:"id"`
GroupID *int64 `json:"group_id"`
} `json:"api_key"`
AutoGrantedGroupAccess bool `json:"auto_granted_group_access"`
}
require.NoError(t, json.Unmarshal(resp.Data, &data))
require.Equal(t, int64(10), data.APIKey.ID)
require.NotNil(t, data.APIKey.GroupID)
require.Equal(t, int64(2), *data.APIKey.GroupID)
}
func TestAdminAPIKeyHandler_UpdateGroup_Unbind(t *testing.T) {
svc := newStubAdminService()
gid := int64(2)
svc.apiKeys[0].GroupID = &gid
router := setupAPIKeyHandler(svc)
body := `{"group_id": 0}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data struct {
APIKey struct {
GroupID *int64 `json:"group_id"`
} `json:"api_key"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Nil(t, resp.Data.APIKey.GroupID)
}
func TestAdminAPIKeyHandler_UpdateGroup_ServiceError(t *testing.T) {
svc := &failingUpdateGroupService{
stubAdminService: newStubAdminService(),
err: errors.New("internal failure"),
}
router := setupAPIKeyHandler(svc)
body := `{"group_id": 2}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
// H2: empty body → group_id is nil → no-op, returns original key
func TestAdminAPIKeyHandler_UpdateGroup_EmptyBody_NoChange(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Code int `json:"code"`
Data struct {
APIKey struct {
ID int64 `json:"id"`
} `json:"api_key"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, int64(10), resp.Data.APIKey.ID)
}
// M2: service returns GROUP_NOT_ACTIVE → handler maps to 400
func TestAdminAPIKeyHandler_UpdateGroup_GroupNotActive(t *testing.T) {
svc := &failingUpdateGroupService{
stubAdminService: newStubAdminService(),
err: infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active"),
}
router := setupAPIKeyHandler(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"group_id": 5}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "GROUP_NOT_ACTIVE")
}
// M2: service returns INVALID_GROUP_ID → handler maps to 400
func TestAdminAPIKeyHandler_UpdateGroup_NegativeGroupID(t *testing.T) {
svc := &failingUpdateGroupService{
stubAdminService: newStubAdminService(),
err: infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative"),
}
router := setupAPIKeyHandler(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"group_id": -5}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "INVALID_GROUP_ID")
}
// failingUpdateGroupService overrides AdminUpdateAPIKeyGroupID to return an error.
type failingUpdateGroupService struct {
*stubAdminService
err error
}
func (f *failingUpdateGroupService) AdminUpdateAPIKeyGroupID(_ context.Context, _ int64, _ *int64) (*service.AdminUpdateAPIKeyGroupIDResult, error) {
return nil, f.err
}

View File

@@ -26,6 +26,7 @@ type AdminHandlers struct {
Usage *admin.UsageHandler
UserAttribute *admin.UserAttributeHandler
ErrorPassthrough *admin.ErrorPassthroughHandler
APIKey *admin.AdminAPIKeyHandler
}
// Handlers contains all HTTP handlers

View File

@@ -945,6 +945,9 @@ func (r *stubUserRepoForHandler) RemoveGroupFromAllowedGroups(context.Context, i
func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForHandler) AddGroupToAllowedGroups(context.Context, int64, int64) error {
return nil
}
// ==================== NewSoraClientHandler ====================

View File

@@ -29,6 +29,7 @@ func ProvideAdminHandlers(
usageHandler *admin.UsageHandler,
userAttributeHandler *admin.UserAttributeHandler,
errorPassthroughHandler *admin.ErrorPassthroughHandler,
apiKeyHandler *admin.AdminAPIKeyHandler,
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
@@ -51,6 +52,7 @@ func ProvideAdminHandlers(
Usage: usageHandler,
UserAttribute: userAttributeHandler,
ErrorPassthrough: errorPassthroughHandler,
APIKey: apiKeyHandler,
}
}
@@ -138,6 +140,7 @@ var ProviderSet = wire.NewSet(
admin.NewUsageHandler,
admin.NewUserAttributeHandler,
admin.NewErrorPassthroughHandler,
admin.NewAdminAPIKeyHandler,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers,

View File

@@ -171,8 +171,9 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
// 则会更新已删除的记录。
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。
// 同时显式设置 updated_at避免二次查询带来的并发可见性问题。
client := clientFromContext(ctx, r.client)
now := time.Now()
builder := r.client.APIKey.Update().
builder := client.APIKey.Update().
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
SetName(key.Name).
SetStatus(key.Status).

View File

@@ -429,6 +429,16 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool,
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
}
func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
client := clientFromContext(ctx, r.client)
return client.UserAllowedGroup.Create().
SetUserID(userID).
SetGroupID(groupID).
OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
DoNothing().
Exec(ctx)
}
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
// 仅操作 user_allowed_groups 联接表legacy users.allowed_groups 列已弃用。
affected, err := r.client.UserAllowedGroup.Delete().

View File

@@ -619,7 +619,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
@@ -779,6 +779,10 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
return 0, errors.New("not implemented")
}
func (r *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
return errors.New("not implemented")
}
func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
return errors.New("not implemented")
}

View File

@@ -181,6 +181,10 @@ func (s *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
panic("unexpected RemoveGroupFromAllowedGroups call")
}
func (s *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
panic("unexpected AddGroupToAllowedGroups call")
}
func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}

View File

@@ -75,6 +75,16 @@ func RegisterAdminRoutes(
// 错误透传规则管理
registerErrorPassthroughRoutes(admin, h)
// API Key 管理
registerAdminAPIKeyRoutes(admin, h)
}
}
func registerAdminAPIKeyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
apiKeys := admin.Group("/api-keys")
{
apiKeys.PUT("/:id", h.Admin.APIKey.UpdateGroup)
}
}

View File

@@ -9,6 +9,8 @@ import (
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -42,6 +44,9 @@ type AdminService interface {
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
// API Key management (admin)
AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error)
// Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error)
@@ -242,6 +247,14 @@ type BulkUpdateAccountResult struct {
Error string `json:"error,omitempty"`
}
// AdminUpdateAPIKeyGroupIDResult is the result of AdminUpdateAPIKeyGroupID.
type AdminUpdateAPIKeyGroupIDResult struct {
APIKey *APIKey
AutoGrantedGroupAccess bool // true if a new exclusive group permission was auto-added
GrantedGroupID *int64 // the group ID that was auto-granted
GrantedGroupName string // the group name that was auto-granted
}
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
type BulkUpdateAccountsResult struct {
Success int `json:"success"`
@@ -406,6 +419,7 @@ type adminServiceImpl struct {
proxyProber ProxyExitInfoProber
proxyLatencyCache ProxyLatencyCache
authCacheInvalidator APIKeyAuthCacheInvalidator
entClient *dbent.Client // 用于开启数据库事务
}
type userGroupRateBatchReader interface {
@@ -430,6 +444,7 @@ func NewAdminService(
proxyProber ProxyExitInfoProber,
proxyLatencyCache ProxyLatencyCache,
authCacheInvalidator APIKeyAuthCacheInvalidator,
entClient *dbent.Client,
) AdminService {
return &adminServiceImpl{
userRepo: userRepo,
@@ -444,6 +459,7 @@ func NewAdminService(
proxyProber: proxyProber,
proxyLatencyCache: proxyLatencyCache,
authCacheInvalidator: authCacheInvalidator,
entClient: entClient,
}
}
@@ -1185,6 +1201,103 @@ func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []
return s.groupRepo.UpdateSortOrders(ctx, updates)
}
// AdminUpdateAPIKeyGroupID 管理员修改 API Key 分组绑定
// groupID: nil=不修改, 指向0=解绑, 指向正整数=绑定到目标分组
func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID)
if err != nil {
return nil, err
}
if groupID == nil {
// nil 表示不修改,直接返回
return &AdminUpdateAPIKeyGroupIDResult{APIKey: apiKey}, nil
}
if *groupID < 0 {
return nil, infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative")
}
result := &AdminUpdateAPIKeyGroupIDResult{}
if *groupID == 0 {
// 0 表示解绑分组(不修改 user_allowed_groups避免影响用户其他 Key
apiKey.GroupID = nil
apiKey.Group = nil
} else {
// 验证目标分组存在且状态为 active
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return nil, err
}
if group.Status != StatusActive {
return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active")
}
// 订阅类型分组:不允许通过此 API 直接绑定,需通过订阅管理流程
if group.IsSubscriptionType() {
return nil, infraerrors.BadRequest("SUBSCRIPTION_GROUP_NOT_ALLOWED", "subscription groups must be managed through the subscription workflow")
}
gid := *groupID
apiKey.GroupID = &gid
apiKey.Group = group
// 专属标准分组:使用事务保证「添加分组权限」与「更新 API Key」的原子性
if group.IsExclusive {
opCtx := ctx
var tx *dbent.Tx
if s.entClient == nil {
logger.LegacyPrintf("service.admin", "Warning: entClient is nil, skipping transaction protection for exclusive group binding")
} else {
var txErr error
tx, txErr = s.entClient.Tx(ctx)
if txErr != nil {
return nil, fmt.Errorf("begin transaction: %w", txErr)
}
defer func() { _ = tx.Rollback() }()
opCtx = dbent.NewTxContext(ctx, tx)
}
if addErr := s.userRepo.AddGroupToAllowedGroups(opCtx, apiKey.UserID, gid); addErr != nil {
return nil, fmt.Errorf("add group to user allowed groups: %w", addErr)
}
if err := s.apiKeyRepo.Update(opCtx, apiKey); err != nil {
return nil, fmt.Errorf("update api key: %w", err)
}
if tx != nil {
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("commit transaction: %w", err)
}
}
result.AutoGrantedGroupAccess = true
result.GrantedGroupID = &gid
result.GrantedGroupName = group.Name
// 失效认证缓存(在事务提交后执行)
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key)
}
result.APIKey = apiKey
return result, nil
}
}
// 非专属分组 / 解绑:无需事务,单步更新即可
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
return nil, fmt.Errorf("update api key: %w", err)
}
// 失效认证缓存
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key)
}
result.APIKey = apiKey
return result, nil
}
// Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}

View File

@@ -0,0 +1,420 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// Stubs
// ---------------------------------------------------------------------------
// userRepoStubForGroupUpdate implements UserRepository for AdminUpdateAPIKeyGroupID tests.
type userRepoStubForGroupUpdate struct {
addGroupErr error
addGroupCalled bool
addedUserID int64
addedGroupID int64
}
func (s *userRepoStubForGroupUpdate) AddGroupToAllowedGroups(_ context.Context, userID int64, groupID int64) error {
s.addGroupCalled = true
s.addedUserID = userID
s.addedGroupID = groupID
return s.addGroupErr
}
func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests.
type apiKeyRepoStubForGroupUpdate struct {
key *APIKey
getErr error
updateErr error
updated *APIKey // captures what was passed to Update
}
func (s *apiKeyRepoStubForGroupUpdate) GetByID(_ context.Context, _ int64) (*APIKey, error) {
if s.getErr != nil {
return nil, s.getErr
}
clone := *s.key
return &clone, nil
}
func (s *apiKeyRepoStubForGroupUpdate) Update(_ context.Context, key *APIKey) error {
if s.updateErr != nil {
return s.updateErr
}
clone := *key
s.updated = &clone
return nil
}
// Unused methods panic on unexpected call.
func (s *apiKeyRepoStubForGroupUpdate) Create(context.Context, *APIKey) error { panic("unexpected") }
func (s *apiKeyRepoStubForGroupUpdate) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) GetByKey(context.Context, string) (*APIKey, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) GetByKeyForAuth(context.Context, string) (*APIKey, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
func (s *apiKeyRepoStubForGroupUpdate) ListByUserID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) CountByUserID(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) ExistsByKey(context.Context, string) (bool, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) CountByGroupID(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) ListKeysByUserID(context.Context, int64) ([]string, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) ListKeysByGroupID(context.Context, int64) ([]string, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) UpdateLastUsed(context.Context, int64, time.Time) error {
panic("unexpected")
}
// groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests.
type groupRepoStubForGroupUpdate struct {
group *Group
getErr error
lastGetByIDArg int64
}
func (s *groupRepoStubForGroupUpdate) GetByID(_ context.Context, id int64) (*Group, error) {
s.lastGetByIDArg = id
if s.getErr != nil {
return nil, s.getErr
}
clone := *s.group
return &clone, nil
}
// Unused methods panic on unexpected call.
func (s *groupRepoStubForGroupUpdate) Create(context.Context, *Group) error { panic("unexpected") }
func (s *groupRepoStubForGroupUpdate) GetByIDLite(context.Context, int64) (*Group, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) Update(context.Context, *Group) error { panic("unexpected") }
func (s *groupRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
func (s *groupRepoStubForGroupUpdate) DeleteCascade(context.Context, int64) ([]int64, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) ListActive(context.Context) ([]Group, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, string) ([]Group, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) BindAccountsToGroup(context.Context, int64, []int64) error {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) UpdateSortOrders(context.Context, []GroupSortOrderUpdate) error {
panic("unexpected")
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
func TestAdminService_AdminUpdateAPIKeyGroupID_KeyNotFound(t *testing.T) {
repo := &apiKeyRepoStubForGroupUpdate{getErr: ErrAPIKeyNotFound}
svc := &adminServiceImpl{apiKeyRepo: repo}
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 999, int64Ptr(1))
require.ErrorIs(t, err, ErrAPIKeyNotFound)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_NilGroupID_NoOp(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(5)}
repo := &apiKeyRepoStubForGroupUpdate{key: existing}
svc := &adminServiceImpl{apiKeyRepo: repo}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, nil)
require.NoError(t, err)
require.Equal(t, int64(1), got.APIKey.ID)
// Update should NOT have been called (updated stays nil)
require.Nil(t, repo.updated)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_Unbind(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(5), Group: &Group{ID: 5, Name: "Old"}}
repo := &apiKeyRepoStubForGroupUpdate{key: existing}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: repo, authCacheInvalidator: cache}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0))
require.NoError(t, err)
require.Nil(t, got.APIKey.GroupID, "group_id should be nil after unbind")
require.Nil(t, got.APIKey.Group, "group object should be nil after unbind")
require.NotNil(t, repo.updated, "Update should have been called")
require.Nil(t, repo.updated.GroupID)
require.Equal(t, []string{"sk-test"}, cache.keys, "cache should be invalidated")
}
func TestAdminService_AdminUpdateAPIKeyGroupID_BindActiveGroup(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.NoError(t, err)
require.NotNil(t, got.APIKey.GroupID)
require.Equal(t, int64(10), *got.APIKey.GroupID)
require.Equal(t, int64(10), *apiKeyRepo.updated.GroupID)
require.Equal(t, []string{"sk-test"}, cache.keys)
// M3: verify correct group ID was passed to repo
require.Equal(t, int64(10), groupRepo.lastGetByIDArg)
// C1 fix: verify Group object is populated
require.NotNil(t, got.APIKey.Group)
require.Equal(t, "Pro", got.APIKey.Group.Name)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_SameGroup_Idempotent(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(10), Group: &Group{ID: 10, Name: "Pro"}}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.NoError(t, err)
require.NotNil(t, got.APIKey.GroupID)
require.Equal(t, int64(10), *got.APIKey.GroupID)
// Update is still called (current impl doesn't short-circuit on same group)
require.NotNil(t, apiKeyRepo.updated)
require.Equal(t, []string{"sk-test"}, cache.keys)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_GroupNotFound(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test"}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{getErr: ErrGroupNotFound}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo}
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(99))
require.ErrorIs(t, err, ErrGroupNotFound)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_GroupNotActive(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test"}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 5, Status: StatusDisabled}}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo}
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(5))
require.Error(t, err)
require.Equal(t, "GROUP_NOT_ACTIVE", infraerrors.Reason(err))
}
func TestAdminService_AdminUpdateAPIKeyGroupID_UpdateFails(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(3)}
repo := &apiKeyRepoStubForGroupUpdate{key: existing, updateErr: errors.New("db write error")}
svc := &adminServiceImpl{apiKeyRepo: repo}
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0))
require.Error(t, err)
require.Contains(t, err.Error(), "update api key")
}
func TestAdminService_AdminUpdateAPIKeyGroupID_NegativeGroupID(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test"}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo}
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(-5))
require.Error(t, err)
require.Equal(t, "INVALID_GROUP_ID", infraerrors.Reason(err))
}
func TestAdminService_AdminUpdateAPIKeyGroupID_PointerIsolation(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache}
inputGID := int64(10)
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, &inputGID)
require.NoError(t, err)
require.NotNil(t, got.APIKey.GroupID)
// Mutating the input pointer must NOT affect the stored value
inputGID = 999
require.Equal(t, int64(10), *got.APIKey.GroupID)
require.Equal(t, int64(10), *apiKeyRepo.updated.GroupID)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_NilCacheInvalidator(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test"}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 7, Status: StatusActive}}
// authCacheInvalidator is nil should not panic
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(7))
require.NoError(t, err)
require.NotNil(t, got.APIKey.GroupID)
require.Equal(t, int64(7), *got.APIKey.GroupID)
}
// ---------------------------------------------------------------------------
// Tests: AllowedGroup auto-sync
// ---------------------------------------------------------------------------
func TestAdminService_AdminUpdateAPIKeyGroupID_ExclusiveGroup_AddsAllowedGroup(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Exclusive", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeStandard}}
userRepo := &userRepoStubForGroupUpdate{}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, authCacheInvalidator: cache}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.NoError(t, err)
require.NotNil(t, got.APIKey.GroupID)
require.Equal(t, int64(10), *got.APIKey.GroupID)
// 验证 AddGroupToAllowedGroups 被调用,且参数正确
require.True(t, userRepo.addGroupCalled)
require.Equal(t, int64(42), userRepo.addedUserID)
require.Equal(t, int64(10), userRepo.addedGroupID)
// 验证 result 标记了自动授权
require.True(t, got.AutoGrantedGroupAccess)
require.NotNil(t, got.GrantedGroupID)
require.Equal(t, int64(10), *got.GrantedGroupID)
require.Equal(t, "Exclusive", got.GrantedGroupName)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_NonExclusiveGroup_NoAllowedGroupUpdate(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Public", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeStandard}}
userRepo := &userRepoStubForGroupUpdate{}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, authCacheInvalidator: cache}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.NoError(t, err)
require.NotNil(t, got.APIKey.GroupID)
// 非专属分组不触发 AddGroupToAllowedGroups
require.False(t, userRepo.addGroupCalled)
require.False(t, got.AutoGrantedGroupAccess)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}}
userRepo := &userRepoStubForGroupUpdate{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo}
// 订阅类型分组应被阻止绑定
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.Error(t, err)
require.Equal(t, "SUBSCRIPTION_GROUP_NOT_ALLOWED", infraerrors.Reason(err))
require.False(t, userRepo.addGroupCalled)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_ExclusiveGroup_AllowedGroupAddFails_ReturnsError(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Exclusive", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeStandard}}
userRepo := &userRepoStubForGroupUpdate{addGroupErr: errors.New("db error")}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo}
// 严格模式AddGroupToAllowedGroups 失败时,整体操作报错
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.Error(t, err)
require.Contains(t, err.Error(), "add group to user allowed groups")
require.True(t, userRepo.addGroupCalled)
// apiKey 不应被更新
require.Nil(t, apiKeyRepo.updated)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_Unbind_NoAllowedGroupUpdate(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: int64Ptr(10), Group: &Group{ID: 10, Name: "Exclusive"}}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
userRepo := &userRepoStubForGroupUpdate{}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, userRepo: userRepo, authCacheInvalidator: cache}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0))
require.NoError(t, err)
require.Nil(t, got.APIKey.GroupID)
// 解绑时不修改 allowed_groups
require.False(t, userRepo.addGroupCalled)
require.False(t, got.AutoGrantedGroupAccess)
}

View File

@@ -93,6 +93,10 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
panic("unexpected RemoveGroupFromAllowedGroups call")
}
func (s *userRepoStub) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
panic("unexpected AddGroupToAllowedGroups call")
}
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}

View File

@@ -165,6 +165,9 @@ func (r *stubUserRepoForQuota) RemoveGroupFromAllowedGroups(context.Context, int
func (r *stubUserRepoForQuota) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (r *stubUserRepoForQuota) EnableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForQuota) DisableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForQuota) AddGroupToAllowedGroups(context.Context, int64, int64) error {
return nil
}
// ==================== 辅助函数:构造带 CDN 缓存的 SoraS3Storage ====================

View File

@@ -40,6 +40,8 @@ type UserRepository interface {
UpdateConcurrency(ctx context.Context, id int64, amount int) error
ExistsByEmail(ctx context.Context, email string) (bool, error)
RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error)
// AddGroupToAllowedGroups 将指定分组增量添加到用户的 allowed_groups幂等冲突忽略
AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error
// TOTP 双因素认证
UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error

View File

@@ -45,7 +45,8 @@ func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { re
func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }