同步上游至最新版本并重新应用自定义配置
Some checks failed
CI / test (push) Has been cancelled
CI / golangci-lint (push) Has been cancelled
Security Scan / backend-security (push) Has been cancelled
Security Scan / frontend-security (push) Has been cancelled

上游新增功能:
- 双模式用户消息队列(串行队列 + 软性限速)
- 自定义菜单页面(iframe嵌入 + CSP注入)
- 代理URL集中验证与全局fail-fast
- 新用户默认订阅设置
- 指纹缓存TTL懒续期机制
- 分组用量分布图表
- 代理密码可见性 + 复制代理URL
- 大量bug修复和性能优化

自定义配置保留:
- 品牌化:Sub2API → StarFireAPI
- 链接:GitHub → anthropic.edu.pl 官网
- docker-compose:starfireapi镜像、端口6580、外部Redis、项目名xinghuoapi
- 更新模块禁用:后端始终返回无更新、前端隐藏更新UI

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-03 10:15:24 +08:00
780 changed files with 145032 additions and 40124 deletions

View File

@@ -175,22 +175,28 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
return
}
dataPayload := req.Data
if err := validateDataHeader(dataPayload); err != nil {
if err := validateDataHeader(req.Data); err != nil {
response.BadRequest(c, err.Error())
return
}
executeAdminIdempotentJSON(c, "admin.accounts.import_data", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
return h.importData(ctx, req)
})
}
func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest) (DataImportResult, error) {
skipDefaultGroupBind := true
if req.SkipDefaultGroupBind != nil {
skipDefaultGroupBind = *req.SkipDefaultGroupBind
}
dataPayload := req.Data
result := DataImportResult{}
existingProxies, err := h.listAllProxies(c.Request.Context())
existingProxies, err := h.listAllProxies(ctx)
if err != nil {
response.ErrorFrom(c, err)
return
return result, err
}
proxyKeyToID := make(map[string]int64, len(existingProxies))
@@ -221,8 +227,8 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
proxyKeyToID[key] = existingID
result.ProxyReused++
if normalizedStatus != "" {
if proxy, err := h.adminService.GetProxy(c.Request.Context(), existingID); err == nil && proxy != nil && proxy.Status != normalizedStatus {
_, _ = h.adminService.UpdateProxy(c.Request.Context(), existingID, &service.UpdateProxyInput{
if proxy, getErr := h.adminService.GetProxy(ctx, existingID); getErr == nil && proxy != nil && proxy.Status != normalizedStatus {
_, _ = h.adminService.UpdateProxy(ctx, existingID, &service.UpdateProxyInput{
Status: normalizedStatus,
})
}
@@ -230,7 +236,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
continue
}
created, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
created, createErr := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{
Name: defaultProxyName(item.Name),
Protocol: item.Protocol,
Host: item.Host,
@@ -238,13 +244,13 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
Username: item.Username,
Password: item.Password,
})
if err != nil {
if createErr != nil {
result.ProxyFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "proxy",
Name: item.Name,
ProxyKey: key,
Message: err.Error(),
Message: createErr.Error(),
})
continue
}
@@ -252,7 +258,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
result.ProxyCreated++
if normalizedStatus != "" && normalizedStatus != created.Status {
_, _ = h.adminService.UpdateProxy(c.Request.Context(), created.ID, &service.UpdateProxyInput{
_, _ = h.adminService.UpdateProxy(ctx, created.ID, &service.UpdateProxyInput{
Status: normalizedStatus,
})
}
@@ -303,7 +309,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
SkipDefaultGroupBind: skipDefaultGroupBind,
}
if _, err := h.adminService.CreateAccount(c.Request.Context(), accountInput); err != nil {
if _, err := h.adminService.CreateAccount(ctx, accountInput); err != nil {
result.AccountFailed++
result.Errors = append(result.Errors, DataImportError{
Kind: "account",
@@ -315,7 +321,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) {
result.AccountCreated++
}
response.Success(c, result)
return result, nil
}
func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, error) {
@@ -341,7 +347,7 @@ func (h *AccountHandler) listAccountsFiltered(ctx context.Context, platform, acc
pageSize := dataPageCap
var out []service.Account
for {
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search)
items, total, err := h.adminService.ListAccounts(ctx, page, pageSize, platform, accountType, status, search, 0)
if err != nil {
return nil, err
}

View File

@@ -64,6 +64,7 @@ func setupAccountDataRouter() (*gin.Engine, *stubAdminService) {
nil,
nil,
nil,
nil,
)
router.GET("/api/v1/admin/accounts/data", h.ExportData)

View File

@@ -2,7 +2,13 @@
package admin
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
@@ -10,6 +16,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
@@ -46,6 +53,7 @@ type AccountHandler struct {
concurrencyService *service.ConcurrencyService
crsSyncService *service.CRSSyncService
sessionLimitCache service.SessionLimitCache
rpmCache service.RPMCache
tokenCacheInvalidator service.TokenCacheInvalidator
}
@@ -62,6 +70,7 @@ func NewAccountHandler(
concurrencyService *service.ConcurrencyService,
crsSyncService *service.CRSSyncService,
sessionLimitCache service.SessionLimitCache,
rpmCache service.RPMCache,
tokenCacheInvalidator service.TokenCacheInvalidator,
) *AccountHandler {
return &AccountHandler{
@@ -76,6 +85,7 @@ func NewAccountHandler(
concurrencyService: concurrencyService,
crsSyncService: crsSyncService,
sessionLimitCache: sessionLimitCache,
rpmCache: rpmCache,
tokenCacheInvalidator: tokenCacheInvalidator,
}
}
@@ -133,6 +143,13 @@ type BulkUpdateAccountsRequest struct {
ConfirmMixedChannelRisk *bool `json:"confirm_mixed_channel_risk"` // 用户确认混合渠道风险
}
// CheckMixedChannelRequest represents check mixed channel risk request
type CheckMixedChannelRequest struct {
Platform string `json:"platform" binding:"required"`
GroupIDs []int64 `json:"group_ids"`
AccountID *int64 `json:"account_id"`
}
// AccountWithConcurrency extends Account with real-time concurrency info
type AccountWithConcurrency struct {
*dto.Account
@@ -140,6 +157,51 @@ type AccountWithConcurrency struct {
// 以下字段仅对 Anthropic OAuth/SetupToken 账号有效,且仅在启用相应功能时返回
CurrentWindowCost *float64 `json:"current_window_cost,omitempty"` // 当前窗口费用
ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数
CurrentRPM *int `json:"current_rpm,omitempty"` // 当前分钟 RPM 计数
}
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
item := AccountWithConcurrency{
Account: dto.AccountFromService(account),
CurrentConcurrency: 0,
}
if account == nil {
return item
}
if h.concurrencyService != nil {
if counts, err := h.concurrencyService.GetAccountConcurrencyBatch(ctx, []int64{account.ID}); err == nil {
item.CurrentConcurrency = counts[account.ID]
}
}
if account.IsAnthropicOAuthOrSetupToken() {
if h.accountUsageService != nil && account.GetWindowCostLimit() > 0 {
startTime := account.GetCurrentWindowStartTime()
if stats, err := h.accountUsageService.GetAccountWindowStats(ctx, account.ID, startTime); err == nil && stats != nil {
cost := stats.StandardCost
item.CurrentWindowCost = &cost
}
}
if h.sessionLimitCache != nil && account.GetMaxSessions() > 0 {
idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute
idleTimeouts := map[int64]time.Duration{account.ID: idleTimeout}
if sessions, err := h.sessionLimitCache.GetActiveSessionCountBatch(ctx, []int64{account.ID}, idleTimeouts); err == nil {
if count, ok := sessions[account.ID]; ok {
item.ActiveSessions = &count
}
}
}
if h.rpmCache != nil && account.GetBaseRPM() > 0 {
if rpm, err := h.rpmCache.GetRPM(ctx, account.ID); err == nil {
item.CurrentRPM = &rpm
}
}
}
return item
}
// List handles listing all accounts with pagination
@@ -156,7 +218,12 @@ func (h *AccountHandler) List(c *gin.Context) {
search = search[:100]
}
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
var groupID int64
if groupIDStr := c.Query("group"); groupIDStr != "" {
groupID, _ = strconv.ParseInt(groupIDStr, 10, 64)
}
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search, groupID)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -174,9 +241,10 @@ func (h *AccountHandler) List(c *gin.Context) {
concurrencyCounts = make(map[int64]int)
}
// 识别需要查询窗口费用会话数的账号Anthropic OAuth/SetupToken 且启用了相应功能)
// 识别需要查询窗口费用会话数和 RPM 的账号Anthropic OAuth/SetupToken 且启用了相应功能)
windowCostAccountIDs := make([]int64, 0)
sessionLimitAccountIDs := make([]int64, 0)
rpmAccountIDs := make([]int64, 0)
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
for i := range accounts {
acc := &accounts[i]
@@ -188,12 +256,24 @@ func (h *AccountHandler) List(c *gin.Context) {
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
}
if acc.GetBaseRPM() > 0 {
rpmAccountIDs = append(rpmAccountIDs, acc.ID)
}
}
}
// 并行获取窗口费用活跃会话数
// 并行获取窗口费用活跃会话数和 RPM 计数
var windowCosts map[int64]float64
var activeSessions map[int64]int
var rpmCounts map[int64]int
// 获取 RPM 计数(批量查询)
if len(rpmAccountIDs) > 0 && h.rpmCache != nil {
rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs)
if rpmCounts == nil {
rpmCounts = make(map[int64]int)
}
}
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
@@ -254,12 +334,81 @@ func (h *AccountHandler) List(c *gin.Context) {
}
}
// 添加 RPM 计数(仅当启用时)
if rpmCounts != nil {
if rpm, ok := rpmCounts[acc.ID]; ok {
item.CurrentRPM = &rpm
}
}
result[i] = item
}
etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search)
if etag != "" {
c.Header("ETag", etag)
c.Header("Vary", "If-None-Match")
if ifNoneMatchMatched(c.GetHeader("If-None-Match"), etag) {
c.Status(http.StatusNotModified)
return
}
}
response.Paginated(c, result, total, page, pageSize)
}
func buildAccountsListETag(
items []AccountWithConcurrency,
total int64,
page, pageSize int,
platform, accountType, status, search string,
) string {
payload := struct {
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
Platform string `json:"platform"`
AccountType string `json:"type"`
Status string `json:"status"`
Search string `json:"search"`
Items []AccountWithConcurrency `json:"items"`
}{
Total: total,
Page: page,
PageSize: pageSize,
Platform: platform,
AccountType: accountType,
Status: status,
Search: search,
Items: items,
}
raw, err := json.Marshal(payload)
if err != nil {
return ""
}
sum := sha256.Sum256(raw)
return "\"" + hex.EncodeToString(sum[:]) + "\""
}
func ifNoneMatchMatched(ifNoneMatch, etag string) bool {
if etag == "" || ifNoneMatch == "" {
return false
}
for _, token := range strings.Split(ifNoneMatch, ",") {
candidate := strings.TrimSpace(token)
if candidate == "*" {
return true
}
if candidate == etag {
return true
}
if strings.HasPrefix(candidate, "W/") && strings.TrimPrefix(candidate, "W/") == etag {
return true
}
}
return false
}
// GetByID handles getting an account by ID
// GET /api/v1/admin/accounts/:id
func (h *AccountHandler) GetByID(c *gin.Context) {
@@ -275,7 +424,51 @@ func (h *AccountHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.AccountFromService(account))
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// CheckMixedChannel handles checking mixed channel risk for account-group binding.
// POST /api/v1/admin/accounts/check-mixed-channel
func (h *AccountHandler) CheckMixedChannel(c *gin.Context) {
var req CheckMixedChannelRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.GroupIDs) == 0 {
response.Success(c, gin.H{"has_risk": false})
return
}
accountID := int64(0)
if req.AccountID != nil {
accountID = *req.AccountID
}
err := h.adminService.CheckMixedChannelRisk(c.Request.Context(), accountID, req.Platform, req.GroupIDs)
if err != nil {
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
response.Success(c, gin.H{
"has_risk": true,
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
})
return
}
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"has_risk": false})
}
// Create handles creating a new account
@@ -290,50 +483,57 @@ func (h *AccountHandler) Create(c *gin.Context) {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
// base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM(req.Extra)
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: req.Name,
Notes: req.Notes,
Platform: req.Platform,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
RateMultiplier: req.RateMultiplier,
GroupIDs: req.GroupIDs,
ExpiresAt: req.ExpiresAt,
AutoPauseOnExpired: req.AutoPauseOnExpired,
SkipMixedChannelCheck: skipCheck,
result, err := executeAdminIdempotent(c, "admin.accounts.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
account, execErr := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
Name: req.Name,
Notes: req.Notes,
Platform: req.Platform,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
RateMultiplier: req.RateMultiplier,
GroupIDs: req.GroupIDs,
ExpiresAt: req.ExpiresAt,
AutoPauseOnExpired: req.AutoPauseOnExpired,
SkipMixedChannelCheck: skipCheck,
})
if execErr != nil {
return nil, execErr
}
return h.buildAccountResponseWithRuntime(ctx, account), nil
})
if err != nil {
// 检查是否为混合渠道错误
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
// 返回特殊错误码要求确认
// 创建接口仅返回最小必要字段,详细信息由专门检查接口提供
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
"require_confirmation": true,
})
return
}
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.AccountFromService(account))
if result != nil && result.Replayed {
c.Header("X-Idempotency-Replayed", "true")
}
response.Success(c, result.Data)
}
// Update handles updating an account
@@ -354,6 +554,8 @@ func (h *AccountHandler) Update(c *gin.Context) {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
// base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM(req.Extra)
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -378,17 +580,10 @@ func (h *AccountHandler) Update(c *gin.Context) {
// 检查是否为混合渠道错误
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
// 返回特殊错误码要求确认
// 更新接口仅返回最小必要字段,详细信息由专门检查接口提供
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
"require_confirmation": true,
})
return
}
@@ -397,7 +592,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
return
}
response.Success(c, dto.AccountFromService(account))
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// Delete handles deleting an account
@@ -424,10 +619,17 @@ type TestAccountRequest struct {
}
type SyncFromCRSRequest struct {
BaseURL string `json:"base_url" binding:"required"`
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
SyncProxies *bool `json:"sync_proxies"`
BaseURL string `json:"base_url" binding:"required"`
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
SyncProxies *bool `json:"sync_proxies"`
SelectedAccountIDs []string `json:"selected_account_ids"`
}
type PreviewFromCRSRequest struct {
BaseURL string `json:"base_url" binding:"required"`
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
}
// Test handles testing account connectivity with SSE streaming
@@ -466,10 +668,11 @@ func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
}
result, err := h.crsSyncService.SyncFromCRS(c.Request.Context(), service.SyncFromCRSInput{
BaseURL: req.BaseURL,
Username: req.Username,
Password: req.Password,
SyncProxies: syncProxies,
BaseURL: req.BaseURL,
Username: req.Username,
Password: req.Password,
SyncProxies: syncProxies,
SelectedAccountIDs: req.SelectedAccountIDs,
})
if err != nil {
// Provide detailed error message for CRS sync failures
@@ -480,6 +683,28 @@ func (h *AccountHandler) SyncFromCRS(c *gin.Context) {
response.Success(c, result)
}
// PreviewFromCRS handles previewing accounts from CRS before sync
// POST /api/v1/admin/accounts/sync/crs/preview
func (h *AccountHandler) PreviewFromCRS(c *gin.Context) {
var req PreviewFromCRSRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
result, err := h.crsSyncService.PreviewFromCRS(c.Request.Context(), service.SyncFromCRSInput{
BaseURL: req.BaseURL,
Username: req.Username,
Password: req.Password,
})
if err != nil {
response.InternalError(c, "CRS preview failed: "+err.Error())
return
}
response.Success(c, result)
}
// Refresh handles refreshing account credentials
// POST /api/v1/admin/accounts/:id/refresh
func (h *AccountHandler) Refresh(c *gin.Context) {
@@ -625,7 +850,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
}
}
response.Success(c, dto.AccountFromService(updatedAccount))
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount))
}
// GetStats handles getting account statistics
@@ -683,7 +908,7 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
}
}
response.Success(c, dto.AccountFromService(account))
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// BatchCreate handles batch creating accounts
@@ -697,61 +922,65 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
return
}
ctx := c.Request.Context()
success := 0
failed := 0
results := make([]gin.H, 0, len(req.Accounts))
executeAdminIdempotentJSON(c, "admin.accounts.batch_create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
success := 0
failed := 0
results := make([]gin.H, 0, len(req.Accounts))
for _, item := range req.Accounts {
if item.RateMultiplier != nil && *item.RateMultiplier < 0 {
failed++
for _, item := range req.Accounts {
if item.RateMultiplier != nil && *item.RateMultiplier < 0 {
failed++
results = append(results, gin.H{
"name": item.Name,
"success": false,
"error": "rate_multiplier must be >= 0",
})
continue
}
// base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM(item.Extra)
skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk
account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
Name: item.Name,
Notes: item.Notes,
Platform: item.Platform,
Type: item.Type,
Credentials: item.Credentials,
Extra: item.Extra,
ProxyID: item.ProxyID,
Concurrency: item.Concurrency,
Priority: item.Priority,
RateMultiplier: item.RateMultiplier,
GroupIDs: item.GroupIDs,
ExpiresAt: item.ExpiresAt,
AutoPauseOnExpired: item.AutoPauseOnExpired,
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
failed++
results = append(results, gin.H{
"name": item.Name,
"success": false,
"error": err.Error(),
})
continue
}
success++
results = append(results, gin.H{
"name": item.Name,
"success": false,
"error": "rate_multiplier must be >= 0",
"id": account.ID,
"success": true,
})
continue
}
skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk
account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
Name: item.Name,
Notes: item.Notes,
Platform: item.Platform,
Type: item.Type,
Credentials: item.Credentials,
Extra: item.Extra,
ProxyID: item.ProxyID,
Concurrency: item.Concurrency,
Priority: item.Priority,
RateMultiplier: item.RateMultiplier,
GroupIDs: item.GroupIDs,
ExpiresAt: item.ExpiresAt,
AutoPauseOnExpired: item.AutoPauseOnExpired,
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
failed++
results = append(results, gin.H{
"name": item.Name,
"success": false,
"error": err.Error(),
})
continue
}
success++
results = append(results, gin.H{
"name": item.Name,
"id": account.ID,
"success": true,
})
}
response.Success(c, gin.H{
"success": success,
"failed": failed,
"results": results,
return gin.H{
"success": success,
"failed": failed,
"results": results,
}, nil
})
}
@@ -789,57 +1018,58 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) {
}
ctx := c.Request.Context()
success := 0
failed := 0
results := []gin.H{}
// 阶段一:预验证所有账号存在,收集 credentials
type accountUpdate struct {
ID int64
Credentials map[string]any
}
updates := make([]accountUpdate, 0, len(req.AccountIDs))
for _, accountID := range req.AccountIDs {
// Get account
account, err := h.adminService.GetAccount(ctx, accountID)
if err != nil {
failed++
results = append(results, gin.H{
"account_id": accountID,
"success": false,
"error": "Account not found",
})
continue
response.Error(c, 404, fmt.Sprintf("Account %d not found", accountID))
return
}
// Update credentials field
if account.Credentials == nil {
account.Credentials = make(map[string]any)
}
account.Credentials[req.Field] = req.Value
updates = append(updates, accountUpdate{ID: accountID, Credentials: account.Credentials})
}
// Update account
updateInput := &service.UpdateAccountInput{
Credentials: account.Credentials,
}
_, err = h.adminService.UpdateAccount(ctx, accountID, updateInput)
if err != nil {
// 阶段二:依次更新,返回每个账号的成功/失败明细,便于调用方重试
success := 0
failed := 0
successIDs := make([]int64, 0, len(updates))
failedIDs := make([]int64, 0, len(updates))
results := make([]gin.H, 0, len(updates))
for _, u := range updates {
updateInput := &service.UpdateAccountInput{Credentials: u.Credentials}
if _, err := h.adminService.UpdateAccount(ctx, u.ID, updateInput); err != nil {
failed++
failedIDs = append(failedIDs, u.ID)
results = append(results, gin.H{
"account_id": accountID,
"account_id": u.ID,
"success": false,
"error": err.Error(),
})
continue
}
success++
successIDs = append(successIDs, u.ID)
results = append(results, gin.H{
"account_id": accountID,
"account_id": u.ID,
"success": true,
})
}
response.Success(c, gin.H{
"success": success,
"failed": failed,
"results": results,
"success": success,
"failed": failed,
"success_ids": successIDs,
"failed_ids": failedIDs,
"results": results,
})
}
@@ -855,6 +1085,8 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
// base_rpm 输入校验:负值归零,超过 10000 截断
sanitizeExtraBaseRPM(req.Extra)
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -890,6 +1122,14 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
var mixedErr *service.MixedChannelError
if errors.As(err, &mixedErr) {
c.JSON(409, gin.H{
"error": "mixed_channel_warning",
"message": mixedErr.Error(),
})
return
}
response.ErrorFrom(c, err)
return
}
@@ -1074,7 +1314,13 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
return
}
response.Success(c, gin.H{"message": "Rate limit cleared successfully"})
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// GetTempUnschedulable handles getting temporary unschedulable status
@@ -1138,6 +1384,34 @@ func (h *AccountHandler) GetTodayStats(c *gin.Context) {
response.Success(c, stats)
}
// BatchTodayStatsRequest 批量今日统计请求体。
type BatchTodayStatsRequest struct {
AccountIDs []int64 `json:"account_ids" binding:"required"`
}
// GetBatchTodayStats 批量获取多个账号的今日统计。
// POST /api/v1/admin/accounts/today-stats/batch
func (h *AccountHandler) GetBatchTodayStats(c *gin.Context) {
var req BatchTodayStatsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.AccountIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]any{}})
return
}
stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), req.AccountIDs)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"stats": stats})
}
// SetSchedulableRequest represents the request body for setting schedulable status
type SetSchedulableRequest struct {
Schedulable bool `json:"schedulable"`
@@ -1164,7 +1438,7 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) {
return
}
response.Success(c, dto.AccountFromService(account))
response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account))
}
// GetAvailableModels handles getting available models for an account
@@ -1261,32 +1535,14 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
// Handle Antigravity accounts: return Claude + Gemini models
if account.Platform == service.PlatformAntigravity {
// Antigravity 支持 Claude 和部分 Gemini 模型
type UnifiedModel struct {
ID string `json:"id"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
}
// 直接复用 antigravity.DefaultModels(),与 /v1/models 端点保持同步
response.Success(c, antigravity.DefaultModels())
return
}
var models []UnifiedModel
// 添加 Claude 模型
for _, m := range claude.DefaultModels {
models = append(models, UnifiedModel{
ID: m.ID,
Type: m.Type,
DisplayName: m.DisplayName,
})
}
// 添加 Gemini 3 系列模型用于测试
geminiTestModels := []UnifiedModel{
{ID: "gemini-3-flash", Type: "model", DisplayName: "Gemini 3 Flash"},
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview"},
}
models = append(models, geminiTestModels...)
response.Success(c, models)
// Handle Sora accounts
if account.Platform == service.PlatformSora {
response.Success(c, service.DefaultSoraModels(nil))
return
}
@@ -1399,7 +1655,7 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
accounts := make([]*service.Account, 0)
if len(req.AccountIDs) == 0 {
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "")
allAccounts, _, err := h.adminService.ListAccounts(ctx, 1, 10000, "gemini", "oauth", "", "", 0)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -1497,3 +1753,22 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
func (h *AccountHandler) GetAntigravityDefaultModelMapping(c *gin.Context) {
response.Success(c, domain.DefaultAntigravityModelMapping)
}
// sanitizeExtraBaseRPM 对 extra map 中的 base_rpm 值进行范围校验和归一化。
// 负值归零,超过 10000 截断为 10000。extra 为 nil 或不含 base_rpm 时无操作。
func sanitizeExtraBaseRPM(extra map[string]any) {
if extra == nil {
return
}
raw, ok := extra["base_rpm"]
if !ok {
return
}
v := service.ParseExtraInt(raw)
if v < 0 {
v = 0
} else if v > 10000 {
v = 10000
}
extra["base_rpm"] = v
}

View File

@@ -0,0 +1,198 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func setupAccountMixedChannelRouter(adminSvc *stubAdminService) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/check-mixed-channel", accountHandler.CheckMixedChannel)
router.POST("/api/v1/admin/accounts", accountHandler.Create)
router.PUT("/api/v1/admin/accounts/:id", accountHandler.Update)
router.POST("/api/v1/admin/accounts/bulk-update", accountHandler.BulkUpdate)
return router
}
func TestAccountHandlerCheckMixedChannelNoRisk(t *testing.T) {
adminSvc := newStubAdminService()
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"platform": "antigravity",
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, float64(0), resp["code"])
data, ok := resp["data"].(map[string]any)
require.True(t, ok)
require.Equal(t, false, data["has_risk"])
require.Equal(t, int64(0), adminSvc.lastMixedCheck.accountID)
require.Equal(t, "antigravity", adminSvc.lastMixedCheck.platform)
require.Equal(t, []int64{27}, adminSvc.lastMixedCheck.groupIDs)
}
func TestAccountHandlerCheckMixedChannelWithRisk(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.checkMixedErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"platform": "antigravity",
"group_ids": []int64{27},
"account_id": 99,
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/check-mixed-channel", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, float64(0), resp["code"])
data, ok := resp["data"].(map[string]any)
require.True(t, ok)
require.Equal(t, true, data["has_risk"])
require.Equal(t, "mixed_channel_warning", data["error"])
details, ok := data["details"].(map[string]any)
require.True(t, ok)
require.Equal(t, float64(27), details["group_id"])
require.Equal(t, "claude-max", details["group_name"])
require.Equal(t, "Antigravity", details["current_platform"])
require.Equal(t, "Anthropic", details["other_platform"])
require.Equal(t, int64(99), adminSvc.lastMixedCheck.accountID)
}
func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.createAccountErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"name": "ag-oauth-1",
"platform": "antigravity",
"type": "oauth",
"credentials": map[string]any{"refresh_token": "rt"},
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusConflict, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "mixed_channel_warning", resp["error"])
require.Contains(t, resp["message"], "mixed_channel_warning")
_, hasDetails := resp["details"]
_, hasRequireConfirmation := resp["require_confirmation"]
require.False(t, hasDetails)
require.False(t, hasRequireConfirmation)
}
func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.updateAccountErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/accounts/3", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusConflict, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "mixed_channel_warning", resp["error"])
require.Contains(t, resp["message"], "mixed_channel_warning")
_, hasDetails := resp["details"]
_, hasRequireConfirmation := resp["require_confirmation"]
require.False(t, hasDetails)
require.False(t, hasRequireConfirmation)
}
func TestAccountHandlerBulkUpdateMixedChannelConflict(t *testing.T) {
adminSvc := newStubAdminService()
adminSvc.bulkUpdateAccountErr = &service.MixedChannelError{
GroupID: 27,
GroupName: "claude-max",
CurrentPlatform: "Antigravity",
OtherPlatform: "Anthropic",
}
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1, 2, 3},
"group_ids": []int64{27},
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusConflict, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, "mixed_channel_warning", resp["error"])
require.Contains(t, resp["message"], "claude-max")
}
func TestAccountHandlerBulkUpdateMixedChannelConfirmSkips(t *testing.T) {
adminSvc := newStubAdminService()
router := setupAccountMixedChannelRouter(adminSvc)
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1, 2},
"group_ids": []int64{27},
"confirm_mixed_channel_risk": true,
})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, float64(0), resp["code"])
data, ok := resp["data"].(map[string]any)
require.True(t, ok)
require.Equal(t, float64(2), data["success"])
require.Equal(t, float64(0), data["failed"])
}

View File

@@ -0,0 +1,67 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestAccountHandler_Create_AnthropicAPIKeyPassthroughExtraForwarded(t *testing.T) {
gin.SetMode(gin.TestMode)
adminSvc := newStubAdminService()
handler := NewAccountHandler(
adminSvc,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
nil,
)
router := gin.New()
router.POST("/api/v1/admin/accounts", handler.Create)
body := map[string]any{
"name": "anthropic-key-1",
"platform": "anthropic",
"type": "apikey",
"credentials": map[string]any{
"api_key": "sk-ant-xxx",
"base_url": "https://api.anthropic.com",
},
"extra": map[string]any{
"anthropic_passthrough": true,
},
"concurrency": 1,
"priority": 1,
}
raw, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(raw))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Len(t, adminSvc.createdAccounts, 1)
created := adminSvc.createdAccounts[0]
require.Equal(t, "anthropic", created.Platform)
require.Equal(t, "apikey", created.Type)
require.NotNil(t, created.Extra)
require.Equal(t, true, created.Extra["anthropic_passthrough"])
}

View File

@@ -16,10 +16,10 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
router := gin.New()
adminSvc := newStubAdminService()
userHandler := NewUserHandler(adminSvc)
userHandler := NewUserHandler(adminSvc, nil)
groupHandler := NewGroupHandler(adminSvc)
proxyHandler := NewProxyHandler(adminSvc)
redeemHandler := NewRedeemHandler(adminSvc)
redeemHandler := NewRedeemHandler(adminSvc, nil)
router.GET("/api/v1/admin/users", userHandler.List)
router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
@@ -47,6 +47,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete)
router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete)
router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test)
router.POST("/api/v1/admin/proxies/:id/quality-check", proxyHandler.CheckQuality)
router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats)
router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts)
@@ -208,6 +209,11 @@ func TestProxyHandlerEndpoints(t *testing.T) {
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/quality-check", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
rec = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil)
router.ServeHTTP(rec, req)

View File

@@ -58,6 +58,96 @@ func TestParseOpsDuration(t *testing.T) {
require.False(t, ok)
}
func TestParseOpsOpenAITokenStatsDuration(t *testing.T) {
tests := []struct {
input string
want time.Duration
ok bool
}{
{input: "30m", want: 30 * time.Minute, ok: true},
{input: "1h", want: time.Hour, ok: true},
{input: "1d", want: 24 * time.Hour, ok: true},
{input: "15d", want: 15 * 24 * time.Hour, ok: true},
{input: "30d", want: 30 * 24 * time.Hour, ok: true},
{input: "7d", want: 0, ok: false},
}
for _, tt := range tests {
got, ok := parseOpsOpenAITokenStatsDuration(tt.input)
require.Equal(t, tt.ok, ok, "input=%s", tt.input)
require.Equal(t, tt.want, got, "input=%s", tt.input)
}
}
func TestParseOpsOpenAITokenStatsFilter_Defaults(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
before := time.Now().UTC()
filter, err := parseOpsOpenAITokenStatsFilter(c)
after := time.Now().UTC()
require.NoError(t, err)
require.NotNil(t, filter)
require.Equal(t, "30d", filter.TimeRange)
require.Equal(t, 1, filter.Page)
require.Equal(t, 20, filter.PageSize)
require.Equal(t, 0, filter.TopN)
require.Nil(t, filter.GroupID)
require.Equal(t, "", filter.Platform)
require.True(t, filter.StartTime.Before(filter.EndTime))
require.WithinDuration(t, before.Add(-30*24*time.Hour), filter.StartTime, 2*time.Second)
require.WithinDuration(t, after, filter.EndTime, 2*time.Second)
}
func TestParseOpsOpenAITokenStatsFilter_WithTopN(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(
http.MethodGet,
"/?time_range=1h&platform=openai&group_id=12&top_n=50",
nil,
)
filter, err := parseOpsOpenAITokenStatsFilter(c)
require.NoError(t, err)
require.Equal(t, "1h", filter.TimeRange)
require.Equal(t, "openai", filter.Platform)
require.NotNil(t, filter.GroupID)
require.Equal(t, int64(12), *filter.GroupID)
require.Equal(t, 50, filter.TopN)
require.Equal(t, 0, filter.Page)
require.Equal(t, 0, filter.PageSize)
}
func TestParseOpsOpenAITokenStatsFilter_InvalidParams(t *testing.T) {
tests := []string{
"/?time_range=7d",
"/?group_id=0",
"/?group_id=abc",
"/?top_n=0",
"/?top_n=101",
"/?top_n=10&page=1",
"/?top_n=10&page_size=20",
"/?page=0",
"/?page_size=0",
"/?page_size=101",
}
gin.SetMode(gin.TestMode)
for _, rawURL := range tests {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, rawURL, nil)
_, err := parseOpsOpenAITokenStatsFilter(c)
require.Error(t, err, "url=%s", rawURL)
}
}
func TestParseOpsTimeRange(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()

View File

@@ -10,19 +10,28 @@ import (
)
type stubAdminService struct {
users []service.User
apiKeys []service.APIKey
groups []service.Group
accounts []service.Account
proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode
createdAccounts []*service.CreateAccountInput
createdProxies []*service.CreateProxyInput
updatedProxyIDs []int64
updatedProxies []*service.UpdateProxyInput
testedProxyIDs []int64
mu sync.Mutex
users []service.User
apiKeys []service.APIKey
groups []service.Group
accounts []service.Account
proxies []service.Proxy
proxyCounts []service.ProxyWithAccountCount
redeems []service.RedeemCode
createdAccounts []*service.CreateAccountInput
createdProxies []*service.CreateProxyInput
updatedProxyIDs []int64
updatedProxies []*service.UpdateProxyInput
testedProxyIDs []int64
createAccountErr error
updateAccountErr error
bulkUpdateAccountErr error
checkMixedErr error
lastMixedCheck struct {
accountID int64
platform string
groupIDs []int64
}
mu sync.Mutex
}
func newStubAdminService() *stubAdminService {
@@ -166,7 +175,7 @@ func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, p
return s.apiKeys, int64(len(s.apiKeys)), nil
}
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]service.Account, int64, error) {
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]service.Account, int64, error) {
return s.accounts, int64(len(s.accounts)), nil
}
@@ -188,11 +197,17 @@ func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.Cre
s.mu.Lock()
s.createdAccounts = append(s.createdAccounts, input)
s.mu.Unlock()
if s.createAccountErr != nil {
return nil, s.createAccountErr
}
account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive}
return &account, nil
}
func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
if s.updateAccountErr != nil {
return nil, s.updateAccountErr
}
account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive}
return &account, nil
}
@@ -221,7 +236,17 @@ func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64,
}
func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *service.BulkUpdateAccountsInput) (*service.BulkUpdateAccountsResult, error) {
return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil
if s.bulkUpdateAccountErr != nil {
return nil, s.bulkUpdateAccountErr
}
return &service.BulkUpdateAccountsResult{Success: len(input.AccountIDs), Failed: 0, SuccessIDs: input.AccountIDs}, nil
}
func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
s.lastMixedCheck.accountID = currentAccountID
s.lastMixedCheck.platform = currentAccountPlatform
s.lastMixedCheck.groupIDs = append([]int64(nil), groupIDs...)
return s.checkMixedErr
}
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
@@ -327,6 +352,27 @@ func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.Pr
return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
}
func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*service.ProxyQualityCheckResult, error) {
return &service.ProxyQualityCheckResult{
ProxyID: id,
Score: 95,
Grade: "A",
Summary: "通过 5 项,告警 0 项,失败 0 项,挑战 0 项",
PassedCount: 5,
WarnCount: 0,
FailedCount: 0,
ChallengeCount: 0,
CheckedAt: time.Now().Unix(),
Items: []service.ProxyQualityCheckItem{
{Target: "base_connectivity", Status: "pass", Message: "ok"},
{Target: "openai", Status: "pass", HTTPStatus: 401},
{Target: "anthropic", Status: "pass", HTTPStatus: 401},
{Target: "gemini", Status: "pass", HTTPStatus: 200},
{Target: "sora", Status: "pass", HTTPStatus: 401},
},
}, nil
}
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) {
return s.redeems, int64(len(s.redeems)), nil
}
@@ -357,5 +403,27 @@ func (s *stubAdminService) GetUserBalanceHistory(ctx context.Context, userID int
return s.redeems, int64(len(s.redeems)), 100.0, nil
}
func (s *stubAdminService) UpdateGroupSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error {
return nil
}
func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*service.AdminUpdateAPIKeyGroupIDResult, error) {
for i := range s.apiKeys {
if s.apiKeys[i].ID == keyID {
k := s.apiKeys[i]
if groupID != nil {
if *groupID == 0 {
k.GroupID = nil
} else {
gid := *groupID
k.GroupID = &gid
}
}
return &service.AdminUpdateAPIKeyGroupIDResult{APIKey: &k}, nil
}
}
return nil, service.ErrAPIKeyNotFound
}
// Ensure stub implements interface.
var _ service.AdminService = (*stubAdminService)(nil)

View File

@@ -65,3 +65,27 @@ func (h *AntigravityOAuthHandler) ExchangeCode(c *gin.Context) {
response.Success(c, tokenInfo)
}
// AntigravityRefreshTokenRequest represents the request for validating Antigravity refresh token
type AntigravityRefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
ProxyID *int64 `json:"proxy_id"`
}
// RefreshToken validates an Antigravity refresh token and returns full token info
// POST /api/v1/admin/antigravity/oauth/refresh-token
func (h *AntigravityOAuthHandler) RefreshToken(c *gin.Context) {
var req AntigravityRefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求无效: "+err.Error())
return
}
tokenInfo, err := h.antigravityOAuthService.ValidateRefreshToken(c.Request.Context(), req.RefreshToken, req.ProxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, tokenInfo)
}

View File

@@ -0,0 +1,63 @@
package admin
import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AdminAPIKeyHandler handles admin API key management
type AdminAPIKeyHandler struct {
adminService service.AdminService
}
// NewAdminAPIKeyHandler creates a new admin API key handler
func NewAdminAPIKeyHandler(adminService service.AdminService) *AdminAPIKeyHandler {
return &AdminAPIKeyHandler{
adminService: adminService,
}
}
// AdminUpdateAPIKeyGroupRequest represents the request to update an API key's group
type AdminUpdateAPIKeyGroupRequest struct {
GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组
}
// UpdateGroup handles updating an API key's group binding
// PUT /api/v1/admin/api-keys/:id
func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid API key ID")
return
}
var req AdminUpdateAPIKeyGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
result, err := h.adminService.AdminUpdateAPIKeyGroupID(c.Request.Context(), keyID, req.GroupID)
if err != nil {
response.ErrorFrom(c, err)
return
}
resp := struct {
APIKey *dto.APIKey `json:"api_key"`
AutoGrantedGroupAccess bool `json:"auto_granted_group_access"`
GrantedGroupID *int64 `json:"granted_group_id,omitempty"`
GrantedGroupName string `json:"granted_group_name,omitempty"`
}{
APIKey: dto.APIKeyFromService(result.APIKey),
AutoGrantedGroupAccess: result.AutoGrantedGroupAccess,
GrantedGroupID: result.GrantedGroupID,
GrantedGroupName: result.GrantedGroupName,
}
response.Success(c, resp)
}

View File

@@ -0,0 +1,202 @@
package admin
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func setupAPIKeyHandler(adminSvc service.AdminService) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
h := NewAdminAPIKeyHandler(adminSvc)
router.PUT("/api/v1/admin/api-keys/:id", h.UpdateGroup)
return router
}
func TestAdminAPIKeyHandler_UpdateGroup_InvalidID(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
body := `{"group_id": 2}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/abc", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "Invalid API key ID")
}
func TestAdminAPIKeyHandler_UpdateGroup_InvalidJSON(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{bad json`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "Invalid request")
}
func TestAdminAPIKeyHandler_UpdateGroup_KeyNotFound(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
body := `{"group_id": 2}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/999", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
// ErrAPIKeyNotFound maps to 404
require.Equal(t, http.StatusNotFound, rec.Code)
}
func TestAdminAPIKeyHandler_UpdateGroup_BindGroup(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
body := `{"group_id": 2}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Code int `json:"code"`
Data json.RawMessage `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
var data struct {
APIKey struct {
ID int64 `json:"id"`
GroupID *int64 `json:"group_id"`
} `json:"api_key"`
AutoGrantedGroupAccess bool `json:"auto_granted_group_access"`
}
require.NoError(t, json.Unmarshal(resp.Data, &data))
require.Equal(t, int64(10), data.APIKey.ID)
require.NotNil(t, data.APIKey.GroupID)
require.Equal(t, int64(2), *data.APIKey.GroupID)
}
func TestAdminAPIKeyHandler_UpdateGroup_Unbind(t *testing.T) {
svc := newStubAdminService()
gid := int64(2)
svc.apiKeys[0].GroupID = &gid
router := setupAPIKeyHandler(svc)
body := `{"group_id": 0}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Data struct {
APIKey struct {
GroupID *int64 `json:"group_id"`
} `json:"api_key"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Nil(t, resp.Data.APIKey.GroupID)
}
func TestAdminAPIKeyHandler_UpdateGroup_ServiceError(t *testing.T) {
svc := &failingUpdateGroupService{
stubAdminService: newStubAdminService(),
err: errors.New("internal failure"),
}
router := setupAPIKeyHandler(svc)
body := `{"group_id": 2}`
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusInternalServerError, rec.Code)
}
// H2: empty body → group_id is nil → no-op, returns original key
func TestAdminAPIKeyHandler_UpdateGroup_EmptyBody_NoChange(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var resp struct {
Code int `json:"code"`
Data struct {
APIKey struct {
ID int64 `json:"id"`
} `json:"api_key"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, 0, resp.Code)
require.Equal(t, int64(10), resp.Data.APIKey.ID)
}
// M2: service returns GROUP_NOT_ACTIVE → handler maps to 400
func TestAdminAPIKeyHandler_UpdateGroup_GroupNotActive(t *testing.T) {
svc := &failingUpdateGroupService{
stubAdminService: newStubAdminService(),
err: infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active"),
}
router := setupAPIKeyHandler(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"group_id": 5}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "GROUP_NOT_ACTIVE")
}
// M2: service returns INVALID_GROUP_ID → handler maps to 400
func TestAdminAPIKeyHandler_UpdateGroup_NegativeGroupID(t *testing.T) {
svc := &failingUpdateGroupService{
stubAdminService: newStubAdminService(),
err: infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative"),
}
router := setupAPIKeyHandler(svc)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"group_id": -5}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "INVALID_GROUP_ID")
}
// failingUpdateGroupService overrides AdminUpdateAPIKeyGroupID to return an error.
type failingUpdateGroupService struct {
*stubAdminService
err error
}
func (f *failingUpdateGroupService) AdminUpdateAPIKeyGroupID(_ context.Context, _ int64, _ *int64) (*service.AdminUpdateAPIKeyGroupIDResult, error) {
return nil, f.err
}

View File

@@ -0,0 +1,208 @@
//go:build unit
package admin
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// failingAdminService 嵌入 stubAdminService可配置 UpdateAccount 在指定 ID 时失败。
type failingAdminService struct {
*stubAdminService
failOnAccountID int64
updateCallCount atomic.Int64
}
func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
f.updateCallCount.Add(1)
if id == f.failOnAccountID {
return nil, errors.New("database error")
}
return f.stubAdminService.UpdateAccount(ctx, id, input)
}
func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) {
gin.SetMode(gin.TestMode)
router := gin.New()
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials)
return router, handler
}
func TestBatchUpdateCredentials_AllSuccess(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "account_uuid",
Value: "test-uuid",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code, "全部成功时应返回 200")
require.Equal(t, int64(3), svc.updateCallCount.Load(), "应调用 3 次 UpdateAccount")
}
func TestBatchUpdateCredentials_PartialFailure(t *testing.T) {
// 让第 2 个账号ID=2更新时失败
svc := &failingAdminService{
stubAdminService: newStubAdminService(),
failOnAccountID: 2,
}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "org_uuid",
Value: "test-org",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
// 实现采用"部分成功"模式:总是返回 200 + 成功/失败明细
require.Equal(t, http.StatusOK, w.Code, "批量更新返回 200 + 成功/失败明细")
var resp map[string]any
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
data := resp["data"].(map[string]any)
require.Equal(t, float64(2), data["success"], "应有 2 个成功")
require.Equal(t, float64(1), data["failed"], "应有 1 个失败")
// 所有 3 个账号都会被尝试更新(非 fail-fast
require.Equal(t, int64(3), svc.updateCallCount.Load(),
"应调用 3 次 UpdateAccount逐个尝试失败后继续")
}
func TestBatchUpdateCredentials_FirstAccountNotFound(t *testing.T) {
// GetAccount 在 stubAdminService 中总是成功的,需要创建一个 GetAccount 会失败的 stub
svc := &getAccountFailingService{
stubAdminService: newStubAdminService(),
failOnAccountID: 1,
}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(BatchUpdateCredentialsRequest{
AccountIDs: []int64{1, 2, 3},
Field: "account_uuid",
Value: "test",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusNotFound, w.Code, "第一阶段验证失败应返回 404")
}
// getAccountFailingService 模拟 GetAccount 在特定 ID 时返回 not found。
type getAccountFailingService struct {
*stubAdminService
failOnAccountID int64
}
func (f *getAccountFailingService) GetAccount(ctx context.Context, id int64) (*service.Account, error) {
if id == f.failOnAccountID {
return nil, errors.New("not found")
}
return f.stubAdminService.GetAccount(ctx, id)
}
func TestBatchUpdateCredentials_InterceptWarmupRequests_NonBool(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// intercept_warmup_requests 传入非 bool 类型string应返回 400
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "intercept_warmup_requests",
"value": "not-a-bool",
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code,
"intercept_warmup_requests 传入非 bool 值应返回 400")
}
func TestBatchUpdateCredentials_InterceptWarmupRequests_ValidBool(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "intercept_warmup_requests",
"value": true,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code,
"intercept_warmup_requests 传入合法 bool 值应返回 200")
}
func TestBatchUpdateCredentials_AccountUUID_NonString(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// account_uuid 传入非 string 类型number应返回 400
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "account_uuid",
"value": 12345,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code,
"account_uuid 传入非 string 值应返回 400")
}
func TestBatchUpdateCredentials_AccountUUID_NullValue(t *testing.T) {
svc := &failingAdminService{stubAdminService: newStubAdminService()}
router, _ := setupAccountHandlerWithService(svc)
// account_uuid 传入 null设置为空应正常通过
body, _ := json.Marshal(map[string]any{
"account_ids": []int64{1},
"field": "account_uuid",
"value": nil,
})
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code,
"account_uuid 传入 null 应返回 200")
}

View File

@@ -3,6 +3,7 @@ package admin
import (
"errors"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
@@ -186,7 +187,7 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
// GetUsageTrend handles getting usage trend data
// GET /api/v1/admin/dashboard/trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, request_type, stream, billing_type
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
@@ -194,6 +195,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// Parse optional filter params
var userID, apiKeyID, accountID, groupID int64
var model string
var requestType *int16
var stream *bool
var billingType *int8
@@ -220,9 +222,20 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
if modelStr := c.Query("model"); modelStr != "" {
model = modelStr
}
if streamStr := c.Query("stream"); streamStr != "" {
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
} else if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
} else {
response.BadRequest(c, "Invalid stream value, use true or false")
return
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
@@ -235,7 +248,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
}
}
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
@@ -251,12 +264,13 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// GetModelStats handles getting model usage statistics
// GET /api/v1/admin/dashboard/models
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
// Parse optional filter params
var userID, apiKeyID, accountID, groupID int64
var requestType *int16
var stream *bool
var billingType *int8
@@ -280,9 +294,20 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
groupID = id
}
}
if streamStr := c.Query("stream"); streamStr != "" {
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
} else if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
} else {
response.BadRequest(c, "Invalid stream value, use true or false")
return
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
@@ -295,7 +320,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
}
}
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
@@ -308,6 +333,76 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
})
}
// GetGroupStats handles getting group usage statistics
// GET /api/v1/admin/dashboard/groups
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type
func (h *DashboardHandler) GetGroupStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
var userID, apiKeyID, accountID, groupID int64
var requestType *int16
var stream *bool
var billingType *int8
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = id
}
}
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
apiKeyID = id
}
}
if accountIDStr := c.Query("account_id"); accountIDStr != "" {
if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil {
accountID = id
}
}
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
groupID = id
}
}
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
} else if streamStr := c.Query("stream"); streamStr != "" {
if streamVal, err := strconv.ParseBool(streamStr); err == nil {
stream = &streamVal
} else {
response.BadRequest(c, "Invalid stream value, use true or false")
return
}
}
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
bt := int8(v)
billingType = &bt
} else {
response.BadRequest(c, "Invalid billing_type")
return
}
}
stats, err := h.dashboardService.GetGroupStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get group statistics")
return
}
response.Success(c, gin.H{
"groups": stats,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
})
}
// GetAPIKeyUsageTrend handles getting API key usage trend data
// GET /api/v1/admin/dashboard/api-keys-trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 5)
@@ -379,7 +474,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
return
}
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{})
if err != nil {
response.Error(c, 500, "Failed to get user usage stats")
return
@@ -407,7 +502,7 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) {
return
}
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs)
stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{})
if err != nil {
response.Error(c, 500, "Failed to get API key usage stats")
return

View File

@@ -0,0 +1,132 @@
package admin
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type dashboardUsageRepoCapture struct {
service.UsageLogRepository
trendRequestType *int16
trendStream *bool
modelRequestType *int16
modelStream *bool
}
func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters(
ctx context.Context,
startTime, endTime time.Time,
granularity string,
userID, apiKeyID, accountID, groupID int64,
model string,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.TrendDataPoint, error) {
s.trendRequestType = requestType
s.trendStream = stream
return []usagestats.TrendDataPoint{}, nil
}
func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters(
ctx context.Context,
startTime, endTime time.Time,
userID, apiKeyID, accountID, groupID int64,
requestType *int16,
stream *bool,
billingType *int8,
) ([]usagestats.ModelStat, error) {
s.modelRequestType = requestType
s.modelStream = stream
return []usagestats.ModelStat{}, nil
}
func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine {
gin.SetMode(gin.TestMode)
dashboardSvc := service.NewDashboardService(repo, nil, nil, nil)
handler := NewDashboardHandler(dashboardSvc, nil)
router := gin.New()
router.GET("/admin/dashboard/trend", handler.GetUsageTrend)
router.GET("/admin/dashboard/models", handler.GetModelStats)
return router
}
func TestDashboardTrendRequestTypePriority(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=ws_v2&stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, repo.trendRequestType)
require.Equal(t, int16(service.RequestTypeWSV2), *repo.trendRequestType)
require.Nil(t, repo.trendStream)
}
func TestDashboardTrendInvalidRequestType(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardTrendInvalidStream(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardModelStatsRequestTypePriority(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=sync&stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, repo.modelRequestType)
require.Equal(t, int16(service.RequestTypeSync), *repo.modelRequestType)
require.Nil(t, repo.modelStream)
}
func TestDashboardModelStatsInvalidRequestType(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestDashboardModelStatsInvalidStream(t *testing.T) {
repo := &dashboardUsageRepoCapture{}
router := newDashboardRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}

View File

@@ -0,0 +1,545 @@
package admin
import (
"context"
"strconv"
"strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type DataManagementHandler struct {
dataManagementService dataManagementService
}
func NewDataManagementHandler(dataManagementService *service.DataManagementService) *DataManagementHandler {
return &DataManagementHandler{dataManagementService: dataManagementService}
}
type dataManagementService interface {
GetConfig(ctx context.Context) (service.DataManagementConfig, error)
UpdateConfig(ctx context.Context, cfg service.DataManagementConfig) (service.DataManagementConfig, error)
ValidateS3(ctx context.Context, cfg service.DataManagementS3Config) (service.DataManagementTestS3Result, error)
CreateBackupJob(ctx context.Context, input service.DataManagementCreateBackupJobInput) (service.DataManagementBackupJob, error)
ListSourceProfiles(ctx context.Context, sourceType string) ([]service.DataManagementSourceProfile, error)
CreateSourceProfile(ctx context.Context, input service.DataManagementCreateSourceProfileInput) (service.DataManagementSourceProfile, error)
UpdateSourceProfile(ctx context.Context, input service.DataManagementUpdateSourceProfileInput) (service.DataManagementSourceProfile, error)
DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error
SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (service.DataManagementSourceProfile, error)
ListS3Profiles(ctx context.Context) ([]service.DataManagementS3Profile, error)
CreateS3Profile(ctx context.Context, input service.DataManagementCreateS3ProfileInput) (service.DataManagementS3Profile, error)
UpdateS3Profile(ctx context.Context, input service.DataManagementUpdateS3ProfileInput) (service.DataManagementS3Profile, error)
DeleteS3Profile(ctx context.Context, profileID string) error
SetActiveS3Profile(ctx context.Context, profileID string) (service.DataManagementS3Profile, error)
ListBackupJobs(ctx context.Context, input service.DataManagementListBackupJobsInput) (service.DataManagementListBackupJobsResult, error)
GetBackupJob(ctx context.Context, jobID string) (service.DataManagementBackupJob, error)
EnsureAgentEnabled(ctx context.Context) error
GetAgentHealth(ctx context.Context) service.DataManagementAgentHealth
}
type TestS3ConnectionRequest struct {
Endpoint string `json:"endpoint"`
Region string `json:"region" binding:"required"`
Bucket string `json:"bucket" binding:"required"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
UseSSL bool `json:"use_ssl"`
}
type CreateBackupJobRequest struct {
BackupType string `json:"backup_type" binding:"required,oneof=postgres redis full"`
UploadToS3 bool `json:"upload_to_s3"`
S3ProfileID string `json:"s3_profile_id"`
PostgresID string `json:"postgres_profile_id"`
RedisID string `json:"redis_profile_id"`
IdempotencyKey string `json:"idempotency_key"`
}
type CreateSourceProfileRequest struct {
ProfileID string `json:"profile_id" binding:"required"`
Name string `json:"name" binding:"required"`
Config service.DataManagementSourceConfig `json:"config" binding:"required"`
SetActive bool `json:"set_active"`
}
type UpdateSourceProfileRequest struct {
Name string `json:"name" binding:"required"`
Config service.DataManagementSourceConfig `json:"config" binding:"required"`
}
type CreateS3ProfileRequest struct {
ProfileID string `json:"profile_id" binding:"required"`
Name string `json:"name" binding:"required"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
UseSSL bool `json:"use_ssl"`
SetActive bool `json:"set_active"`
}
type UpdateS3ProfileRequest struct {
Name string `json:"name" binding:"required"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
UseSSL bool `json:"use_ssl"`
}
func (h *DataManagementHandler) GetAgentHealth(c *gin.Context) {
health := h.getAgentHealth(c)
payload := gin.H{
"enabled": health.Enabled,
"reason": health.Reason,
"socket_path": health.SocketPath,
}
if health.Agent != nil {
payload["agent"] = gin.H{
"status": health.Agent.Status,
"version": health.Agent.Version,
"uptime_seconds": health.Agent.UptimeSeconds,
}
}
response.Success(c, payload)
}
func (h *DataManagementHandler) GetConfig(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
cfg, err := h.dataManagementService.GetConfig(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *DataManagementHandler) UpdateConfig(c *gin.Context) {
var req service.DataManagementConfig
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
cfg, err := h.dataManagementService.UpdateConfig(c.Request.Context(), req)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, cfg)
}
func (h *DataManagementHandler) TestS3(c *gin.Context) {
var req TestS3ConnectionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
result, err := h.dataManagementService.ValidateS3(c.Request.Context(), service.DataManagementS3Config{
Enabled: true,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
UseSSL: req.UseSSL,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"ok": result.OK, "message": result.Message})
}
func (h *DataManagementHandler) CreateBackupJob(c *gin.Context) {
var req CreateBackupJobRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
req.IdempotencyKey = normalizeBackupIdempotencyKey(c.GetHeader("X-Idempotency-Key"), req.IdempotencyKey)
if !h.requireAgentEnabled(c) {
return
}
triggeredBy := "admin:unknown"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
triggeredBy = "admin:" + strconv.FormatInt(subject.UserID, 10)
}
job, err := h.dataManagementService.CreateBackupJob(c.Request.Context(), service.DataManagementCreateBackupJobInput{
BackupType: req.BackupType,
UploadToS3: req.UploadToS3,
S3ProfileID: req.S3ProfileID,
PostgresID: req.PostgresID,
RedisID: req.RedisID,
TriggeredBy: triggeredBy,
IdempotencyKey: req.IdempotencyKey,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"job_id": job.JobID, "status": job.Status})
}
func (h *DataManagementHandler) ListSourceProfiles(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType == "" {
response.BadRequest(c, "Invalid source_type")
return
}
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
if !h.requireAgentEnabled(c) {
return
}
items, err := h.dataManagementService.ListSourceProfiles(c.Request.Context(), sourceType)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"items": items})
}
func (h *DataManagementHandler) CreateSourceProfile(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
var req CreateSourceProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.CreateSourceProfile(c.Request.Context(), service.DataManagementCreateSourceProfileInput{
SourceType: sourceType,
ProfileID: req.ProfileID,
Name: req.Name,
Config: req.Config,
SetActive: req.SetActive,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) UpdateSourceProfile(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
var req UpdateSourceProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.UpdateSourceProfile(c.Request.Context(), service.DataManagementUpdateSourceProfileInput{
SourceType: sourceType,
ProfileID: profileID,
Name: req.Name,
Config: req.Config,
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) DeleteSourceProfile(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
if err := h.dataManagementService.DeleteSourceProfile(c.Request.Context(), sourceType, profileID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": true})
}
func (h *DataManagementHandler) SetActiveSourceProfile(c *gin.Context) {
sourceType := strings.TrimSpace(c.Param("source_type"))
if sourceType != "postgres" && sourceType != "redis" {
response.BadRequest(c, "source_type must be postgres or redis")
return
}
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.SetActiveSourceProfile(c.Request.Context(), sourceType, profileID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) ListS3Profiles(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
items, err := h.dataManagementService.ListS3Profiles(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"items": items})
}
func (h *DataManagementHandler) CreateS3Profile(c *gin.Context) {
var req CreateS3ProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.CreateS3Profile(c.Request.Context(), service.DataManagementCreateS3ProfileInput{
ProfileID: req.ProfileID,
Name: req.Name,
SetActive: req.SetActive,
S3: service.DataManagementS3Config{
Enabled: req.Enabled,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
UseSSL: req.UseSSL,
},
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) UpdateS3Profile(c *gin.Context) {
var req UpdateS3ProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.UpdateS3Profile(c.Request.Context(), service.DataManagementUpdateS3ProfileInput{
ProfileID: profileID,
Name: req.Name,
S3: service.DataManagementS3Config{
Enabled: req.Enabled,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
UseSSL: req.UseSSL,
},
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) DeleteS3Profile(c *gin.Context) {
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
if err := h.dataManagementService.DeleteS3Profile(c.Request.Context(), profileID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": true})
}
func (h *DataManagementHandler) SetActiveS3Profile(c *gin.Context) {
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Invalid profile_id")
return
}
if !h.requireAgentEnabled(c) {
return
}
profile, err := h.dataManagementService.SetActiveS3Profile(c.Request.Context(), profileID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, profile)
}
func (h *DataManagementHandler) ListBackupJobs(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
pageSize := int32(20)
if raw := strings.TrimSpace(c.Query("page_size")); raw != "" {
v, err := strconv.Atoi(raw)
if err != nil || v <= 0 {
response.BadRequest(c, "Invalid page_size")
return
}
pageSize = int32(v)
}
result, err := h.dataManagementService.ListBackupJobs(c.Request.Context(), service.DataManagementListBackupJobsInput{
PageSize: pageSize,
PageToken: c.Query("page_token"),
Status: c.Query("status"),
BackupType: c.Query("backup_type"),
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
func (h *DataManagementHandler) GetBackupJob(c *gin.Context) {
jobID := strings.TrimSpace(c.Param("job_id"))
if jobID == "" {
response.BadRequest(c, "Invalid backup job ID")
return
}
if !h.requireAgentEnabled(c) {
return
}
job, err := h.dataManagementService.GetBackupJob(c.Request.Context(), jobID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, job)
}
func (h *DataManagementHandler) requireAgentEnabled(c *gin.Context) bool {
if h.dataManagementService == nil {
err := infraerrors.ServiceUnavailable(
service.DataManagementAgentUnavailableReason,
"data management agent service is not configured",
).WithMetadata(map[string]string{"socket_path": service.DefaultDataManagementAgentSocketPath})
response.ErrorFrom(c, err)
return false
}
if err := h.dataManagementService.EnsureAgentEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return false
}
return true
}
func (h *DataManagementHandler) getAgentHealth(c *gin.Context) service.DataManagementAgentHealth {
if h.dataManagementService == nil {
return service.DataManagementAgentHealth{
Enabled: false,
Reason: service.DataManagementAgentUnavailableReason,
SocketPath: service.DefaultDataManagementAgentSocketPath,
}
}
return h.dataManagementService.GetAgentHealth(c.Request.Context())
}
func normalizeBackupIdempotencyKey(headerValue, bodyValue string) string {
headerKey := strings.TrimSpace(headerValue)
if headerKey != "" {
return headerKey
}
return strings.TrimSpace(bodyValue)
}

View File

@@ -0,0 +1,78 @@
package admin
import (
"encoding/json"
"net/http"
"net/http/httptest"
"path/filepath"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type apiEnvelope struct {
Code int `json:"code"`
Message string `json:"message"`
Reason string `json:"reason"`
Data json.RawMessage `json:"data"`
}
func TestDataManagementHandler_AgentHealthAlways200(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond)
h := NewDataManagementHandler(svc)
r := gin.New()
r.GET("/api/v1/admin/data-management/agent/health", h.GetAgentHealth)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/agent/health", nil)
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
var envelope apiEnvelope
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope))
require.Equal(t, 0, envelope.Code)
var data struct {
Enabled bool `json:"enabled"`
Reason string `json:"reason"`
SocketPath string `json:"socket_path"`
}
require.NoError(t, json.Unmarshal(envelope.Data, &data))
require.False(t, data.Enabled)
require.Equal(t, service.DataManagementDeprecatedReason, data.Reason)
require.Equal(t, svc.SocketPath(), data.SocketPath)
}
func TestDataManagementHandler_NonHealthRouteReturns503WhenDisabled(t *testing.T) {
gin.SetMode(gin.TestMode)
svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond)
h := NewDataManagementHandler(svc)
r := gin.New()
r.GET("/api/v1/admin/data-management/config", h.GetConfig)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/config", nil)
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
var envelope apiEnvelope
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope))
require.Equal(t, http.StatusServiceUnavailable, envelope.Code)
require.Equal(t, service.DataManagementDeprecatedReason, envelope.Reason)
}
func TestNormalizeBackupIdempotencyKey(t *testing.T) {
require.Equal(t, "from-header", normalizeBackupIdempotencyKey("from-header", "from-body"))
require.Equal(t, "from-body", normalizeBackupIdempotencyKey(" ", " from-body "))
require.Equal(t, "", normalizeBackupIdempotencyKey("", ""))
}

View File

@@ -32,6 +32,7 @@ type CreateErrorPassthroughRuleRequest struct {
ResponseCode *int `json:"response_code"`
PassthroughBody *bool `json:"passthrough_body"`
CustomMessage *string `json:"custom_message"`
SkipMonitoring *bool `json:"skip_monitoring"`
Description *string `json:"description"`
}
@@ -48,6 +49,7 @@ type UpdateErrorPassthroughRuleRequest struct {
ResponseCode *int `json:"response_code"`
PassthroughBody *bool `json:"passthrough_body"`
CustomMessage *string `json:"custom_message"`
SkipMonitoring *bool `json:"skip_monitoring"`
Description *string `json:"description"`
}
@@ -122,6 +124,9 @@ func (h *ErrorPassthroughHandler) Create(c *gin.Context) {
} else {
rule.PassthroughBody = true
}
if req.SkipMonitoring != nil {
rule.SkipMonitoring = *req.SkipMonitoring
}
rule.ResponseCode = req.ResponseCode
rule.CustomMessage = req.CustomMessage
rule.Description = req.Description
@@ -190,6 +195,7 @@ func (h *ErrorPassthroughHandler) Update(c *gin.Context) {
ResponseCode: existing.ResponseCode,
PassthroughBody: existing.PassthroughBody,
CustomMessage: existing.CustomMessage,
SkipMonitoring: existing.SkipMonitoring,
Description: existing.Description,
}
@@ -230,6 +236,9 @@ func (h *ErrorPassthroughHandler) Update(c *gin.Context) {
if req.Description != nil {
rule.Description = req.Description
}
if req.SkipMonitoring != nil {
rule.SkipMonitoring = *req.SkipMonitoring
}
// 确保切片不为 nil
if rule.ErrorCodes == nil {

View File

@@ -61,7 +61,11 @@ func (h *GeminiOAuthHandler) GenerateAuthURL(c *gin.Context) {
if err != nil {
msg := err.Error()
// Treat missing/invalid OAuth client configuration as a user/config error.
if strings.Contains(msg, "OAuth client not configured") || strings.Contains(msg, "requires your own OAuth Client") {
if strings.Contains(msg, "OAuth client not configured") ||
strings.Contains(msg, "requires your own OAuth Client") ||
strings.Contains(msg, "requires a custom OAuth Client") ||
strings.Contains(msg, "GEMINI_CLI_OAUTH_CLIENT_SECRET_MISSING") ||
strings.Contains(msg, "built-in Gemini CLI OAuth client_secret is not configured") {
response.BadRequest(c, "Failed to generate auth URL: "+msg)
return
}

View File

@@ -27,7 +27,7 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler {
type CreateGroupRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
@@ -38,6 +38,10 @@ type CreateGroupRequest struct {
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
@@ -47,6 +51,8 @@ type CreateGroupRequest struct {
MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes"`
// Sora 存储配额
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
// 从指定分组复制账号(创建后自动绑定)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
@@ -55,7 +61,7 @@ type CreateGroupRequest struct {
type UpdateGroupRequest struct {
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"`
RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
@@ -67,6 +73,10 @@ type UpdateGroupRequest struct {
ImagePrice1K *float64 `json:"image_price_1k"`
ImagePrice2K *float64 `json:"image_price_2k"`
ImagePrice4K *float64 `json:"image_price_4k"`
SoraImagePrice360 *float64 `json:"sora_image_price_360"`
SoraImagePrice540 *float64 `json:"sora_image_price_540"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"`
ClaudeCodeOnly *bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"`
@@ -76,6 +86,8 @@ type UpdateGroupRequest struct {
MCPXMLInject *bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes *[]string `json:"supported_model_scopes"`
// Sora 存储配额
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
// 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号)
CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"`
}
@@ -179,6 +191,10 @@ func (h *GroupHandler) Create(c *gin.Context) {
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
SoraImagePrice360: req.SoraImagePrice360,
SoraImagePrice540: req.SoraImagePrice540,
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
@@ -186,6 +202,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
@@ -225,6 +242,10 @@ func (h *GroupHandler) Update(c *gin.Context) {
ImagePrice1K: req.ImagePrice1K,
ImagePrice2K: req.ImagePrice2K,
ImagePrice4K: req.ImagePrice4K,
SoraImagePrice360: req.SoraImagePrice360,
SoraImagePrice540: req.SoraImagePrice540,
SoraVideoPricePerRequest: req.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: req.ClaudeCodeOnly,
FallbackGroupID: req.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest,
@@ -232,6 +253,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
ModelRoutingEnabled: req.ModelRoutingEnabled,
MCPXMLInject: req.MCPXMLInject,
SupportedModelScopes: req.SupportedModelScopes,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs,
})
if err != nil {
@@ -302,3 +324,36 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
}
response.Paginated(c, outKeys, total, page, pageSize)
}
// UpdateSortOrderRequest represents the request to update group sort orders
type UpdateSortOrderRequest struct {
Updates []struct {
ID int64 `json:"id" binding:"required"`
SortOrder int `json:"sort_order"`
} `json:"updates" binding:"required,min=1"`
}
// UpdateSortOrder handles updating group sort orders
// PUT /api/v1/admin/groups/sort-order
func (h *GroupHandler) UpdateSortOrder(c *gin.Context) {
var req UpdateSortOrderRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
updates := make([]service.GroupSortOrderUpdate, 0, len(req.Updates))
for _, u := range req.Updates {
updates = append(updates, service.GroupSortOrderUpdate{
ID: u.ID,
SortOrder: u.SortOrder,
})
}
if err := h.adminService.UpdateGroupSortOrders(c.Request.Context(), updates); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Sort order updated successfully"})
}

View File

@@ -0,0 +1,115 @@
package admin
import (
"context"
"strconv"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type idempotencyStoreUnavailableMode int
const (
idempotencyStoreUnavailableFailClose idempotencyStoreUnavailableMode = iota
idempotencyStoreUnavailableFailOpen
)
func executeAdminIdempotent(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) (*service.IdempotencyExecuteResult, error) {
coordinator := service.DefaultIdempotencyCoordinator()
if coordinator == nil {
data, err := execute(c.Request.Context())
if err != nil {
return nil, err
}
return &service.IdempotencyExecuteResult{Data: data}, nil
}
actorScope := "admin:0"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10)
}
return coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{
Scope: scope,
ActorScope: actorScope,
Method: c.Request.Method,
Route: c.FullPath(),
IdempotencyKey: c.GetHeader("Idempotency-Key"),
Payload: payload,
RequireKey: true,
TTL: ttl,
}, execute)
}
func executeAdminIdempotentJSON(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) {
executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailClose, execute)
}
func executeAdminIdempotentJSONFailOpenOnStoreUnavailable(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
execute func(context.Context) (any, error),
) {
executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailOpen, execute)
}
func executeAdminIdempotentJSONWithMode(
c *gin.Context,
scope string,
payload any,
ttl time.Duration,
mode idempotencyStoreUnavailableMode,
execute func(context.Context) (any, error),
) {
result, err := executeAdminIdempotent(c, scope, payload, ttl, execute)
if err != nil {
if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) {
strategy := "fail_close"
if mode == idempotencyStoreUnavailableFailOpen {
strategy = "fail_open"
}
service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_"+strategy)
logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=%s", c.Request.Method, c.FullPath(), scope, strategy)
if mode == idempotencyStoreUnavailableFailOpen {
data, fallbackErr := execute(c.Request.Context())
if fallbackErr != nil {
response.ErrorFrom(c, fallbackErr)
return
}
c.Header("X-Idempotency-Degraded", "store-unavailable")
response.Success(c, data)
return
}
}
if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 {
c.Header("Retry-After", strconv.Itoa(retryAfter))
}
response.ErrorFrom(c, err)
return
}
if result != nil && result.Replayed {
c.Header("X-Idempotency-Replayed", "true")
}
response.Success(c, result.Data)
}

View File

@@ -0,0 +1,285 @@
package admin
import (
"bytes"
"context"
"errors"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type storeUnavailableRepoStub struct{}
func (storeUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) {
return false, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) {
return nil, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) {
return false, errors.New("store unavailable")
}
func (storeUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error {
return errors.New("store unavailable")
}
func (storeUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error {
return errors.New("store unavailable")
}
func (storeUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) {
return 0, errors.New("store unavailable")
}
func TestExecuteAdminIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode)
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed int
router := gin.New()
router.POST("/idempotent", func(c *gin.Context) {
executeAdminIdempotentJSON(c, "admin.test.high", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed++
return gin.H{"ok": true}, nil
})
})
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "test-key-1")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusServiceUnavailable, rec.Code)
require.Equal(t, 0, executed, "fail-close should block business execution when idempotency store is unavailable")
}
func TestExecuteAdminIdempotentJSONFailOpenOnStoreUnavailable(t *testing.T) {
gin.SetMode(gin.TestMode)
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig()))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed int
router := gin.New()
router.POST("/idempotent", func(c *gin.Context) {
executeAdminIdempotentJSONFailOpenOnStoreUnavailable(c, "admin.test.medium", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed++
return gin.H{"ok": true}, nil
})
})
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "test-key-2")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "store-unavailable", rec.Header().Get("X-Idempotency-Degraded"))
require.Equal(t, 1, executed, "fail-open strategy should allow semantic idempotent path to continue")
}
type memoryIdempotencyRepoStub struct {
mu sync.Mutex
nextID int64
data map[string]*service.IdempotencyRecord
}
func newMemoryIdempotencyRepoStub() *memoryIdempotencyRepoStub {
return &memoryIdempotencyRepoStub{
nextID: 1,
data: make(map[string]*service.IdempotencyRecord),
}
}
func (r *memoryIdempotencyRepoStub) key(scope, keyHash string) string {
return scope + "|" + keyHash
}
func (r *memoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord {
if in == nil {
return nil
}
out := *in
if in.LockedUntil != nil {
v := *in.LockedUntil
out.LockedUntil = &v
}
if in.ResponseBody != nil {
v := *in.ResponseBody
out.ResponseBody = &v
}
if in.ResponseStatus != nil {
v := *in.ResponseStatus
out.ResponseStatus = &v
}
if in.ErrorReason != nil {
v := *in.ErrorReason
out.ErrorReason = &v
}
return &out
}
func (r *memoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
k := r.key(record.Scope, record.IdempotencyKeyHash)
if _, ok := r.data[k]; ok {
return false, nil
}
cp := r.clone(record)
cp.ID = r.nextID
r.nextID++
r.data[k] = cp
record.ID = cp.ID
return true, nil
}
func (r *memoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) {
r.mu.Lock()
defer r.mu.Unlock()
return r.clone(r.data[r.key(scope, keyHash)]), nil
}
func (r *memoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != fromStatus {
return false, nil
}
if rec.LockedUntil != nil && rec.LockedUntil.After(now) {
return false, nil
}
rec.Status = service.IdempotencyStatusProcessing
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
rec.ErrorReason = nil
return true, nil
}
return false, nil
}
func (r *memoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint {
return false, nil
}
rec.LockedUntil = &newLockedUntil
rec.ExpiresAt = newExpiresAt
return true, nil
}
return false, nil
}
func (r *memoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = service.IdempotencyStatusSucceeded
rec.LockedUntil = nil
rec.ExpiresAt = expiresAt
rec.ResponseStatus = &responseStatus
rec.ResponseBody = &responseBody
rec.ErrorReason = nil
return nil
}
return nil
}
func (r *memoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, rec := range r.data {
if rec.ID != id {
continue
}
rec.Status = service.IdempotencyStatusFailedRetryable
rec.LockedUntil = &lockedUntil
rec.ExpiresAt = expiresAt
rec.ErrorReason = &errorReason
return nil
}
return nil
}
func (r *memoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) {
return 0, nil
}
func TestExecuteAdminIdempotentJSONConcurrentRetryOnlyOneSideEffect(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := newMemoryIdempotencyRepoStub()
cfg := service.DefaultIdempotencyConfig()
cfg.ProcessingTimeout = 2 * time.Second
service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg))
t.Cleanup(func() {
service.SetDefaultIdempotencyCoordinator(nil)
})
var executed atomic.Int32
router := gin.New()
router.POST("/idempotent", func(c *gin.Context) {
executeAdminIdempotentJSON(c, "admin.test.concurrent", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) {
executed.Add(1)
time.Sleep(120 * time.Millisecond)
return gin.H{"ok": true}, nil
})
})
call := func() (int, http.Header) {
req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Idempotency-Key", "same-key")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
return rec.Code, rec.Header()
}
var status1, status2 int
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
status1, _ = call()
}()
go func() {
defer wg.Done()
status2, _ = call()
}()
wg.Wait()
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1)
require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2)
require.Equal(t, int32(1), executed.Load(), "same idempotency key should execute side-effect only once")
status3, headers3 := call()
require.Equal(t, http.StatusOK, status3)
require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed"))
require.Equal(t, int32(1), executed.Load())
}

View File

@@ -2,8 +2,10 @@ package admin
import (
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -16,6 +18,13 @@ type OpenAIOAuthHandler struct {
adminService service.AdminService
}
func oauthPlatformFromPath(c *gin.Context) string {
if strings.Contains(c.FullPath(), "/admin/sora/") {
return service.PlatformSora
}
return service.PlatformOpenAI
}
// NewOpenAIOAuthHandler creates a new OpenAI OAuth handler
func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler {
return &OpenAIOAuthHandler{
@@ -39,7 +48,12 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
req = OpenAIGenerateAuthURLRequest{}
}
result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI)
result, err := h.openaiOAuthService.GenerateAuthURL(
c.Request.Context(),
req.ProxyID,
req.RedirectURI,
oauthPlatformFromPath(c),
)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -52,6 +66,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) {
type OpenAIExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"`
State string `json:"state" binding:"required"`
RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"`
}
@@ -68,6 +83,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID,
Code: req.Code,
State: req.State,
RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID,
})
@@ -81,18 +97,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) {
// OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token
type OpenAIRefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
RefreshToken string `json:"refresh_token"`
RT string `json:"rt"`
ClientID string `json:"client_id"`
ProxyID *int64 `json:"proxy_id"`
}
// RefreshToken refreshes an OpenAI OAuth token
// POST /api/v1/admin/openai/refresh-token
// POST /api/v1/admin/sora/rt2at
func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
var req OpenAIRefreshTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
refreshToken := strings.TrimSpace(req.RefreshToken)
if refreshToken == "" {
refreshToken = strings.TrimSpace(req.RT)
}
if refreshToken == "" {
response.BadRequest(c, "refresh_token is required")
return
}
var proxyURL string
if req.ProxyID != nil {
@@ -102,7 +129,14 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
}
}
tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL)
// 未指定 client_id 时,根据请求路径平台自动设置默认值,避免 repository 层盲猜
clientID := strings.TrimSpace(req.ClientID)
if clientID == "" {
platform := oauthPlatformFromPath(c)
clientID, _ = openai.OAuthClientConfigByPlatform(platform)
}
tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, clientID)
if err != nil {
response.ErrorFrom(c, err)
return
@@ -111,8 +145,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) {
response.Success(c, tokenInfo)
}
// RefreshAccountToken refreshes token for a specific OpenAI account
// ExchangeSoraSessionToken exchanges Sora session token to access token
// POST /api/v1/admin/sora/st2at
func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) {
var req struct {
SessionToken string `json:"session_token"`
ST string `json:"st"`
ProxyID *int64 `json:"proxy_id"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
sessionToken := strings.TrimSpace(req.SessionToken)
if sessionToken == "" {
sessionToken = strings.TrimSpace(req.ST)
}
if sessionToken == "" {
response.BadRequest(c, "session_token is required")
return
}
tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, tokenInfo)
}
// RefreshAccountToken refreshes token for a specific OpenAI/Sora account
// POST /api/v1/admin/openai/accounts/:id/refresh
// POST /api/v1/admin/sora/accounts/:id/refresh
func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
@@ -127,9 +192,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
return
}
// Ensure account is OpenAI platform
if !account.IsOpenAI() {
response.BadRequest(c, "Account is not an OpenAI account")
platform := oauthPlatformFromPath(c)
if account.Platform != platform {
response.BadRequest(c, "Account platform does not match OAuth endpoint")
return
}
@@ -167,12 +232,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) {
response.Success(c, dto.AccountFromService(updatedAccount))
}
// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info
// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info
// POST /api/v1/admin/openai/create-from-oauth
// POST /api/v1/admin/sora/create-from-oauth
func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
var req struct {
SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"`
State string `json:"state" binding:"required"`
RedirectURI string `json:"redirect_uri"`
ProxyID *int64 `json:"proxy_id"`
Name string `json:"name"`
@@ -189,6 +256,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{
SessionID: req.SessionID,
Code: req.Code,
State: req.State,
RedirectURI: req.RedirectURI,
ProxyID: req.ProxyID,
})
@@ -200,19 +268,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) {
// Build credentials from token info
credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo)
platform := oauthPlatformFromPath(c)
// Use email as default name if not provided
name := req.Name
if name == "" && tokenInfo.Email != "" {
name = tokenInfo.Email
}
if name == "" {
name = "OpenAI OAuth Account"
if platform == service.PlatformSora {
name = "Sora OAuth Account"
} else {
name = "OpenAI OAuth Account"
}
}
// Create account
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: name,
Platform: "openai",
Platform: platform,
Type: "oauth",
Credentials: credentials,
ProxyID: req.ProxyID,

View File

@@ -1,6 +1,7 @@
package admin
import (
"fmt"
"net/http"
"strconv"
"strings"
@@ -218,6 +219,115 @@ func (h *OpsHandler) GetDashboardErrorDistribution(c *gin.Context) {
response.Success(c, data)
}
// GetDashboardOpenAITokenStats returns OpenAI token efficiency stats grouped by model.
// GET /api/v1/admin/ops/dashboard/openai-token-stats
func (h *OpsHandler) GetDashboardOpenAITokenStats(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
filter, err := parseOpsOpenAITokenStatsFilter(c)
if err != nil {
response.BadRequest(c, err.Error())
return
}
data, err := h.opsService.GetOpenAITokenStats(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, data)
}
func parseOpsOpenAITokenStatsFilter(c *gin.Context) (*service.OpsOpenAITokenStatsFilter, error) {
if c == nil {
return nil, fmt.Errorf("invalid request")
}
timeRange := strings.TrimSpace(c.Query("time_range"))
if timeRange == "" {
timeRange = "30d"
}
dur, ok := parseOpsOpenAITokenStatsDuration(timeRange)
if !ok {
return nil, fmt.Errorf("invalid time_range")
}
end := time.Now().UTC()
start := end.Add(-dur)
filter := &service.OpsOpenAITokenStatsFilter{
TimeRange: timeRange,
StartTime: start,
EndTime: end,
Platform: strings.TrimSpace(c.Query("platform")),
}
if v := strings.TrimSpace(c.Query("group_id")); v != "" {
id, err := strconv.ParseInt(v, 10, 64)
if err != nil || id <= 0 {
return nil, fmt.Errorf("invalid group_id")
}
filter.GroupID = &id
}
topNRaw := strings.TrimSpace(c.Query("top_n"))
pageRaw := strings.TrimSpace(c.Query("page"))
pageSizeRaw := strings.TrimSpace(c.Query("page_size"))
if topNRaw != "" && (pageRaw != "" || pageSizeRaw != "") {
return nil, fmt.Errorf("invalid query: top_n cannot be used with page/page_size")
}
if topNRaw != "" {
topN, err := strconv.Atoi(topNRaw)
if err != nil || topN < 1 || topN > 100 {
return nil, fmt.Errorf("invalid top_n")
}
filter.TopN = topN
return filter, nil
}
filter.Page = 1
filter.PageSize = 20
if pageRaw != "" {
page, err := strconv.Atoi(pageRaw)
if err != nil || page < 1 {
return nil, fmt.Errorf("invalid page")
}
filter.Page = page
}
if pageSizeRaw != "" {
pageSize, err := strconv.Atoi(pageSizeRaw)
if err != nil || pageSize < 1 || pageSize > 100 {
return nil, fmt.Errorf("invalid page_size")
}
filter.PageSize = pageSize
}
return filter, nil
}
func parseOpsOpenAITokenStatsDuration(v string) (time.Duration, bool) {
switch strings.TrimSpace(v) {
case "30m":
return 30 * time.Minute, true
case "1h":
return time.Hour, true
case "1d":
return 24 * time.Hour, true
case "15d":
return 15 * 24 * time.Hour, true
case "30d":
return 30 * 24 * time.Hour, true
default:
return 0, false
}
}
func pickThroughputBucketSeconds(window time.Duration) int {
// Keep buckets predictable and avoid huge responses.
switch {

View File

@@ -0,0 +1,173 @@
package admin
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type testSettingRepo struct {
values map[string]string
}
func newTestSettingRepo() *testSettingRepo {
return &testSettingRepo{values: map[string]string{}}
}
func (s *testSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) {
v, err := s.GetValue(ctx, key)
if err != nil {
return nil, err
}
return &service.Setting{Key: key, Value: v}, nil
}
func (s *testSettingRepo) GetValue(ctx context.Context, key string) (string, error) {
v, ok := s.values[key]
if !ok {
return "", service.ErrSettingNotFound
}
return v, nil
}
func (s *testSettingRepo) Set(ctx context.Context, key, value string) error {
s.values[key] = value
return nil
}
func (s *testSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
out := make(map[string]string, len(keys))
for _, k := range keys {
if v, ok := s.values[k]; ok {
out[k] = v
}
}
return out, nil
}
func (s *testSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error {
for k, v := range settings {
s.values[k] = v
}
return nil
}
func (s *testSettingRepo) GetAll(ctx context.Context) (map[string]string, error) {
out := make(map[string]string, len(s.values))
for k, v := range s.values {
out[k] = v
}
return out, nil
}
func (s *testSettingRepo) Delete(ctx context.Context, key string) error {
delete(s.values, key)
return nil
}
func newOpsRuntimeRouter(handler *OpsHandler, withUser bool) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
if withUser {
r.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 7})
c.Next()
})
}
r.GET("/runtime/logging", handler.GetRuntimeLogConfig)
r.PUT("/runtime/logging", handler.UpdateRuntimeLogConfig)
r.POST("/runtime/logging/reset", handler.ResetRuntimeLogConfig)
return r
}
func newRuntimeOpsService(t *testing.T) *service.OpsService {
t.Helper()
if err := logger.Init(logger.InitOptions{
Level: "info",
Format: "json",
ServiceName: "sub2api",
Environment: "test",
Output: logger.OutputOptions{
ToStdout: false,
ToFile: false,
},
}); err != nil {
t.Fatalf("init logger: %v", err)
}
settingRepo := newTestSettingRepo()
cfg := &config.Config{
Ops: config.OpsConfig{Enabled: true},
Log: config.LogConfig{
Level: "info",
Caller: true,
StacktraceLevel: "error",
Sampling: config.LogSamplingConfig{
Enabled: false,
Initial: 100,
Thereafter: 100,
},
},
}
return service.NewOpsService(nil, settingRepo, cfg, nil, nil, nil, nil, nil, nil, nil, nil)
}
func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t))
r := newOpsRuntimeRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/runtime/logging", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d, want 200", w.Code)
}
}
func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t))
r := newOpsRuntimeRouter(h, false)
body := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}`
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", w.Code)
}
}
func TestOpsRuntimeLoggingHandler_UpdateAndResetSuccess(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t))
r := newOpsRuntimeRouter(h, true)
payload := map[string]any{
"level": "debug",
"enable_sampling": false,
"sampling_initial": 100,
"sampling_thereafter": 100,
"caller": true,
"stacktrace_level": "error",
"retention_days": 30,
}
raw, _ := json.Marshal(payload)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewReader(raw))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("update status=%d, want 200, body=%s", w.Code, w.Body.String())
}
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, "/runtime/logging/reset", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("reset status=%d, want 200, body=%s", w.Code, w.Body.String())
}
}

View File

@@ -4,6 +4,7 @@ import (
"net/http"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -101,6 +102,84 @@ func (h *OpsHandler) UpdateAlertRuntimeSettings(c *gin.Context) {
response.Success(c, updated)
}
// GetRuntimeLogConfig returns runtime log config (DB-backed).
// GET /api/v1/admin/ops/runtime/logging
func (h *OpsHandler) GetRuntimeLogConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
cfg, err := h.opsService.GetRuntimeLogConfig(c.Request.Context())
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to get runtime log config")
return
}
response.Success(c, cfg)
}
// UpdateRuntimeLogConfig updates runtime log config and applies changes immediately.
// PUT /api/v1/admin/ops/runtime/logging
func (h *OpsHandler) UpdateRuntimeLogConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
var req service.OpsRuntimeLogConfig
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
updated, err := h.opsService.UpdateRuntimeLogConfig(c.Request.Context(), &req, subject.UserID)
if err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, updated)
}
// ResetRuntimeLogConfig removes runtime override and falls back to env/yaml baseline.
// POST /api/v1/admin/ops/runtime/logging/reset
func (h *OpsHandler) ResetRuntimeLogConfig(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
updated, err := h.opsService.ResetRuntimeLogConfig(c.Request.Context(), subject.UserID)
if err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, updated)
}
// GetAdvancedSettings returns Ops advanced settings (DB-backed).
// GET /api/v1/admin/ops/advanced-settings
func (h *OpsHandler) GetAdvancedSettings(c *gin.Context) {

View File

@@ -0,0 +1,174 @@
package admin
import (
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type opsSystemLogCleanupRequest struct {
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
Level string `json:"level"`
Component string `json:"component"`
RequestID string `json:"request_id"`
ClientRequestID string `json:"client_request_id"`
UserID *int64 `json:"user_id"`
AccountID *int64 `json:"account_id"`
Platform string `json:"platform"`
Model string `json:"model"`
Query string `json:"q"`
}
// ListSystemLogs returns indexed system logs.
// GET /api/v1/admin/ops/system-logs
func (h *OpsHandler) ListSystemLogs(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
page, pageSize := response.ParsePagination(c)
if pageSize > 200 {
pageSize = 200
}
start, end, err := parseOpsTimeRange(c, "1h")
if err != nil {
response.BadRequest(c, err.Error())
return
}
filter := &service.OpsSystemLogFilter{
Page: page,
PageSize: pageSize,
StartTime: &start,
EndTime: &end,
Level: strings.TrimSpace(c.Query("level")),
Component: strings.TrimSpace(c.Query("component")),
RequestID: strings.TrimSpace(c.Query("request_id")),
ClientRequestID: strings.TrimSpace(c.Query("client_request_id")),
Platform: strings.TrimSpace(c.Query("platform")),
Model: strings.TrimSpace(c.Query("model")),
Query: strings.TrimSpace(c.Query("q")),
}
if v := strings.TrimSpace(c.Query("user_id")); v != "" {
id, parseErr := strconv.ParseInt(v, 10, 64)
if parseErr != nil || id <= 0 {
response.BadRequest(c, "Invalid user_id")
return
}
filter.UserID = &id
}
if v := strings.TrimSpace(c.Query("account_id")); v != "" {
id, parseErr := strconv.ParseInt(v, 10, 64)
if parseErr != nil || id <= 0 {
response.BadRequest(c, "Invalid account_id")
return
}
filter.AccountID = &id
}
result, err := h.opsService.ListSystemLogs(c.Request.Context(), filter)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Paginated(c, result.Logs, int64(result.Total), result.Page, result.PageSize)
}
// CleanupSystemLogs deletes indexed system logs by filter.
// POST /api/v1/admin/ops/system-logs/cleanup
func (h *OpsHandler) CleanupSystemLogs(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
var req opsSystemLogCleanupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request body")
return
}
parseTS := func(raw string) (*time.Time, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil, nil
}
if t, err := time.Parse(time.RFC3339Nano, raw); err == nil {
return &t, nil
}
t, err := time.Parse(time.RFC3339, raw)
if err != nil {
return nil, err
}
return &t, nil
}
start, err := parseTS(req.StartTime)
if err != nil {
response.BadRequest(c, "Invalid start_time")
return
}
end, err := parseTS(req.EndTime)
if err != nil {
response.BadRequest(c, "Invalid end_time")
return
}
filter := &service.OpsSystemLogCleanupFilter{
StartTime: start,
EndTime: end,
Level: strings.TrimSpace(req.Level),
Component: strings.TrimSpace(req.Component),
RequestID: strings.TrimSpace(req.RequestID),
ClientRequestID: strings.TrimSpace(req.ClientRequestID),
UserID: req.UserID,
AccountID: req.AccountID,
Platform: strings.TrimSpace(req.Platform),
Model: strings.TrimSpace(req.Model),
Query: strings.TrimSpace(req.Query),
}
deleted, err := h.opsService.CleanupSystemLogs(c.Request.Context(), filter, subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": deleted})
}
// GetSystemLogIngestionHealth returns sink health metrics.
// GET /api/v1/admin/ops/system-logs/health
func (h *OpsHandler) GetSystemLogIngestionHealth(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, h.opsService.GetSystemLogSinkHealth())
}

View File

@@ -0,0 +1,233 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
type responseEnvelope struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data"`
}
func newOpsSystemLogTestRouter(handler *OpsHandler, withUser bool) *gin.Engine {
gin.SetMode(gin.TestMode)
r := gin.New()
if withUser {
r.Use(func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 99})
c.Next()
})
}
r.GET("/logs", handler.ListSystemLogs)
r.POST("/logs/cleanup", handler.CleanupSystemLogs)
r.GET("/logs/health", handler.GetSystemLogIngestionHealth)
return r
}
func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) {
h := NewOpsHandler(nil)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("status=%d, want 503", w.Code)
}
}
func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs?user_id=abc", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_ListInvalidAccountID(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs?account_id=-1", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) {
svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("status=%d, want 404", w.Code)
}
}
func TestOpsSystemLogHandler_ListSuccess(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs?time_range=30m&page=1&page_size=20", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d, want 200", w.Code)
}
var resp responseEnvelope
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("unmarshal response: %v", err)
}
if resp.Code != 0 {
t.Fatalf("unexpected response code: %+v", resp)
}
}
func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Fatalf("status=%d, want 401", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{bad-json`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"start_time":"bad","request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"end_time":"bad","request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status=%d, want 400", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupServiceUnavailable(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("status=%d, want 503", w.Code)
}
}
func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) {
svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("status=%d, want 404", w.Code)
}
}
func TestOpsSystemLogHandler_Health(t *testing.T) {
sink := service.NewOpsSystemLogSink(nil)
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink)
h := NewOpsHandler(svc)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d, want 200", w.Code)
}
}
func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T) {
h := NewOpsHandler(nil)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/logs/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("status=%d, want 503", w.Code)
}
svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil)
h = NewOpsHandler(svc)
r = newOpsSystemLogTestRouter(h, false)
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/logs/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Fatalf("status=%d, want 404", w.Code)
}
}

View File

@@ -3,7 +3,6 @@ package admin
import (
"context"
"encoding/json"
"log"
"math"
"net"
"net/http"
@@ -16,6 +15,7 @@ import (
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
@@ -62,7 +62,8 @@ const (
)
var wsConnCount atomic.Int32
var wsConnCountByIP sync.Map // map[string]*atomic.Int32
var wsConnCountByIPMu sync.Mutex
var wsConnCountByIP = make(map[string]int32)
const qpsWSIdleStopDelay = 30 * time.Second
@@ -252,7 +253,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) {
stats, err := opsService.GetWindowStats(ctx, now.Add(-c.requestCountWindow), now)
if err != nil || stats == nil {
if err != nil {
log.Printf("[OpsWS] refresh: get window stats failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: get window stats failed: %v", err)
}
return
}
@@ -278,7 +279,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) {
msg, err := json.Marshal(payload)
if err != nil {
log.Printf("[OpsWS] refresh: marshal payload failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: marshal payload failed: %v", err)
return
}
@@ -338,7 +339,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
// Reserve a global slot before upgrading the connection to keep the limit strict.
if !tryAcquireOpsWSTotalSlot(opsWSLimits.MaxConns) {
log.Printf("[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns)
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
return
}
@@ -350,7 +351,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
if opsWSLimits.MaxConnsPerIP > 0 && clientIP != "" {
if !tryAcquireOpsWSIPSlot(clientIP, opsWSLimits.MaxConnsPerIP) {
log.Printf("[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP)
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
return
}
@@ -359,7 +360,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) {
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Printf("[OpsWS] upgrade failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] upgrade failed: %v", err)
return
}
@@ -389,42 +390,31 @@ func tryAcquireOpsWSIPSlot(clientIP string, limit int32) bool {
if strings.TrimSpace(clientIP) == "" || limit <= 0 {
return true
}
v, _ := wsConnCountByIP.LoadOrStore(clientIP, &atomic.Int32{})
counter, ok := v.(*atomic.Int32)
if !ok {
wsConnCountByIPMu.Lock()
defer wsConnCountByIPMu.Unlock()
current := wsConnCountByIP[clientIP]
if current >= limit {
return false
}
for {
current := counter.Load()
if current >= limit {
return false
}
if counter.CompareAndSwap(current, current+1) {
return true
}
}
wsConnCountByIP[clientIP] = current + 1
return true
}
func releaseOpsWSIPSlot(clientIP string) {
if strings.TrimSpace(clientIP) == "" {
return
}
v, ok := wsConnCountByIP.Load(clientIP)
wsConnCountByIPMu.Lock()
defer wsConnCountByIPMu.Unlock()
current, ok := wsConnCountByIP[clientIP]
if !ok {
return
}
counter, ok := v.(*atomic.Int32)
if !ok {
if current <= 1 {
delete(wsConnCountByIP, clientIP)
return
}
next := counter.Add(-1)
if next <= 0 {
// Best-effort cleanup; safe even if a new slot was acquired concurrently.
wsConnCountByIP.Delete(clientIP)
}
wsConnCountByIP[clientIP] = current - 1
}
func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
@@ -452,7 +442,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
conn.SetReadLimit(qpsWSMaxReadBytes)
if err := conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)); err != nil {
log.Printf("[OpsWS] set read deadline failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] set read deadline failed: %v", err)
return
}
conn.SetPongHandler(func(string) error {
@@ -471,7 +461,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
_, _, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
log.Printf("[OpsWS] read failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] read failed: %v", err)
}
return
}
@@ -508,7 +498,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
continue
}
if err := writeWithTimeout(websocket.TextMessage, msg); err != nil {
log.Printf("[OpsWS] write failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] write failed: %v", err)
cancel()
closeConn()
wg.Wait()
@@ -517,7 +507,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) {
case <-pingTicker.C:
if err := writeWithTimeout(websocket.PingMessage, nil); err != nil {
log.Printf("[OpsWS] ping failed: %v", err)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] ping failed: %v", err)
cancel()
closeConn()
wg.Wait()
@@ -666,14 +656,14 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig {
if parsed, err := strconv.ParseBool(v); err == nil {
cfg.TrustProxy = parsed
} else {
log.Printf("[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy)
}
}
if raw := strings.TrimSpace(os.Getenv(envOpsWSTrustedProxies)); raw != "" {
prefixes, invalid := parseTrustedProxyList(raw)
if len(invalid) > 0 {
log.Printf("[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", "))
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", "))
}
cfg.TrustedProxies = prefixes
}
@@ -684,7 +674,7 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig {
case OriginPolicyStrict, OriginPolicyPermissive:
cfg.OriginPolicy = normalized
default:
log.Printf("[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy)
}
}
@@ -701,14 +691,14 @@ func loadOpsWSRuntimeLimitsFromEnv() opsWSRuntimeLimits {
if parsed, err := strconv.Atoi(v); err == nil && parsed > 0 {
cfg.MaxConns = int32(parsed)
} else {
log.Printf("[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns)
}
}
if v := strings.TrimSpace(os.Getenv(envOpsWSMaxConnsPerIP)); v != "" {
if parsed, err := strconv.Atoi(v); err == nil && parsed >= 0 {
cfg.MaxConnsPerIP = int32(parsed)
} else {
log.Printf("[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP)
logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP)
}
}
return cfg

View File

@@ -1,6 +1,7 @@
package admin
import (
"context"
"strconv"
"strings"
@@ -63,9 +64,9 @@ func (h *ProxyHandler) List(c *gin.Context) {
return
}
out := make([]dto.ProxyWithAccountCount, 0, len(proxies))
out := make([]dto.AdminProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i]))
out = append(out, *dto.ProxyWithAccountCountFromServiceAdmin(&proxies[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
@@ -82,9 +83,9 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
out := make([]dto.ProxyWithAccountCount, 0, len(proxies))
out := make([]dto.AdminProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i]))
out = append(out, *dto.ProxyWithAccountCountFromServiceAdmin(&proxies[i]))
}
response.Success(c, out)
return
@@ -96,9 +97,9 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
return
}
out := make([]dto.Proxy, 0, len(proxies))
out := make([]dto.AdminProxy, 0, len(proxies))
for i := range proxies {
out = append(out, *dto.ProxyFromService(&proxies[i]))
out = append(out, *dto.ProxyFromServiceAdmin(&proxies[i]))
}
response.Success(c, out)
}
@@ -118,7 +119,7 @@ func (h *ProxyHandler) GetByID(c *gin.Context) {
return
}
response.Success(c, dto.ProxyFromService(proxy))
response.Success(c, dto.ProxyFromServiceAdmin(proxy))
}
// Create handles creating a new proxy
@@ -130,20 +131,20 @@ func (h *ProxyHandler) Create(c *gin.Context) {
return
}
proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
Name: strings.TrimSpace(req.Name),
Protocol: strings.TrimSpace(req.Protocol),
Host: strings.TrimSpace(req.Host),
Port: req.Port,
Username: strings.TrimSpace(req.Username),
Password: strings.TrimSpace(req.Password),
executeAdminIdempotentJSON(c, "admin.proxies.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
proxy, err := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{
Name: strings.TrimSpace(req.Name),
Protocol: strings.TrimSpace(req.Protocol),
Host: strings.TrimSpace(req.Host),
Port: req.Port,
Username: strings.TrimSpace(req.Username),
Password: strings.TrimSpace(req.Password),
})
if err != nil {
return nil, err
}
return dto.ProxyFromServiceAdmin(proxy), nil
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.ProxyFromService(proxy))
}
// Update handles updating a proxy
@@ -175,7 +176,7 @@ func (h *ProxyHandler) Update(c *gin.Context) {
return
}
response.Success(c, dto.ProxyFromService(proxy))
response.Success(c, dto.ProxyFromServiceAdmin(proxy))
}
// Delete handles deleting a proxy
@@ -236,6 +237,24 @@ func (h *ProxyHandler) Test(c *gin.Context) {
response.Success(c, result)
}
// CheckQuality handles checking proxy quality across common AI targets.
// POST /api/v1/admin/proxies/:id/quality-check
func (h *ProxyHandler) CheckQuality(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
result, err := h.adminService.CheckProxyQuality(c.Request.Context(), proxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// GetStats handles getting proxy statistics
// GET /api/v1/admin/proxies/:id/stats
func (h *ProxyHandler) GetStats(c *gin.Context) {

View File

@@ -2,12 +2,15 @@ package admin
import (
"bytes"
"context"
"encoding/csv"
"errors"
"fmt"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -16,13 +19,15 @@ import (
// RedeemHandler handles admin redeem code management
type RedeemHandler struct {
adminService service.AdminService
adminService service.AdminService
redeemService *service.RedeemService
}
// NewRedeemHandler creates a new admin redeem handler
func NewRedeemHandler(adminService service.AdminService) *RedeemHandler {
func NewRedeemHandler(adminService service.AdminService, redeemService *service.RedeemService) *RedeemHandler {
return &RedeemHandler{
adminService: adminService,
adminService: adminService,
redeemService: redeemService,
}
}
@@ -35,6 +40,15 @@ type GenerateRedeemCodesRequest struct {
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用默认30天最大100年
}
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
type CreateAndRedeemCodeRequest struct {
Code string `json:"code" binding:"required,min=3,max=128"`
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
Value float64 `json:"value" binding:"required,gt=0"`
UserID int64 `json:"user_id" binding:"required,gt=0"`
Notes string `json:"notes"`
}
// List handles listing all redeem codes with pagination
// GET /api/v1/admin/redeem-codes
func (h *RedeemHandler) List(c *gin.Context) {
@@ -88,23 +102,99 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
return
}
codes, err := h.adminService.GenerateRedeemCodes(c.Request.Context(), &service.GenerateRedeemCodesInput{
Count: req.Count,
Type: req.Type,
Value: req.Value,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
executeAdminIdempotentJSON(c, "admin.redeem_codes.generate", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
codes, execErr := h.adminService.GenerateRedeemCodes(ctx, &service.GenerateRedeemCodesInput{
Count: req.Count,
Type: req.Type,
Value: req.Value,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
})
if execErr != nil {
return nil, execErr
}
out := make([]dto.AdminRedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
}
return out, nil
})
if err != nil {
response.ErrorFrom(c, err)
}
// CreateAndRedeem creates a fixed redeem code and redeems it for a target user in one step.
// POST /api/v1/admin/redeem-codes/create-and-redeem
func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
if h.redeemService == nil {
response.InternalError(c, "redeem service not configured")
return
}
out := make([]dto.AdminRedeemCode, 0, len(codes))
for i := range codes {
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
var req CreateAndRedeemCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
response.Success(c, out)
req.Code = strings.TrimSpace(req.Code)
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
existing, err := h.redeemService.GetByCode(ctx, req.Code)
if err == nil {
return h.resolveCreateAndRedeemExisting(ctx, existing, req.UserID)
}
if !errors.Is(err, service.ErrRedeemCodeNotFound) {
return nil, err
}
createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{
Code: req.Code,
Type: req.Type,
Value: req.Value,
Status: service.StatusUnused,
Notes: req.Notes,
})
if createErr != nil {
// Unique code race: if code now exists, use idempotent semantics by used_by.
existingAfterCreateErr, getErr := h.redeemService.GetByCode(ctx, req.Code)
if getErr == nil {
return h.resolveCreateAndRedeemExisting(ctx, existingAfterCreateErr, req.UserID)
}
return nil, createErr
}
redeemed, redeemErr := h.redeemService.Redeem(ctx, req.UserID, req.Code)
if redeemErr != nil {
return nil, redeemErr
}
return gin.H{"redeem_code": dto.RedeemCodeFromServiceAdmin(redeemed)}, nil
})
}
func (h *RedeemHandler) resolveCreateAndRedeemExisting(ctx context.Context, existing *service.RedeemCode, userID int64) (any, error) {
if existing == nil {
return nil, infraerrors.Conflict("REDEEM_CODE_CONFLICT", "redeem code conflict")
}
// If previous run created the code but crashed before redeem, redeem it now.
if existing.CanUse() {
redeemed, err := h.redeemService.Redeem(ctx, userID, existing.Code)
if err == nil {
return gin.H{"redeem_code": dto.RedeemCodeFromServiceAdmin(redeemed)}, nil
}
if !errors.Is(err, service.ErrRedeemCodeUsed) {
return nil, err
}
latest, getErr := h.redeemService.GetByCode(ctx, existing.Code)
if getErr == nil {
existing = latest
}
}
if existing.UsedBy != nil && *existing.UsedBy == userID {
return gin.H{"redeem_code": dto.RedeemCodeFromServiceAdmin(existing)}, nil
}
return nil, infraerrors.Conflict("REDEEM_CODE_CONFLICT", "redeem code already used by another user")
}
// Delete handles deleting a redeem code
@@ -202,7 +292,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
writer := csv.NewWriter(&buf)
// Write header
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"}); err != nil {
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_by_email", "used_at", "created_at"}); err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
return
}
@@ -213,6 +303,10 @@ func (h *RedeemHandler) Export(c *gin.Context) {
if code.UsedBy != nil {
usedBy = fmt.Sprintf("%d", *code.UsedBy)
}
usedByEmail := ""
if code.User != nil {
usedByEmail = code.User.Email
}
usedAt := ""
if code.UsedAt != nil {
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
@@ -224,6 +318,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
fmt.Sprintf("%.2f", code.Value),
code.Status,
usedBy,
usedByEmail,
usedAt,
code.CreatedAt.Format("2006-01-02 15:04:05"),
}); err != nil {

View File

@@ -0,0 +1,97 @@
//go:build unit
package admin
import (
"testing"
"github.com/stretchr/testify/require"
)
// truncateSearchByRune 模拟 user_handler.go 中的 search 截断逻辑
func truncateSearchByRune(search string, maxRunes int) string {
if runes := []rune(search); len(runes) > maxRunes {
return string(runes[:maxRunes])
}
return search
}
func TestTruncateSearchByRune(t *testing.T) {
tests := []struct {
name string
input string
maxRunes int
wantLen int // 期望的 rune 长度
}{
{
name: "纯中文超长",
input: string(make([]rune, 150)),
maxRunes: 100,
wantLen: 100,
},
{
name: "纯 ASCII 超长",
input: string(make([]byte, 150)),
maxRunes: 100,
wantLen: 100,
},
{
name: "空字符串",
input: "",
maxRunes: 100,
wantLen: 0,
},
{
name: "恰好 100 个字符",
input: string(make([]rune, 100)),
maxRunes: 100,
wantLen: 100,
},
{
name: "不足 100 字符不截断",
input: "hello世界",
maxRunes: 100,
wantLen: 7,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := truncateSearchByRune(tc.input, tc.maxRunes)
require.Equal(t, tc.wantLen, len([]rune(result)))
})
}
}
func TestTruncateSearchByRune_PreservesMultibyte(t *testing.T) {
// 101 个中文字符,截断到 100 个后应该仍然是有效 UTF-8
input := ""
for i := 0; i < 101; i++ {
input += "中"
}
result := truncateSearchByRune(input, 100)
require.Equal(t, 100, len([]rune(result)))
// 验证截断结果是有效的 UTF-8每个中文字符 3 字节)
require.Equal(t, 300, len(result))
}
func TestTruncateSearchByRune_MixedASCIIAndMultibyte(t *testing.T) {
// 50 个 ASCII + 51 个中文 = 101 个 rune
input := ""
for i := 0; i < 50; i++ {
input += "a"
}
for i := 0; i < 51; i++ {
input += "中"
}
result := truncateSearchByRune(input, 100)
runes := []rune(result)
require.Equal(t, 100, len(runes))
// 前 50 个应该是 'a',后 50 个应该是 '中'
require.Equal(t, 'a', runes[0])
require.Equal(t, 'a', runes[49])
require.Equal(t, '中', runes[50])
require.Equal(t, '中', runes[99])
}

View File

@@ -1,7 +1,13 @@
package admin
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"log"
"net/http"
"regexp"
"strings"
"time"
@@ -14,21 +20,38 @@ import (
"github.com/gin-gonic/gin"
)
// semverPattern 预编译 semver 格式校验正则
var semverPattern = regexp.MustCompile(`^\d+\.\d+\.\d+$`)
// menuItemIDPattern validates custom menu item IDs: alphanumeric, hyphens, underscores only.
var menuItemIDPattern = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
// generateMenuItemID generates a short random hex ID for a custom menu item.
func generateMenuItemID() (string, error) {
b := make([]byte, 8)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("generate menu item ID: %w", err)
}
return hex.EncodeToString(b), nil
}
// SettingHandler 系统设置处理器
type SettingHandler struct {
settingService *service.SettingService
emailService *service.EmailService
turnstileService *service.TurnstileService
opsService *service.OpsService
soraS3Storage *service.SoraS3Storage
}
// NewSettingHandler 创建系统设置处理器
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService) *SettingHandler {
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, soraS3Storage *service.SoraS3Storage) *SettingHandler {
return &SettingHandler{
settingService: settingService,
emailService: emailService,
turnstileService: turnstileService,
opsService: opsService,
soraS3Storage: soraS3Storage,
}
}
@@ -43,6 +66,13 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
// Check if ops monitoring is enabled (respects config.ops.enabled)
opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context())
defaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(settings.DefaultSubscriptions))
for _, sub := range settings.DefaultSubscriptions {
defaultSubscriptions = append(defaultSubscriptions, dto.DefaultSubscriptionSetting{
GroupID: sub.GroupID,
ValidityDays: sub.ValidityDays,
})
}
response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
@@ -76,8 +106,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
HideCcsImportButton: settings.HideCcsImportButton,
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
SoraClientEnabled: settings.SoraClientEnabled,
CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: settings.EnableModelFallback,
FallbackModelAnthropic: settings.FallbackModelAnthropic,
FallbackModelOpenAI: settings.FallbackModelOpenAI,
@@ -89,6 +122,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled,
OpsQueryModeDefault: settings.OpsQueryModeDefault,
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
})
}
@@ -123,20 +157,23 @@ type UpdateSettingsRequest struct {
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
// OEM设置
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
APIBaseURL string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
DocURL string `json:"doc_url"`
HomeContent string `json:"home_content"`
HideCcsImportButton bool `json:"hide_ccs_import_button"`
PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@@ -154,6 +191,8 @@ type UpdateSettingsRequest struct {
OpsRealtimeMonitoringEnabled *bool `json:"ops_realtime_monitoring_enabled"`
OpsQueryModeDefault *string `json:"ops_query_mode_default"`
OpsMetricsIntervalSeconds *int `json:"ops_metrics_interval_seconds"`
MinClaudeCodeVersion string `json:"min_claude_code_version"`
}
// UpdateSettings 更新系统设置
@@ -181,6 +220,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
if req.SMTPPort <= 0 {
req.SMTPPort = 587
}
req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions)
// Turnstile 参数验证
if req.TurnstileEnabled {
@@ -276,6 +316,84 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
// 自定义菜单项验证
const (
maxCustomMenuItems = 20
maxMenuItemLabelLen = 50
maxMenuItemURLLen = 2048
maxMenuItemIconSVGLen = 10 * 1024 // 10KB
maxMenuItemIDLen = 32
)
customMenuJSON := previousSettings.CustomMenuItems
if req.CustomMenuItems != nil {
items := *req.CustomMenuItems
if len(items) > maxCustomMenuItems {
response.BadRequest(c, "Too many custom menu items (max 20)")
return
}
for i, item := range items {
if strings.TrimSpace(item.Label) == "" {
response.BadRequest(c, "Custom menu item label is required")
return
}
if len(item.Label) > maxMenuItemLabelLen {
response.BadRequest(c, "Custom menu item label is too long (max 50 characters)")
return
}
if strings.TrimSpace(item.URL) == "" {
response.BadRequest(c, "Custom menu item URL is required")
return
}
if len(item.URL) > maxMenuItemURLLen {
response.BadRequest(c, "Custom menu item URL is too long (max 2048 characters)")
return
}
if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(item.URL)); err != nil {
response.BadRequest(c, "Custom menu item URL must be an absolute http(s) URL")
return
}
if item.Visibility != "user" && item.Visibility != "admin" {
response.BadRequest(c, "Custom menu item visibility must be 'user' or 'admin'")
return
}
if len(item.IconSVG) > maxMenuItemIconSVGLen {
response.BadRequest(c, "Custom menu item icon SVG is too large (max 10KB)")
return
}
// Auto-generate ID if missing
if strings.TrimSpace(item.ID) == "" {
id, err := generateMenuItemID()
if err != nil {
response.Error(c, http.StatusInternalServerError, "Failed to generate menu item ID")
return
}
items[i].ID = id
} else if len(item.ID) > maxMenuItemIDLen {
response.BadRequest(c, "Custom menu item ID is too long (max 32 characters)")
return
} else if !menuItemIDPattern.MatchString(item.ID) {
response.BadRequest(c, "Custom menu item ID contains invalid characters (only a-z, A-Z, 0-9, - and _ are allowed)")
return
}
}
// ID uniqueness check
seen := make(map[string]struct{}, len(items))
for _, item := range items {
if _, exists := seen[item.ID]; exists {
response.BadRequest(c, "Duplicate custom menu item ID: "+item.ID)
return
}
seen[item.ID] = struct{}{}
}
menuBytes, err := json.Marshal(items)
if err != nil {
response.BadRequest(c, "Failed to serialize custom menu items")
return
}
customMenuJSON = string(menuBytes)
}
// Ops metrics collector interval validation (seconds).
if req.OpsMetricsIntervalSeconds != nil {
v := *req.OpsMetricsIntervalSeconds
@@ -287,6 +405,21 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
req.OpsMetricsIntervalSeconds = &v
}
defaultSubscriptions := make([]service.DefaultSubscriptionSetting, 0, len(req.DefaultSubscriptions))
for _, sub := range req.DefaultSubscriptions {
defaultSubscriptions = append(defaultSubscriptions, service.DefaultSubscriptionSetting{
GroupID: sub.GroupID,
ValidityDays: sub.ValidityDays,
})
}
// 验证最低版本号格式(空字符串=禁用,或合法 semver
if req.MinClaudeCodeVersion != "" {
if !semverPattern.MatchString(req.MinClaudeCodeVersion) {
response.Error(c, http.StatusBadRequest, "min_claude_code_version must be empty or a valid semver (e.g. 2.1.63)")
return
}
}
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
@@ -319,8 +452,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
HideCcsImportButton: req.HideCcsImportButton,
PurchaseSubscriptionEnabled: purchaseEnabled,
PurchaseSubscriptionURL: purchaseURL,
SoraClientEnabled: req.SoraClientEnabled,
CustomMenuItems: customMenuJSON,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI,
@@ -328,6 +464,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
OpsMonitoringEnabled: func() bool {
if req.OpsMonitoringEnabled != nil {
return *req.OpsMonitoringEnabled
@@ -367,6 +504,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions))
for _, sub := range updatedSettings.DefaultSubscriptions {
updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{
GroupID: sub.GroupID,
ValidityDays: sub.ValidityDays,
})
}
response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
@@ -400,8 +544,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
HideCcsImportButton: updatedSettings.HideCcsImportButton,
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
SoraClientEnabled: updatedSettings.SoraClientEnabled,
CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
DefaultSubscriptions: updatedDefaultSubscriptions,
EnableModelFallback: updatedSettings.EnableModelFallback,
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
@@ -413,6 +560,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled,
OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
})
}
@@ -522,6 +670,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.DefaultBalance != after.DefaultBalance {
changed = append(changed, "default_balance")
}
if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) {
changed = append(changed, "default_subscriptions")
}
if before.EnableModelFallback != after.EnableModelFallback {
changed = append(changed, "enable_model_fallback")
}
@@ -555,9 +706,50 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.OpsMetricsIntervalSeconds != after.OpsMetricsIntervalSeconds {
changed = append(changed, "ops_metrics_interval_seconds")
}
if before.MinClaudeCodeVersion != after.MinClaudeCodeVersion {
changed = append(changed, "min_claude_code_version")
}
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
changed = append(changed, "purchase_subscription_enabled")
}
if before.PurchaseSubscriptionURL != after.PurchaseSubscriptionURL {
changed = append(changed, "purchase_subscription_url")
}
if before.CustomMenuItems != after.CustomMenuItems {
changed = append(changed, "custom_menu_items")
}
return changed
}
func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto.DefaultSubscriptionSetting {
if len(input) == 0 {
return nil
}
normalized := make([]dto.DefaultSubscriptionSetting, 0, len(input))
for _, item := range input {
if item.GroupID <= 0 || item.ValidityDays <= 0 {
continue
}
if item.ValidityDays > service.MaxValidityDays {
item.ValidityDays = service.MaxValidityDays
}
normalized = append(normalized, item)
}
return normalized
}
func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i].GroupID != b[i].GroupID || a[i].ValidityDays != b[i].ValidityDays {
return false
}
}
return true
}
// TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host" binding:"required"`
@@ -750,6 +942,384 @@ func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
})
}
func toSoraS3SettingsDTO(settings *service.SoraS3Settings) dto.SoraS3Settings {
if settings == nil {
return dto.SoraS3Settings{}
}
return dto.SoraS3Settings{
Enabled: settings.Enabled,
Endpoint: settings.Endpoint,
Region: settings.Region,
Bucket: settings.Bucket,
AccessKeyID: settings.AccessKeyID,
SecretAccessKeyConfigured: settings.SecretAccessKeyConfigured,
Prefix: settings.Prefix,
ForcePathStyle: settings.ForcePathStyle,
CDNURL: settings.CDNURL,
DefaultStorageQuotaBytes: settings.DefaultStorageQuotaBytes,
}
}
func toSoraS3ProfileDTO(profile service.SoraS3Profile) dto.SoraS3Profile {
return dto.SoraS3Profile{
ProfileID: profile.ProfileID,
Name: profile.Name,
IsActive: profile.IsActive,
Enabled: profile.Enabled,
Endpoint: profile.Endpoint,
Region: profile.Region,
Bucket: profile.Bucket,
AccessKeyID: profile.AccessKeyID,
SecretAccessKeyConfigured: profile.SecretAccessKeyConfigured,
Prefix: profile.Prefix,
ForcePathStyle: profile.ForcePathStyle,
CDNURL: profile.CDNURL,
DefaultStorageQuotaBytes: profile.DefaultStorageQuotaBytes,
UpdatedAt: profile.UpdatedAt,
}
}
func validateSoraS3RequiredWhenEnabled(enabled bool, endpoint, bucket, accessKeyID, secretAccessKey string, hasStoredSecret bool) error {
if !enabled {
return nil
}
if strings.TrimSpace(endpoint) == "" {
return fmt.Errorf("S3 Endpoint is required when enabled")
}
if strings.TrimSpace(bucket) == "" {
return fmt.Errorf("S3 Bucket is required when enabled")
}
if strings.TrimSpace(accessKeyID) == "" {
return fmt.Errorf("S3 Access Key ID is required when enabled")
}
if strings.TrimSpace(secretAccessKey) != "" || hasStoredSecret {
return nil
}
return fmt.Errorf("S3 Secret Access Key is required when enabled")
}
func findSoraS3ProfileByID(items []service.SoraS3Profile, profileID string) *service.SoraS3Profile {
for idx := range items {
if items[idx].ProfileID == profileID {
return &items[idx]
}
}
return nil
}
// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置接口)
// GET /api/v1/admin/settings/sora-s3
func (h *SettingHandler) GetSoraS3Settings(c *gin.Context) {
settings, err := h.settingService.GetSoraS3Settings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, toSoraS3SettingsDTO(settings))
}
// ListSoraS3Profiles 获取 Sora S3 多配置
// GET /api/v1/admin/settings/sora-s3/profiles
func (h *SettingHandler) ListSoraS3Profiles(c *gin.Context) {
result, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
items := make([]dto.SoraS3Profile, 0, len(result.Items))
for idx := range result.Items {
items = append(items, toSoraS3ProfileDTO(result.Items[idx]))
}
response.Success(c, dto.ListSoraS3ProfilesResponse{
ActiveProfileID: result.ActiveProfileID,
Items: items,
})
}
// UpdateSoraS3SettingsRequest 更新/测试 Sora S3 配置请求(兼容旧接口)
type UpdateSoraS3SettingsRequest struct {
ProfileID string `json:"profile_id"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
CDNURL string `json:"cdn_url"`
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
}
type CreateSoraS3ProfileRequest struct {
ProfileID string `json:"profile_id"`
Name string `json:"name"`
SetActive bool `json:"set_active"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
CDNURL string `json:"cdn_url"`
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
}
type UpdateSoraS3ProfileRequest struct {
Name string `json:"name"`
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
CDNURL string `json:"cdn_url"`
DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"`
}
// CreateSoraS3Profile 创建 Sora S3 配置
// POST /api/v1/admin/settings/sora-s3/profiles
func (h *SettingHandler) CreateSoraS3Profile(c *gin.Context) {
var req CreateSoraS3ProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.DefaultStorageQuotaBytes < 0 {
req.DefaultStorageQuotaBytes = 0
}
if strings.TrimSpace(req.Name) == "" {
response.BadRequest(c, "Name is required")
return
}
if strings.TrimSpace(req.ProfileID) == "" {
response.BadRequest(c, "Profile ID is required")
return
}
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, false); err != nil {
response.BadRequest(c, err.Error())
return
}
created, err := h.settingService.CreateSoraS3Profile(c.Request.Context(), &service.SoraS3Profile{
ProfileID: req.ProfileID,
Name: req.Name,
Enabled: req.Enabled,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
CDNURL: req.CDNURL,
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
}, req.SetActive)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, toSoraS3ProfileDTO(*created))
}
// UpdateSoraS3Profile 更新 Sora S3 配置
// PUT /api/v1/admin/settings/sora-s3/profiles/:profile_id
func (h *SettingHandler) UpdateSoraS3Profile(c *gin.Context) {
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Profile ID is required")
return
}
var req UpdateSoraS3ProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.DefaultStorageQuotaBytes < 0 {
req.DefaultStorageQuotaBytes = 0
}
if strings.TrimSpace(req.Name) == "" {
response.BadRequest(c, "Name is required")
return
}
existingList, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
existing := findSoraS3ProfileByID(existingList.Items, profileID)
if existing == nil {
response.ErrorFrom(c, service.ErrSoraS3ProfileNotFound)
return
}
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil {
response.BadRequest(c, err.Error())
return
}
updated, updateErr := h.settingService.UpdateSoraS3Profile(c.Request.Context(), profileID, &service.SoraS3Profile{
Name: req.Name,
Enabled: req.Enabled,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
CDNURL: req.CDNURL,
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
})
if updateErr != nil {
response.ErrorFrom(c, updateErr)
return
}
response.Success(c, toSoraS3ProfileDTO(*updated))
}
// DeleteSoraS3Profile 删除 Sora S3 配置
// DELETE /api/v1/admin/settings/sora-s3/profiles/:profile_id
func (h *SettingHandler) DeleteSoraS3Profile(c *gin.Context) {
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Profile ID is required")
return
}
if err := h.settingService.DeleteSoraS3Profile(c.Request.Context(), profileID); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"deleted": true})
}
// SetActiveSoraS3Profile 切换激活 Sora S3 配置
// POST /api/v1/admin/settings/sora-s3/profiles/:profile_id/activate
func (h *SettingHandler) SetActiveSoraS3Profile(c *gin.Context) {
profileID := strings.TrimSpace(c.Param("profile_id"))
if profileID == "" {
response.BadRequest(c, "Profile ID is required")
return
}
active, err := h.settingService.SetActiveSoraS3Profile(c.Request.Context(), profileID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, toSoraS3ProfileDTO(*active))
}
// UpdateSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置接口)
// PUT /api/v1/admin/settings/sora-s3
func (h *SettingHandler) UpdateSoraS3Settings(c *gin.Context) {
var req UpdateSoraS3SettingsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
existing, err := h.settingService.GetSoraS3Settings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
if req.DefaultStorageQuotaBytes < 0 {
req.DefaultStorageQuotaBytes = 0
}
if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil {
response.BadRequest(c, err.Error())
return
}
settings := &service.SoraS3Settings{
Enabled: req.Enabled,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
CDNURL: req.CDNURL,
DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes,
}
if err := h.settingService.SetSoraS3Settings(c.Request.Context(), settings); err != nil {
response.ErrorFrom(c, err)
return
}
updatedSettings, err := h.settingService.GetSoraS3Settings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, toSoraS3SettingsDTO(updatedSettings))
}
// TestSoraS3Connection 测试 Sora S3 连接HeadBucket
// POST /api/v1/admin/settings/sora-s3/test
func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) {
if h.soraS3Storage == nil {
response.Error(c, 500, "S3 存储服务未初始化")
return
}
var req UpdateSoraS3SettingsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if !req.Enabled {
response.BadRequest(c, "S3 未启用,无法测试连接")
return
}
if req.SecretAccessKey == "" {
if req.ProfileID != "" {
profiles, err := h.settingService.ListSoraS3Profiles(c.Request.Context())
if err == nil {
profile := findSoraS3ProfileByID(profiles.Items, req.ProfileID)
if profile != nil {
req.SecretAccessKey = profile.SecretAccessKey
}
}
}
if req.SecretAccessKey == "" {
existing, err := h.settingService.GetSoraS3Settings(c.Request.Context())
if err == nil {
req.SecretAccessKey = existing.SecretAccessKey
}
}
}
testCfg := &service.SoraS3Settings{
Enabled: true,
Endpoint: req.Endpoint,
Region: req.Region,
Bucket: req.Bucket,
AccessKeyID: req.AccessKeyID,
SecretAccessKey: req.SecretAccessKey,
Prefix: req.Prefix,
ForcePathStyle: req.ForcePathStyle,
CDNURL: req.CDNURL,
}
if err := h.soraS3Storage.TestConnectionWithSettings(c.Request.Context(), testCfg); err != nil {
response.Error(c, 400, "S3 连接测试失败: "+err.Error())
return
}
response.Success(c, gin.H{"message": "S3 连接成功"})
}
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
type UpdateStreamTimeoutSettingsRequest struct {
Enabled bool `json:"enabled"`

View File

@@ -1,6 +1,7 @@
package admin
import (
"context"
"strconv"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
@@ -199,13 +200,20 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
return
}
subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days)
if err != nil {
response.ErrorFrom(c, err)
return
idempotencyPayload := struct {
SubscriptionID int64 `json:"subscription_id"`
Body AdjustSubscriptionRequest `json:"body"`
}{
SubscriptionID: subscriptionID,
Body: req,
}
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
executeAdminIdempotentJSON(c, "admin.subscriptions.extend", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
subscription, execErr := h.subscriptionService.ExtendSubscription(ctx, subscriptionID, req.Days)
if execErr != nil {
return nil, execErr
}
return dto.UserSubscriptionFromServiceAdmin(subscription), nil
})
}
// Revoke handles revoking a subscription

View File

@@ -1,11 +1,15 @@
package admin
import (
"context"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/sysutil"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -14,12 +18,14 @@ import (
// SystemHandler handles system-related operations
type SystemHandler struct {
updateSvc *service.UpdateService
lockSvc *service.SystemOperationLockService
}
// NewSystemHandler creates a new SystemHandler
func NewSystemHandler(updateSvc *service.UpdateService) *SystemHandler {
func NewSystemHandler(updateSvc *service.UpdateService, lockSvc *service.SystemOperationLockService) *SystemHandler {
return &SystemHandler{
updateSvc: updateSvc,
lockSvc: lockSvc,
}
}
@@ -47,41 +53,125 @@ func (h *SystemHandler) CheckUpdates(c *gin.Context) {
// PerformUpdate downloads and applies the update
// POST /api/v1/admin/system/update
func (h *SystemHandler) PerformUpdate(c *gin.Context) {
if err := h.updateSvc.PerformUpdate(c.Request.Context()); err != nil {
response.Error(c, http.StatusInternalServerError, err.Error())
return
}
response.Success(c, gin.H{
"message": "Update completed. Please restart the service.",
"need_restart": true,
operationID := buildSystemOperationID(c, "update")
payload := gin.H{"operation_id": operationID}
executeAdminIdempotentJSON(c, "admin.system.update", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
lock, release, err := h.acquireSystemLock(ctx, operationID)
if err != nil {
return nil, err
}
var releaseReason string
succeeded := false
defer func() {
release(releaseReason, succeeded)
}()
if err := h.updateSvc.PerformUpdate(ctx); err != nil {
releaseReason = "SYSTEM_UPDATE_FAILED"
return nil, err
}
succeeded = true
return gin.H{
"message": "Update completed. Please restart the service.",
"need_restart": true,
"operation_id": lock.OperationID(),
}, nil
})
}
// Rollback restores the previous version
// POST /api/v1/admin/system/rollback
func (h *SystemHandler) Rollback(c *gin.Context) {
if err := h.updateSvc.Rollback(); err != nil {
response.Error(c, http.StatusInternalServerError, err.Error())
return
}
response.Success(c, gin.H{
"message": "Rollback completed. Please restart the service.",
"need_restart": true,
operationID := buildSystemOperationID(c, "rollback")
payload := gin.H{"operation_id": operationID}
executeAdminIdempotentJSON(c, "admin.system.rollback", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
lock, release, err := h.acquireSystemLock(ctx, operationID)
if err != nil {
return nil, err
}
var releaseReason string
succeeded := false
defer func() {
release(releaseReason, succeeded)
}()
if err := h.updateSvc.Rollback(); err != nil {
releaseReason = "SYSTEM_ROLLBACK_FAILED"
return nil, err
}
succeeded = true
return gin.H{
"message": "Rollback completed. Please restart the service.",
"need_restart": true,
"operation_id": lock.OperationID(),
}, nil
})
}
// RestartService restarts the systemd service
// POST /api/v1/admin/system/restart
func (h *SystemHandler) RestartService(c *gin.Context) {
// Schedule service restart in background after sending response
// This ensures the client receives the success response before the service restarts
go func() {
// Wait a moment to ensure the response is sent
time.Sleep(500 * time.Millisecond)
sysutil.RestartServiceAsync()
}()
operationID := buildSystemOperationID(c, "restart")
payload := gin.H{"operation_id": operationID}
executeAdminIdempotentJSON(c, "admin.system.restart", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) {
lock, release, err := h.acquireSystemLock(ctx, operationID)
if err != nil {
return nil, err
}
succeeded := false
defer func() {
release("", succeeded)
}()
response.Success(c, gin.H{
"message": "Service restart initiated",
// Schedule service restart in background after sending response
// This ensures the client receives the success response before the service restarts
go func() {
// Wait a moment to ensure the response is sent
time.Sleep(500 * time.Millisecond)
sysutil.RestartServiceAsync()
}()
succeeded = true
return gin.H{
"message": "Service restart initiated",
"operation_id": lock.OperationID(),
}, nil
})
}
func (h *SystemHandler) acquireSystemLock(
ctx context.Context,
operationID string,
) (*service.SystemOperationLock, func(string, bool), error) {
if h.lockSvc == nil {
return nil, nil, service.ErrIdempotencyStoreUnavail
}
lock, err := h.lockSvc.Acquire(ctx, operationID)
if err != nil {
return nil, nil, err
}
release := func(reason string, succeeded bool) {
releaseCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = h.lockSvc.Release(releaseCtx, lock, succeeded, reason)
}
return lock, release, nil
}
func buildSystemOperationID(c *gin.Context, operation string) string {
key := strings.TrimSpace(c.GetHeader("Idempotency-Key"))
if key == "" {
return "sysop-" + operation + "-" + strconv.FormatInt(time.Now().UnixNano(), 36)
}
actorScope := "admin:0"
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10)
}
seed := operation + "|" + actorScope + "|" + c.FullPath() + "|" + key
hash := service.HashIdempotencyKey(seed)
if len(hash) > 24 {
hash = hash[:24]
}
return "sysop-" + hash
}

View File

@@ -225,6 +225,92 @@ func TestUsageHandlerCreateCleanupTaskInvalidEndDate(t *testing.T) {
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskInvalidRequestType(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 88)
payload := map[string]any{
"start_date": "2024-01-01",
"end_date": "2024-01-02",
"timezone": "UTC",
"request_type": "invalid",
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusBadRequest, recorder.Code)
}
func TestUsageHandlerCreateCleanupTaskRequestTypePriority(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 99)
payload := map[string]any{
"start_date": "2024-01-01",
"end_date": "2024-01-02",
"timezone": "UTC",
"request_type": "ws_v2",
"stream": false,
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.created, 1)
created := repo.created[0]
require.NotNil(t, created.Filters.RequestType)
require.Equal(t, int16(service.RequestTypeWSV2), *created.Filters.RequestType)
require.Nil(t, created.Filters.Stream)
}
func TestUsageHandlerCreateCleanupTaskWithLegacyStream(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
router := setupCleanupRouter(cleanupService, 99)
payload := map[string]any{
"start_date": "2024-01-01",
"end_date": "2024-01-02",
"timezone": "UTC",
"stream": true,
}
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
require.Equal(t, http.StatusOK, recorder.Code)
repo.mu.Lock()
defer repo.mu.Unlock()
require.Len(t, repo.created, 1)
created := repo.created[0]
require.Nil(t, created.Filters.RequestType)
require.NotNil(t, created.Filters.Stream)
require.True(t, *created.Filters.Stream)
}
func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) {
repo := &cleanupRepoStub{}
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}

View File

@@ -1,13 +1,14 @@
package admin
import (
"log"
"context"
"net/http"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
@@ -50,6 +51,7 @@ type CreateUsageCleanupTaskRequest struct {
AccountID *int64 `json:"account_id"`
GroupID *int64 `json:"group_id"`
Model *string `json:"model"`
RequestType *string `json:"request_type"`
Stream *bool `json:"stream"`
BillingType *int8 `json:"billing_type"`
Timezone string `json:"timezone"`
@@ -100,8 +102,17 @@ func (h *UsageHandler) List(c *gin.Context) {
model := c.Query("model")
var requestType *int16
var stream *bool
if streamStr := c.Query("stream"); streamStr != "" {
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
} else if streamStr := c.Query("stream"); streamStr != "" {
val, err := strconv.ParseBool(streamStr)
if err != nil {
response.BadRequest(c, "Invalid stream value, use true or false")
@@ -151,6 +162,7 @@ func (h *UsageHandler) List(c *gin.Context) {
AccountID: accountID,
GroupID: groupID,
Model: model,
RequestType: requestType,
Stream: stream,
BillingType: billingType,
StartTime: startTime,
@@ -213,8 +225,17 @@ func (h *UsageHandler) Stats(c *gin.Context) {
model := c.Query("model")
var requestType *int16
var stream *bool
if streamStr := c.Query("stream"); streamStr != "" {
if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
parsed, err := service.ParseUsageRequestType(requestTypeStr)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
} else if streamStr := c.Query("stream"); streamStr != "" {
val, err := strconv.ParseBool(streamStr)
if err != nil {
response.BadRequest(c, "Invalid stream value, use true or false")
@@ -277,6 +298,7 @@ func (h *UsageHandler) Stats(c *gin.Context) {
AccountID: accountID,
GroupID: groupID,
Model: model,
RequestType: requestType,
Stream: stream,
BillingType: billingType,
StartTime: &startTime,
@@ -378,11 +400,11 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
operator = subject.UserID
}
page, pageSize := response.ParsePagination(c)
log.Printf("[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params)
if err != nil {
log.Printf("[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
response.ErrorFrom(c, err)
return
}
@@ -390,7 +412,7 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
for i := range tasks {
out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i]))
}
log.Printf("[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
response.Paginated(c, out, result.Total, page, pageSize)
}
@@ -431,6 +453,19 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
}
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
var requestType *int16
stream := req.Stream
if req.RequestType != nil {
parsed, err := service.ParseUsageRequestType(*req.RequestType)
if err != nil {
response.BadRequest(c, err.Error())
return
}
value := int16(parsed)
requestType = &value
stream = nil
}
filters := service.UsageCleanupFilters{
StartTime: startTime,
EndTime: endTime,
@@ -439,7 +474,8 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
AccountID: req.AccountID,
GroupID: req.GroupID,
Model: req.Model,
Stream: req.Stream,
RequestType: requestType,
Stream: stream,
BillingType: req.BillingType,
}
@@ -463,38 +499,50 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
if filters.Model != nil {
model = *filters.Model
}
var stream any
var streamValue any
if filters.Stream != nil {
stream = *filters.Stream
streamValue = *filters.Stream
}
var requestTypeName any
if filters.RequestType != nil {
requestTypeName = service.RequestTypeFromInt16(*filters.RequestType).String()
}
var billingType any
if filters.BillingType != nil {
billingType = *filters.BillingType
}
log.Printf("[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
subject.UserID,
filters.StartTime.Format(time.RFC3339),
filters.EndTime.Format(time.RFC3339),
userID,
apiKeyID,
accountID,
groupID,
model,
stream,
billingType,
req.Timezone,
)
task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID)
if err != nil {
log.Printf("[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
response.ErrorFrom(c, err)
return
idempotencyPayload := struct {
OperatorID int64 `json:"operator_id"`
Body CreateUsageCleanupTaskRequest `json:"body"`
}{
OperatorID: subject.UserID,
Body: req,
}
executeAdminIdempotentJSON(c, "admin.usage.cleanup_tasks.create", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v request_type=%v stream=%v billing_type=%v tz=%q",
subject.UserID,
filters.StartTime.Format(time.RFC3339),
filters.EndTime.Format(time.RFC3339),
userID,
apiKeyID,
accountID,
groupID,
model,
requestTypeName,
streamValue,
billingType,
req.Timezone,
)
log.Printf("[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
response.Success(c, dto.UsageCleanupTaskFromService(task))
task, err := h.cleanupService.CreateTask(ctx, filters, subject.UserID)
if err != nil {
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
return nil, err
}
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
return dto.UsageCleanupTaskFromService(task), nil
})
}
// CancelCleanupTask handles canceling a usage cleanup task
@@ -515,12 +563,12 @@ func (h *UsageHandler) CancelCleanupTask(c *gin.Context) {
response.BadRequest(c, "Invalid task id")
return
}
log.Printf("[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil {
log.Printf("[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
response.ErrorFrom(c, err)
return
}
log.Printf("[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled})
}

View File

@@ -0,0 +1,117 @@
package admin
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type adminUsageRepoCapture struct {
service.UsageLogRepository
listFilters usagestats.UsageLogFilters
statsFilters usagestats.UsageLogFilters
}
func (s *adminUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) {
s.listFilters = filters
return []service.UsageLog{}, &pagination.PaginationResult{
Total: 0,
Page: params.Page,
PageSize: params.PageSize,
Pages: 0,
}, nil
}
func (s *adminUsageRepoCapture) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) {
s.statsFilters = filters
return &usagestats.UsageStats{}, nil
}
func newAdminUsageRequestTypeTestRouter(repo *adminUsageRepoCapture) *gin.Engine {
gin.SetMode(gin.TestMode)
usageSvc := service.NewUsageService(repo, nil, nil, nil)
handler := NewUsageHandler(usageSvc, nil, nil, nil)
router := gin.New()
router.GET("/admin/usage", handler.List)
router.GET("/admin/usage/stats", handler.Stats)
return router
}
func TestAdminUsageListRequestTypePriority(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage?request_type=ws_v2&stream=false", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, repo.listFilters.RequestType)
require.Equal(t, int16(service.RequestTypeWSV2), *repo.listFilters.RequestType)
require.Nil(t, repo.listFilters.Stream)
}
func TestAdminUsageListInvalidRequestType(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage?request_type=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestAdminUsageListInvalidStream(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage?stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestAdminUsageStatsRequestTypePriority(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?request_type=stream&stream=bad", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, repo.statsFilters.RequestType)
require.Equal(t, int16(service.RequestTypeStream), *repo.statsFilters.RequestType)
require.Nil(t, repo.statsFilters.Stream)
}
func TestAdminUsageStatsInvalidRequestType(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?request_type=oops", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestAdminUsageStatsInvalidStream(t *testing.T) {
repo := &adminUsageRepoCapture{}
router := newAdminUsageRequestTypeTestRouter(repo)
req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?stream=oops", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
}

View File

@@ -1,6 +1,7 @@
package admin
import (
"context"
"strconv"
"strings"
@@ -11,27 +12,36 @@ import (
"github.com/gin-gonic/gin"
)
// UserWithConcurrency wraps AdminUser with current concurrency info
type UserWithConcurrency struct {
dto.AdminUser
CurrentConcurrency int `json:"current_concurrency"`
}
// UserHandler handles admin user management
type UserHandler struct {
adminService service.AdminService
adminService service.AdminService
concurrencyService *service.ConcurrencyService
}
// NewUserHandler creates a new admin user handler
func NewUserHandler(adminService service.AdminService) *UserHandler {
func NewUserHandler(adminService service.AdminService, concurrencyService *service.ConcurrencyService) *UserHandler {
return &UserHandler{
adminService: adminService,
adminService: adminService,
concurrencyService: concurrencyService,
}
}
// CreateUserRequest represents admin create user request
type CreateUserRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
Username string `json:"username"`
Notes string `json:"notes"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
AllowedGroups []int64 `json:"allowed_groups"`
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
Username string `json:"username"`
Notes string `json:"notes"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
AllowedGroups []int64 `json:"allowed_groups"`
SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"`
}
// UpdateUserRequest represents admin update user request
@@ -47,7 +57,8 @@ type UpdateUserRequest struct {
AllowedGroups *[]int64 `json:"allowed_groups"`
// GroupRates 用户专属分组倍率配置
// map[groupID]*ratenil 表示删除该分组的专属倍率
GroupRates map[int64]*float64 `json:"group_rates"`
GroupRates map[int64]*float64 `json:"group_rates"`
SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"`
}
// UpdateBalanceRequest represents balance update request
@@ -70,8 +81,8 @@ func (h *UserHandler) List(c *gin.Context) {
search := c.Query("search")
// 标准化和验证 search 参数
search = strings.TrimSpace(search)
if len(search) > 100 {
search = search[:100]
if runes := []rune(search); len(runes) > 100 {
search = string(runes[:100])
}
filters := service.UserListFilters{
@@ -87,10 +98,30 @@ func (h *UserHandler) List(c *gin.Context) {
return
}
out := make([]dto.AdminUser, 0, len(users))
for i := range users {
out = append(out, *dto.UserFromServiceAdmin(&users[i]))
// Batch get current concurrency (nil map if unavailable)
var loadInfo map[int64]*service.UserLoadInfo
if len(users) > 0 && h.concurrencyService != nil {
usersConcurrency := make([]service.UserWithConcurrency, len(users))
for i := range users {
usersConcurrency[i] = service.UserWithConcurrency{
ID: users[i].ID,
MaxConcurrency: users[i].Concurrency,
}
}
loadInfo, _ = h.concurrencyService.GetUsersLoadBatch(c.Request.Context(), usersConcurrency)
}
// Build response with concurrency info
out := make([]UserWithConcurrency, len(users))
for i := range users {
out[i] = UserWithConcurrency{
AdminUser: *dto.UserFromServiceAdmin(&users[i]),
}
if info := loadInfo[users[i].ID]; info != nil {
out[i].CurrentConcurrency = info.CurrentConcurrency
}
}
response.Paginated(c, out, total, page, pageSize)
}
@@ -145,13 +176,14 @@ func (h *UserHandler) Create(c *gin.Context) {
}
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
Email: req.Email,
Password: req.Password,
Username: req.Username,
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
AllowedGroups: req.AllowedGroups,
Email: req.Email,
Password: req.Password,
Username: req.Username,
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
AllowedGroups: req.AllowedGroups,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -178,15 +210,16 @@ func (h *UserHandler) Update(c *gin.Context) {
// 使用指针类型直接传递nil 表示未提供该字段
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
Email: req.Email,
Password: req.Password,
Username: req.Username,
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
Status: req.Status,
AllowedGroups: req.AllowedGroups,
GroupRates: req.GroupRates,
Email: req.Email,
Password: req.Password,
Username: req.Username,
Notes: req.Notes,
Balance: req.Balance,
Concurrency: req.Concurrency,
Status: req.Status,
AllowedGroups: req.AllowedGroups,
GroupRates: req.GroupRates,
SoraStorageQuotaBytes: req.SoraStorageQuotaBytes,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -229,13 +262,20 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
return
}
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes)
if err != nil {
response.ErrorFrom(c, err)
return
idempotencyPayload := struct {
UserID int64 `json:"user_id"`
Body UpdateBalanceRequest `json:"body"`
}{
UserID: userID,
Body: req,
}
response.Success(c, dto.UserFromServiceAdmin(user))
executeAdminIdempotentJSON(c, "admin.users.balance.update", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
user, execErr := h.adminService.UpdateUserBalance(ctx, userID, req.Balance, req.Operation, req.Notes)
if execErr != nil {
return nil, execErr
}
return dto.UserFromServiceAdmin(user), nil
})
}
// GetUserAPIKeys handles getting user's API keys