Merge branch 'main' of github.com:Wei-Shaw/sub2api
This commit is contained in:
@@ -45,6 +45,7 @@ type AccountHandler struct {
|
||||
concurrencyService *service.ConcurrencyService
|
||||
crsSyncService *service.CRSSyncService
|
||||
sessionLimitCache service.SessionLimitCache
|
||||
tokenCacheInvalidator service.TokenCacheInvalidator
|
||||
}
|
||||
|
||||
// NewAccountHandler creates a new admin account handler
|
||||
@@ -60,6 +61,7 @@ func NewAccountHandler(
|
||||
concurrencyService *service.ConcurrencyService,
|
||||
crsSyncService *service.CRSSyncService,
|
||||
sessionLimitCache service.SessionLimitCache,
|
||||
tokenCacheInvalidator service.TokenCacheInvalidator,
|
||||
) *AccountHandler {
|
||||
return &AccountHandler{
|
||||
adminService: adminService,
|
||||
@@ -73,6 +75,7 @@ func NewAccountHandler(
|
||||
concurrencyService: concurrencyService,
|
||||
crsSyncService: crsSyncService,
|
||||
sessionLimitCache: sessionLimitCache,
|
||||
tokenCacheInvalidator: tokenCacheInvalidator,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,6 +176,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
// 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
|
||||
windowCostAccountIDs := make([]int64, 0)
|
||||
sessionLimitAccountIDs := make([]int64, 0)
|
||||
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if acc.IsAnthropicOAuthOrSetupToken() {
|
||||
@@ -181,6 +185,7 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
}
|
||||
if acc.GetMaxSessions() > 0 {
|
||||
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
|
||||
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -189,9 +194,9 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
var windowCosts map[int64]float64
|
||||
var activeSessions map[int64]int
|
||||
|
||||
// 获取活跃会话数(批量查询)
|
||||
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
|
||||
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
|
||||
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs)
|
||||
activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts)
|
||||
if activeSessions == nil {
|
||||
activeSessions = make(map[int64]int)
|
||||
}
|
||||
@@ -211,12 +216,8 @@ func (h *AccountHandler) List(c *gin.Context) {
|
||||
}
|
||||
accCopy := acc // 闭包捕获
|
||||
g.Go(func() error {
|
||||
var startTime time.Time
|
||||
if accCopy.SessionWindowStart != nil {
|
||||
startTime = *accCopy.SessionWindowStart
|
||||
} else {
|
||||
startTime = time.Now().Add(-5 * time.Hour)
|
||||
}
|
||||
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
|
||||
startTime := accCopy.GetCurrentWindowStartTime()
|
||||
stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
|
||||
if err == nil && stats != nil {
|
||||
mu.Lock()
|
||||
@@ -545,6 +546,36 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,先更新凭证,再标记账户为 error
|
||||
if tokenInfo.ProjectIDMissing {
|
||||
// 先更新凭证
|
||||
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
if updateErr != nil {
|
||||
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
|
||||
return
|
||||
}
|
||||
// 标记账户为 error
|
||||
if setErr := h.adminService.SetAccountError(c.Request.Context(), accountID, "missing_project_id: 账户缺少project id,可能无法使用Antigravity"); setErr != nil {
|
||||
response.InternalError(c, "Failed to set account error: "+setErr.Error())
|
||||
return
|
||||
}
|
||||
response.Success(c, gin.H{
|
||||
"message": "Token refreshed but project_id is missing, account marked as error",
|
||||
"warning": "missing_project_id",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 成功获取到 project_id,如果之前是 missing_project_id 错误则清除
|
||||
if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
||||
if _, clearErr := h.adminService.ClearAccountError(c.Request.Context(), accountID); clearErr != nil {
|
||||
response.InternalError(c, "Failed to clear account error: "+clearErr.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Use Anthropic/Claude OAuth service to refresh token
|
||||
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
|
||||
@@ -580,6 +611,14 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 刷新成功后,清除 token 缓存,确保下次请求使用新 token
|
||||
if h.tokenCacheInvalidator != nil {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), updatedAccount); invalidateErr != nil {
|
||||
// 缓存失效失败只记录日志,不影响主流程
|
||||
_ = c.Error(invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(updatedAccount))
|
||||
}
|
||||
|
||||
|
||||
262
backend/internal/handler/admin/admin_basic_handlers_test.go
Normal file
262
backend/internal/handler/admin/admin_basic_handlers_test.go
Normal file
@@ -0,0 +1,262 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func setupAdminRouter() (*gin.Engine, *stubAdminService) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
adminSvc := newStubAdminService()
|
||||
|
||||
userHandler := NewUserHandler(adminSvc)
|
||||
groupHandler := NewGroupHandler(adminSvc)
|
||||
proxyHandler := NewProxyHandler(adminSvc)
|
||||
redeemHandler := NewRedeemHandler(adminSvc)
|
||||
|
||||
router.GET("/api/v1/admin/users", userHandler.List)
|
||||
router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
|
||||
router.POST("/api/v1/admin/users", userHandler.Create)
|
||||
router.PUT("/api/v1/admin/users/:id", userHandler.Update)
|
||||
router.DELETE("/api/v1/admin/users/:id", userHandler.Delete)
|
||||
router.POST("/api/v1/admin/users/:id/balance", userHandler.UpdateBalance)
|
||||
router.GET("/api/v1/admin/users/:id/api-keys", userHandler.GetUserAPIKeys)
|
||||
router.GET("/api/v1/admin/users/:id/usage", userHandler.GetUserUsage)
|
||||
|
||||
router.GET("/api/v1/admin/groups", groupHandler.List)
|
||||
router.GET("/api/v1/admin/groups/all", groupHandler.GetAll)
|
||||
router.GET("/api/v1/admin/groups/:id", groupHandler.GetByID)
|
||||
router.POST("/api/v1/admin/groups", groupHandler.Create)
|
||||
router.PUT("/api/v1/admin/groups/:id", groupHandler.Update)
|
||||
router.DELETE("/api/v1/admin/groups/:id", groupHandler.Delete)
|
||||
router.GET("/api/v1/admin/groups/:id/stats", groupHandler.GetStats)
|
||||
router.GET("/api/v1/admin/groups/:id/api-keys", groupHandler.GetGroupAPIKeys)
|
||||
|
||||
router.GET("/api/v1/admin/proxies", proxyHandler.List)
|
||||
router.GET("/api/v1/admin/proxies/all", proxyHandler.GetAll)
|
||||
router.GET("/api/v1/admin/proxies/:id", proxyHandler.GetByID)
|
||||
router.POST("/api/v1/admin/proxies", proxyHandler.Create)
|
||||
router.PUT("/api/v1/admin/proxies/:id", proxyHandler.Update)
|
||||
router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete)
|
||||
router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete)
|
||||
router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test)
|
||||
router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats)
|
||||
router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts)
|
||||
|
||||
router.GET("/api/v1/admin/redeem-codes", redeemHandler.List)
|
||||
router.GET("/api/v1/admin/redeem-codes/:id", redeemHandler.GetByID)
|
||||
router.POST("/api/v1/admin/redeem-codes", redeemHandler.Generate)
|
||||
router.DELETE("/api/v1/admin/redeem-codes/:id", redeemHandler.Delete)
|
||||
router.POST("/api/v1/admin/redeem-codes/batch-delete", redeemHandler.BatchDelete)
|
||||
router.POST("/api/v1/admin/redeem-codes/:id/expire", redeemHandler.Expire)
|
||||
router.GET("/api/v1/admin/redeem-codes/:id/stats", redeemHandler.GetStats)
|
||||
|
||||
return router, adminSvc
|
||||
}
|
||||
|
||||
func TestUserHandlerEndpoints(t *testing.T) {
|
||||
router, _ := setupAdminRouter()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/users?page=1&page_size=20", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2}
|
||||
body, _ := json.Marshal(createBody)
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
updateBody := map[string]any{"email": "updated@example.com"}
|
||||
body, _ = json.Marshal(updateBody)
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/users/1", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/users/1", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/balance", bytes.NewBufferString(`{"balance":1,"operation":"add"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1/api-keys", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1/usage?period=today", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestGroupHandlerEndpoints(t *testing.T) {
|
||||
router, _ := setupAdminRouter()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/all", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{"name": "new", "platform": "anthropic", "subscription_type": "standard"})
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/groups", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
body, _ = json.Marshal(map[string]any{"name": "update"})
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/groups/2", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/groups/2", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2/stats", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2/api-keys", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestProxyHandlerEndpoints(t *testing.T) {
|
||||
router, _ := setupAdminRouter()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/all", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{"name": "proxy", "protocol": "http", "host": "localhost", "port": 8080})
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
body, _ = json.Marshal(map[string]any{"name": "proxy2"})
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/proxies/4", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/proxies/4", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/batch-delete", bytes.NewBufferString(`{"ids":[1,2]}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/test", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/accounts", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestRedeemHandlerEndpoints(t *testing.T) {
|
||||
router, _ := setupAdminRouter()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/5", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{"count": 1, "type": "balance", "value": 10})
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/redeem-codes/5", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/batch-delete", bytes.NewBufferString(`{"ids":[1,2]}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/5/expire", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/5/stats", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
134
backend/internal/handler/admin/admin_helpers_test.go
Normal file
134
backend/internal/handler/admin/admin_helpers_test.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseTimeRange(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
req := httptest.NewRequest(http.MethodGet, "/?start_date=2024-01-01&end_date=2024-01-02&timezone=UTC", nil)
|
||||
c.Request = req
|
||||
|
||||
start, end := parseTimeRange(c)
|
||||
require.Equal(t, time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), start)
|
||||
require.Equal(t, time.Date(2024, 1, 3, 0, 0, 0, 0, time.UTC), end)
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/?start_date=bad&timezone=UTC", nil)
|
||||
c.Request = req
|
||||
start, end = parseTimeRange(c)
|
||||
require.False(t, start.IsZero())
|
||||
require.False(t, end.IsZero())
|
||||
}
|
||||
|
||||
func TestParseOpsViewParam(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/?view=excluded", nil)
|
||||
require.Equal(t, opsListViewExcluded, parseOpsViewParam(c))
|
||||
|
||||
c2, _ := gin.CreateTestContext(w)
|
||||
c2.Request = httptest.NewRequest(http.MethodGet, "/?view=all", nil)
|
||||
require.Equal(t, opsListViewAll, parseOpsViewParam(c2))
|
||||
|
||||
c3, _ := gin.CreateTestContext(w)
|
||||
c3.Request = httptest.NewRequest(http.MethodGet, "/?view=unknown", nil)
|
||||
require.Equal(t, opsListViewErrors, parseOpsViewParam(c3))
|
||||
|
||||
require.Equal(t, "", parseOpsViewParam(nil))
|
||||
}
|
||||
|
||||
func TestParseOpsDuration(t *testing.T) {
|
||||
dur, ok := parseOpsDuration("1h")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, time.Hour, dur)
|
||||
|
||||
_, ok = parseOpsDuration("invalid")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestParseOpsTimeRange(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
now := time.Now().UTC()
|
||||
startStr := now.Add(-time.Hour).Format(time.RFC3339)
|
||||
endStr := now.Format(time.RFC3339)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/?start_time="+startStr+"&end_time="+endStr, nil)
|
||||
start, end, err := parseOpsTimeRange(c, "1h")
|
||||
require.NoError(t, err)
|
||||
require.True(t, start.Before(end))
|
||||
|
||||
c2, _ := gin.CreateTestContext(w)
|
||||
c2.Request = httptest.NewRequest(http.MethodGet, "/?start_time=bad", nil)
|
||||
_, _, err = parseOpsTimeRange(c2, "1h")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseOpsRealtimeWindow(t *testing.T) {
|
||||
dur, label, ok := parseOpsRealtimeWindow("5m")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 5*time.Minute, dur)
|
||||
require.Equal(t, "5min", label)
|
||||
|
||||
_, _, ok = parseOpsRealtimeWindow("invalid")
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestPickThroughputBucketSeconds(t *testing.T) {
|
||||
require.Equal(t, 60, pickThroughputBucketSeconds(30*time.Minute))
|
||||
require.Equal(t, 300, pickThroughputBucketSeconds(6*time.Hour))
|
||||
require.Equal(t, 3600, pickThroughputBucketSeconds(48*time.Hour))
|
||||
}
|
||||
|
||||
func TestParseOpsQueryMode(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/?mode=raw", nil)
|
||||
require.Equal(t, service.ParseOpsQueryMode("raw"), parseOpsQueryMode(c))
|
||||
require.Equal(t, service.OpsQueryMode(""), parseOpsQueryMode(nil))
|
||||
}
|
||||
|
||||
func TestOpsAlertRuleValidation(t *testing.T) {
|
||||
raw := map[string]json.RawMessage{
|
||||
"name": json.RawMessage(`"High error rate"`),
|
||||
"metric_type": json.RawMessage(`"error_rate"`),
|
||||
"operator": json.RawMessage(`">"`),
|
||||
"threshold": json.RawMessage(`90`),
|
||||
}
|
||||
|
||||
validated, err := validateOpsAlertRulePayload(raw)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "High error rate", validated.Name)
|
||||
|
||||
_, err = validateOpsAlertRulePayload(map[string]json.RawMessage{})
|
||||
require.Error(t, err)
|
||||
|
||||
require.True(t, isPercentOrRateMetric("error_rate"))
|
||||
require.False(t, isPercentOrRateMetric("concurrency_queue_depth"))
|
||||
}
|
||||
|
||||
func TestOpsWSHelpers(t *testing.T) {
|
||||
prefixes, invalid := parseTrustedProxyList("10.0.0.0/8,invalid")
|
||||
require.Len(t, prefixes, 1)
|
||||
require.Len(t, invalid, 1)
|
||||
|
||||
host := hostWithoutPort("example.com:443")
|
||||
require.Equal(t, "example.com", host)
|
||||
|
||||
addr := netip.MustParseAddr("10.0.0.1")
|
||||
require.True(t, isAddrInTrustedProxies(addr, prefixes))
|
||||
require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes))
|
||||
}
|
||||
294
backend/internal/handler/admin/admin_service_stub_test.go
Normal file
294
backend/internal/handler/admin/admin_service_stub_test.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type stubAdminService struct {
|
||||
users []service.User
|
||||
apiKeys []service.APIKey
|
||||
groups []service.Group
|
||||
accounts []service.Account
|
||||
proxies []service.Proxy
|
||||
proxyCounts []service.ProxyWithAccountCount
|
||||
redeems []service.RedeemCode
|
||||
}
|
||||
|
||||
func newStubAdminService() *stubAdminService {
|
||||
now := time.Now().UTC()
|
||||
user := service.User{
|
||||
ID: 1,
|
||||
Email: "user@example.com",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
apiKey := service.APIKey{
|
||||
ID: 10,
|
||||
UserID: user.ID,
|
||||
Key: "sk-test",
|
||||
Name: "test",
|
||||
Status: service.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
group := service.Group{
|
||||
ID: 2,
|
||||
Name: "group",
|
||||
Platform: service.PlatformAnthropic,
|
||||
Status: service.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
account := service.Account{
|
||||
ID: 3,
|
||||
Name: "account",
|
||||
Platform: service.PlatformAnthropic,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
proxy := service.Proxy{
|
||||
ID: 4,
|
||||
Name: "proxy",
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Status: service.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
redeem := service.RedeemCode{
|
||||
ID: 5,
|
||||
Code: "R-TEST",
|
||||
Type: service.RedeemTypeBalance,
|
||||
Value: 10,
|
||||
Status: service.StatusUnused,
|
||||
CreatedAt: now,
|
||||
}
|
||||
return &stubAdminService{
|
||||
users: []service.User{user},
|
||||
apiKeys: []service.APIKey{apiKey},
|
||||
groups: []service.Group{group},
|
||||
accounts: []service.Account{account},
|
||||
proxies: []service.Proxy{proxy},
|
||||
proxyCounts: []service.ProxyWithAccountCount{{Proxy: proxy, AccountCount: 1}},
|
||||
redeems: []service.RedeemCode{redeem},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters) ([]service.User, int64, error) {
|
||||
return s.users, int64(len(s.users)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetUser(ctx context.Context, id int64) (*service.User, error) {
|
||||
for i := range s.users {
|
||||
if s.users[i].ID == id {
|
||||
return &s.users[i], nil
|
||||
}
|
||||
}
|
||||
user := service.User{ID: id, Email: "user@example.com", Status: service.StatusActive}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CreateUser(ctx context.Context, input *service.CreateUserInput) (*service.User, error) {
|
||||
user := service.User{ID: 100, Email: input.Email, Status: service.StatusActive}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) UpdateUser(ctx context.Context, id int64, input *service.UpdateUserInput) (*service.User, error) {
|
||||
user := service.User{ID: id, Email: "updated@example.com", Status: service.StatusActive}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) DeleteUser(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*service.User, error) {
|
||||
user := service.User{ID: userID, Balance: balance, Status: service.StatusActive}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]service.APIKey, int64, error) {
|
||||
return s.apiKeys, int64(len(s.apiKeys)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
|
||||
return map[string]any{"user_id": userID}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]service.Group, int64, error) {
|
||||
return s.groups, int64(len(s.groups)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetAllGroups(ctx context.Context) ([]service.Group, error) {
|
||||
return s.groups, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
|
||||
return s.groups, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetGroup(ctx context.Context, id int64) (*service.Group, error) {
|
||||
group := service.Group{ID: id, Name: "group", Status: service.StatusActive}
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CreateGroup(ctx context.Context, input *service.CreateGroupInput) (*service.Group, error) {
|
||||
group := service.Group{ID: 200, Name: input.Name, Status: service.StatusActive}
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) UpdateGroup(ctx context.Context, id int64, input *service.UpdateGroupInput) (*service.Group, error) {
|
||||
group := service.Group{ID: id, Name: input.Name, Status: service.StatusActive}
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) DeleteGroup(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]service.APIKey, int64, error) {
|
||||
return s.apiKeys, int64(len(s.apiKeys)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]service.Account, int64, error) {
|
||||
return s.accounts, int64(len(s.accounts)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetAccount(ctx context.Context, id int64) (*service.Account, error) {
|
||||
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
|
||||
out := make([]*service.Account, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
|
||||
out = append(out, &account)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.CreateAccountInput) (*service.Account, error) {
|
||||
account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
|
||||
account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) DeleteAccount(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) RefreshAccountCredentials(ctx context.Context, id int64) (*service.Account, error) {
|
||||
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ClearAccountError(ctx context.Context, id int64) (*service.Account, error) {
|
||||
account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) SetAccountError(ctx context.Context, id int64, errorMsg string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*service.Account, error) {
|
||||
account := service.Account{ID: id, Name: "account", Status: service.StatusActive, Schedulable: schedulable}
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *service.BulkUpdateAccountsInput) (*service.BulkUpdateAccountsResult, error) {
|
||||
return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
|
||||
return s.proxies, int64(len(s.proxies)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.ProxyWithAccountCount, int64, error) {
|
||||
return s.proxyCounts, int64(len(s.proxyCounts)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetAllProxies(ctx context.Context) ([]service.Proxy, error) {
|
||||
return s.proxies, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetAllProxiesWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) {
|
||||
return s.proxyCounts, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetProxy(ctx context.Context, id int64) (*service.Proxy, error) {
|
||||
proxy := service.Proxy{ID: id, Name: "proxy", Status: service.StatusActive}
|
||||
return &proxy, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CreateProxy(ctx context.Context, input *service.CreateProxyInput) (*service.Proxy, error) {
|
||||
proxy := service.Proxy{ID: 400, Name: input.Name, Status: service.StatusActive}
|
||||
return &proxy, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) UpdateProxy(ctx context.Context, id int64, input *service.UpdateProxyInput) (*service.Proxy, error) {
|
||||
proxy := service.Proxy{ID: id, Name: input.Name, Status: service.StatusActive}
|
||||
return &proxy, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) DeleteProxy(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) BatchDeleteProxies(ctx context.Context, ids []int64) (*service.ProxyBatchDeleteResult, error) {
|
||||
return &service.ProxyBatchDeleteResult{DeletedIDs: ids}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetProxyAccounts(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
|
||||
return []service.ProxyAccountSummary{{ID: 1, Name: "account"}}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.ProxyTestResult, error) {
|
||||
return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) {
|
||||
return s.redeems, int64(len(s.redeems)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GetRedeemCode(ctx context.Context, id int64) (*service.RedeemCode, error) {
|
||||
code := service.RedeemCode{ID: id, Code: "R-TEST", Status: service.StatusUnused}
|
||||
return &code, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) GenerateRedeemCodes(ctx context.Context, input *service.GenerateRedeemCodesInput) ([]service.RedeemCode, error) {
|
||||
return s.redeems, nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) DeleteRedeemCode(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) {
|
||||
return int64(len(ids)), nil
|
||||
}
|
||||
|
||||
func (s *stubAdminService) ExpireRedeemCode(ctx context.Context, id int64) (*service.RedeemCode, error) {
|
||||
code := service.RedeemCode{ID: id, Code: "R-TEST", Status: service.StatusUsed}
|
||||
return &code, nil
|
||||
}
|
||||
|
||||
// Ensure stub implements interface.
|
||||
var _ service.AdminService = (*stubAdminService)(nil)
|
||||
@@ -186,7 +186,7 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
|
||||
|
||||
// GetUsageTrend handles getting usage trend data
|
||||
// GET /api/v1/admin/dashboard/trend
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type
|
||||
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
granularity := c.DefaultQuery("granularity", "day")
|
||||
@@ -195,6 +195,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
var userID, apiKeyID, accountID, groupID int64
|
||||
var model string
|
||||
var stream *bool
|
||||
var billingType *int8
|
||||
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
||||
@@ -224,8 +225,17 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
stream = &streamVal
|
||||
}
|
||||
}
|
||||
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
||||
if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
|
||||
bt := int8(v)
|
||||
billingType = &bt
|
||||
} else {
|
||||
response.BadRequest(c, "Invalid billing_type")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream)
|
||||
trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get usage trend")
|
||||
return
|
||||
@@ -241,13 +251,14 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
|
||||
|
||||
// GetModelStats handles getting model usage statistics
|
||||
// GET /api/v1/admin/dashboard/models
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream
|
||||
// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type
|
||||
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
startTime, endTime := parseTimeRange(c)
|
||||
|
||||
// Parse optional filter params
|
||||
var userID, apiKeyID, accountID, groupID int64
|
||||
var stream *bool
|
||||
var billingType *int8
|
||||
|
||||
if userIDStr := c.Query("user_id"); userIDStr != "" {
|
||||
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
|
||||
@@ -274,8 +285,17 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
|
||||
stream = &streamVal
|
||||
}
|
||||
}
|
||||
if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
|
||||
if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
|
||||
bt := int8(v)
|
||||
billingType = &bt
|
||||
} else {
|
||||
response.BadRequest(c, "Invalid billing_type")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream)
|
||||
stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
|
||||
if err != nil {
|
||||
response.Error(c, 500, "Failed to get model statistics")
|
||||
return
|
||||
|
||||
@@ -94,9 +94,9 @@ func (h *GroupHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
outGroups := make([]dto.Group, 0, len(groups))
|
||||
outGroups := make([]dto.AdminGroup, 0, len(groups))
|
||||
for i := range groups {
|
||||
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
|
||||
outGroups = append(outGroups, *dto.GroupFromServiceAdmin(&groups[i]))
|
||||
}
|
||||
response.Paginated(c, outGroups, total, page, pageSize)
|
||||
}
|
||||
@@ -120,9 +120,9 @@ func (h *GroupHandler) GetAll(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
outGroups := make([]dto.Group, 0, len(groups))
|
||||
outGroups := make([]dto.AdminGroup, 0, len(groups))
|
||||
for i := range groups {
|
||||
outGroups = append(outGroups, *dto.GroupFromService(&groups[i]))
|
||||
outGroups = append(outGroups, *dto.GroupFromServiceAdmin(&groups[i]))
|
||||
}
|
||||
response.Success(c, outGroups)
|
||||
}
|
||||
@@ -142,7 +142,7 @@ func (h *GroupHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.GroupFromService(group))
|
||||
response.Success(c, dto.GroupFromServiceAdmin(group))
|
||||
}
|
||||
|
||||
// Create handles creating a new group
|
||||
@@ -177,7 +177,7 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.GroupFromService(group))
|
||||
response.Success(c, dto.GroupFromServiceAdmin(group))
|
||||
}
|
||||
|
||||
// Update handles updating a group
|
||||
@@ -219,7 +219,7 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.GroupFromService(group))
|
||||
response.Success(c, dto.GroupFromServiceAdmin(group))
|
||||
}
|
||||
|
||||
// Delete handles deleting a group
|
||||
|
||||
@@ -54,9 +54,9 @@ func (h *RedeemHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.RedeemCode, 0, len(codes))
|
||||
out := make([]dto.AdminRedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
|
||||
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
@@ -76,7 +76,7 @@ func (h *RedeemHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RedeemCodeFromService(code))
|
||||
response.Success(c, dto.RedeemCodeFromServiceAdmin(code))
|
||||
}
|
||||
|
||||
// Generate handles generating new redeem codes
|
||||
@@ -100,9 +100,9 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.RedeemCode, 0, len(codes))
|
||||
out := make([]dto.AdminRedeemCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.RedeemCodeFromService(&codes[i]))
|
||||
out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
@@ -163,7 +163,7 @@ func (h *RedeemHandler) Expire(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.RedeemCodeFromService(code))
|
||||
response.Success(c, dto.RedeemCodeFromServiceAdmin(code))
|
||||
}
|
||||
|
||||
// GetStats handles getting redeem code statistics
|
||||
|
||||
@@ -68,6 +68,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
@@ -111,13 +112,14 @@ type UpdateSettingsRequest struct {
|
||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||
|
||||
// OEM设置
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
@@ -259,6 +261,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
HomeContent: req.HomeContent,
|
||||
HideCcsImportButton: req.HideCcsImportButton,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
@@ -332,6 +335,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
ContactInfo: updatedSettings.ContactInfo,
|
||||
DocURL: updatedSettings.DocURL,
|
||||
HomeContent: updatedSettings.HomeContent,
|
||||
HideCcsImportButton: updatedSettings.HideCcsImportButton,
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
@@ -439,6 +443,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.HomeContent != after.HomeContent {
|
||||
changed = append(changed, "home_content")
|
||||
}
|
||||
if before.HideCcsImportButton != after.HideCcsImportButton {
|
||||
changed = append(changed, "hide_ccs_import_button")
|
||||
}
|
||||
if before.DefaultConcurrency != after.DefaultConcurrency {
|
||||
changed = append(changed, "default_concurrency")
|
||||
}
|
||||
|
||||
@@ -53,9 +53,9 @@ type BulkAssignSubscriptionRequest struct {
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// ExtendSubscriptionRequest represents extend subscription request
|
||||
type ExtendSubscriptionRequest struct {
|
||||
Days int `json:"days" binding:"required,min=1,max=36500"` // max 100 years
|
||||
// AdjustSubscriptionRequest represents adjust subscription request (extend or shorten)
|
||||
type AdjustSubscriptionRequest struct {
|
||||
Days int `json:"days" binding:"required,min=-36500,max=36500"` // negative to shorten, positive to extend
|
||||
}
|
||||
|
||||
// List handles listing all subscriptions with pagination and filters
|
||||
@@ -83,9 +83,9 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
|
||||
}
|
||||
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
|
||||
}
|
||||
@@ -105,7 +105,7 @@ func (h *SubscriptionHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserSubscriptionFromService(subscription))
|
||||
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
|
||||
}
|
||||
|
||||
// GetProgress handles getting subscription usage progress
|
||||
@@ -150,7 +150,7 @@ func (h *SubscriptionHandler) Assign(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserSubscriptionFromService(subscription))
|
||||
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
|
||||
}
|
||||
|
||||
// BulkAssign handles bulk assigning subscriptions to multiple users
|
||||
@@ -180,7 +180,7 @@ func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
|
||||
response.Success(c, dto.BulkAssignResultFromService(result))
|
||||
}
|
||||
|
||||
// Extend handles extending a subscription
|
||||
// Extend handles adjusting a subscription (extend or shorten)
|
||||
// POST /api/v1/admin/subscriptions/:id/extend
|
||||
func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
@@ -189,7 +189,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var req ExtendSubscriptionRequest
|
||||
var req AdjustSubscriptionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
@@ -201,7 +201,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserSubscriptionFromService(subscription))
|
||||
response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription))
|
||||
}
|
||||
|
||||
// Revoke handles revoking a subscription
|
||||
@@ -239,9 +239,9 @@ func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
|
||||
}
|
||||
response.PaginatedWithResult(c, out, toResponsePagination(pagination))
|
||||
}
|
||||
@@ -261,9 +261,9 @@ func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UserSubscription, 0, len(subscriptions))
|
||||
out := make([]dto.AdminUserSubscription, 0, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i]))
|
||||
out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i]))
|
||||
}
|
||||
response.Success(c, out)
|
||||
}
|
||||
|
||||
377
backend/internal/handler/admin/usage_cleanup_handler_test.go
Normal file
377
backend/internal/handler/admin/usage_cleanup_handler_test.go
Normal file
@@ -0,0 +1,377 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type cleanupRepoStub struct {
|
||||
mu sync.Mutex
|
||||
created []*service.UsageCleanupTask
|
||||
listTasks []service.UsageCleanupTask
|
||||
listResult *pagination.PaginationResult
|
||||
listErr error
|
||||
statusByID map[int64]string
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) CreateTask(ctx context.Context, task *service.UsageCleanupTask) error {
|
||||
if task == nil {
|
||||
return nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if task.ID == 0 {
|
||||
task.ID = int64(len(s.created) + 1)
|
||||
}
|
||||
if task.CreatedAt.IsZero() {
|
||||
task.CreatedAt = time.Now().UTC()
|
||||
}
|
||||
task.UpdatedAt = task.CreatedAt
|
||||
clone := *task
|
||||
s.created = append(s.created, &clone)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.listTasks, s.listResult, s.listErr
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*service.UsageCleanupTask, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.statusByID == nil {
|
||||
return "", sql.ErrNoRows
|
||||
}
|
||||
status, ok := s.statusByID[taskID]
|
||||
if !ok {
|
||||
return "", sql.ErrNoRows
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.statusByID == nil {
|
||||
s.statusByID = map[int64]string{}
|
||||
}
|
||||
status := s.statusByID[taskID]
|
||||
if status != service.UsageCleanupStatusPending && status != service.UsageCleanupStatusRunning {
|
||||
return false, nil
|
||||
}
|
||||
s.statusByID[taskID] = service.UsageCleanupStatusCanceled
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) DeleteUsageLogsBatch(ctx context.Context, filters service.UsageCleanupFilters, limit int) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var _ service.UsageCleanupRepository = (*cleanupRepoStub)(nil)
|
||||
|
||||
func setupCleanupRouter(cleanupService *service.UsageCleanupService, userID int64) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
if userID > 0 {
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: userID})
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
|
||||
handler := NewUsageHandler(nil, nil, nil, cleanupService)
|
||||
router.POST("/api/v1/admin/usage/cleanup-tasks", handler.CreateCleanupTask)
|
||||
router.GET("/api/v1/admin/usage/cleanup-tasks", handler.ListCleanupTasks)
|
||||
router.POST("/api/v1/admin/usage/cleanup-tasks/:id/cancel", handler.CancelCleanupTask)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskUnauthorized(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 0)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskUnavailable(t *testing.T) {
|
||||
router := setupCleanupRouter(nil, 1)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString(`{}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusServiceUnavailable, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskBindError(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 88)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString("{bad-json"))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskMissingRange(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 88)
|
||||
|
||||
payload := map[string]any{
|
||||
"start_date": "2024-01-01",
|
||||
"timezone": "UTC",
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskInvalidDate(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 88)
|
||||
|
||||
payload := map[string]any{
|
||||
"start_date": "2024-13-01",
|
||||
"end_date": "2024-01-02",
|
||||
"timezone": "UTC",
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskInvalidEndDate(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 88)
|
||||
|
||||
payload := map[string]any{
|
||||
"start_date": "2024-01-01",
|
||||
"end_date": "2024-02-40",
|
||||
"timezone": "UTC",
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 99)
|
||||
|
||||
payload := map[string]any{
|
||||
"start_date": " 2024-01-01 ",
|
||||
"end_date": "2024-01-02",
|
||||
"timezone": "UTC",
|
||||
"model": "gpt-4",
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
var resp response.Response
|
||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
|
||||
require.Equal(t, 0, resp.Code)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.created, 1)
|
||||
created := repo.created[0]
|
||||
require.Equal(t, int64(99), created.CreatedBy)
|
||||
require.NotNil(t, created.Filters.Model)
|
||||
require.Equal(t, "gpt-4", *created.Filters.Model)
|
||||
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC).Add(24*time.Hour - time.Nanosecond)
|
||||
require.True(t, created.Filters.StartTime.Equal(start))
|
||||
require.True(t, created.Filters.EndTime.Equal(end))
|
||||
}
|
||||
|
||||
func TestUsageHandlerListCleanupTasksUnavailable(t *testing.T) {
|
||||
router := setupCleanupRouter(nil, 0)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusServiceUnavailable, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerListCleanupTasksSuccess(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
repo.listTasks = []service.UsageCleanupTask{
|
||||
{
|
||||
ID: 7,
|
||||
Status: service.UsageCleanupStatusSucceeded,
|
||||
CreatedBy: 4,
|
||||
},
|
||||
}
|
||||
repo.listResult = &pagination.PaginationResult{Total: 1, Page: 1, PageSize: 20, Pages: 1}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 1)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, recorder.Code)
|
||||
|
||||
var resp struct {
|
||||
Code int `json:"code"`
|
||||
Data struct {
|
||||
Items []dto.UsageCleanupTask `json:"items"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
|
||||
require.Equal(t, 0, resp.Code)
|
||||
require.Len(t, resp.Data.Items, 1)
|
||||
require.Equal(t, int64(7), resp.Data.Items[0].ID)
|
||||
require.Equal(t, int64(1), resp.Data.Total)
|
||||
require.Equal(t, 1, resp.Data.Page)
|
||||
}
|
||||
|
||||
func TestUsageHandlerListCleanupTasksError(t *testing.T) {
|
||||
repo := &cleanupRepoStub{listErr: errors.New("boom")}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 1)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
require.Equal(t, http.StatusInternalServerError, recorder.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCancelCleanupTaskUnauthorized(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 0)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/1/cancel", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCancelCleanupTaskNotFound(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 1)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/999/cancel", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusNotFound, rec.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCancelCleanupTaskConflict(t *testing.T) {
|
||||
repo := &cleanupRepoStub{statusByID: map[int64]string{2: service.UsageCleanupStatusSucceeded}}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 1)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/2/cancel", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusConflict, rec.Code)
|
||||
}
|
||||
|
||||
func TestUsageHandlerCancelCleanupTaskSuccess(t *testing.T) {
|
||||
repo := &cleanupRepoStub{statusByID: map[int64]string{3: service.UsageCleanupStatusPending}}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
router := setupCleanupRouter(cleanupService, 1)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/3/cancel", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
@@ -1,7 +1,10 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
@@ -9,6 +12,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -16,9 +20,10 @@ import (
|
||||
|
||||
// UsageHandler handles admin usage-related requests
|
||||
type UsageHandler struct {
|
||||
usageService *service.UsageService
|
||||
apiKeyService *service.APIKeyService
|
||||
adminService service.AdminService
|
||||
usageService *service.UsageService
|
||||
apiKeyService *service.APIKeyService
|
||||
adminService service.AdminService
|
||||
cleanupService *service.UsageCleanupService
|
||||
}
|
||||
|
||||
// NewUsageHandler creates a new admin usage handler
|
||||
@@ -26,14 +31,30 @@ func NewUsageHandler(
|
||||
usageService *service.UsageService,
|
||||
apiKeyService *service.APIKeyService,
|
||||
adminService service.AdminService,
|
||||
cleanupService *service.UsageCleanupService,
|
||||
) *UsageHandler {
|
||||
return &UsageHandler{
|
||||
usageService: usageService,
|
||||
apiKeyService: apiKeyService,
|
||||
adminService: adminService,
|
||||
usageService: usageService,
|
||||
apiKeyService: apiKeyService,
|
||||
adminService: adminService,
|
||||
cleanupService: cleanupService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateUsageCleanupTaskRequest represents cleanup task creation request
|
||||
type CreateUsageCleanupTaskRequest struct {
|
||||
StartDate string `json:"start_date"`
|
||||
EndDate string `json:"end_date"`
|
||||
UserID *int64 `json:"user_id"`
|
||||
APIKeyID *int64 `json:"api_key_id"`
|
||||
AccountID *int64 `json:"account_id"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Model *string `json:"model"`
|
||||
Stream *bool `json:"stream"`
|
||||
BillingType *int8 `json:"billing_type"`
|
||||
Timezone string `json:"timezone"`
|
||||
}
|
||||
|
||||
// List handles listing all usage records with filters
|
||||
// GET /api/v1/admin/usage
|
||||
func (h *UsageHandler) List(c *gin.Context) {
|
||||
@@ -142,7 +163,7 @@ func (h *UsageHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.UsageLog, 0, len(records))
|
||||
out := make([]dto.AdminUsageLog, 0, len(records))
|
||||
for i := range records {
|
||||
out = append(out, *dto.UsageLogFromServiceAdmin(&records[i]))
|
||||
}
|
||||
@@ -344,3 +365,162 @@ func (h *UsageHandler) SearchAPIKeys(c *gin.Context) {
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// ListCleanupTasks handles listing usage cleanup tasks
|
||||
// GET /api/v1/admin/usage/cleanup-tasks
|
||||
func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
|
||||
if h.cleanupService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
|
||||
return
|
||||
}
|
||||
operator := int64(0)
|
||||
if subject, ok := middleware.GetAuthSubjectFromContext(c); ok {
|
||||
operator = subject.UserID
|
||||
}
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
log.Printf("[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params)
|
||||
if err != nil {
|
||||
log.Printf("[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
out := make([]dto.UsageCleanupTask, 0, len(tasks))
|
||||
for i := range tasks {
|
||||
out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i]))
|
||||
}
|
||||
log.Printf("[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
|
||||
response.Paginated(c, out, result.Total, page, pageSize)
|
||||
}
|
||||
|
||||
// CreateCleanupTask handles creating a usage cleanup task
|
||||
// POST /api/v1/admin/usage/cleanup-tasks
|
||||
func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
|
||||
if h.cleanupService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
|
||||
return
|
||||
}
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok || subject.UserID <= 0 {
|
||||
response.Unauthorized(c, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
var req CreateUsageCleanupTaskRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
req.StartDate = strings.TrimSpace(req.StartDate)
|
||||
req.EndDate = strings.TrimSpace(req.EndDate)
|
||||
if req.StartDate == "" || req.EndDate == "" {
|
||||
response.BadRequest(c, "start_date and end_date are required")
|
||||
return
|
||||
}
|
||||
|
||||
startTime, err := timezone.ParseInUserLocation("2006-01-02", req.StartDate, req.Timezone)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
endTime, err := timezone.ParseInUserLocation("2006-01-02", req.EndDate, req.Timezone)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
|
||||
return
|
||||
}
|
||||
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
|
||||
|
||||
filters := service.UsageCleanupFilters{
|
||||
StartTime: startTime,
|
||||
EndTime: endTime,
|
||||
UserID: req.UserID,
|
||||
APIKeyID: req.APIKeyID,
|
||||
AccountID: req.AccountID,
|
||||
GroupID: req.GroupID,
|
||||
Model: req.Model,
|
||||
Stream: req.Stream,
|
||||
BillingType: req.BillingType,
|
||||
}
|
||||
|
||||
var userID any
|
||||
if filters.UserID != nil {
|
||||
userID = *filters.UserID
|
||||
}
|
||||
var apiKeyID any
|
||||
if filters.APIKeyID != nil {
|
||||
apiKeyID = *filters.APIKeyID
|
||||
}
|
||||
var accountID any
|
||||
if filters.AccountID != nil {
|
||||
accountID = *filters.AccountID
|
||||
}
|
||||
var groupID any
|
||||
if filters.GroupID != nil {
|
||||
groupID = *filters.GroupID
|
||||
}
|
||||
var model any
|
||||
if filters.Model != nil {
|
||||
model = *filters.Model
|
||||
}
|
||||
var stream any
|
||||
if filters.Stream != nil {
|
||||
stream = *filters.Stream
|
||||
}
|
||||
var billingType any
|
||||
if filters.BillingType != nil {
|
||||
billingType = *filters.BillingType
|
||||
}
|
||||
|
||||
log.Printf("[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
|
||||
subject.UserID,
|
||||
filters.StartTime.Format(time.RFC3339),
|
||||
filters.EndTime.Format(time.RFC3339),
|
||||
userID,
|
||||
apiKeyID,
|
||||
accountID,
|
||||
groupID,
|
||||
model,
|
||||
stream,
|
||||
billingType,
|
||||
req.Timezone,
|
||||
)
|
||||
|
||||
task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID)
|
||||
if err != nil {
|
||||
log.Printf("[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
|
||||
response.Success(c, dto.UsageCleanupTaskFromService(task))
|
||||
}
|
||||
|
||||
// CancelCleanupTask handles canceling a usage cleanup task
|
||||
// POST /api/v1/admin/usage/cleanup-tasks/:id/cancel
|
||||
func (h *UsageHandler) CancelCleanupTask(c *gin.Context) {
|
||||
if h.cleanupService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
|
||||
return
|
||||
}
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok || subject.UserID <= 0 {
|
||||
response.Unauthorized(c, "Unauthorized")
|
||||
return
|
||||
}
|
||||
idStr := strings.TrimSpace(c.Param("id"))
|
||||
taskID, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil || taskID <= 0 {
|
||||
response.BadRequest(c, "Invalid task id")
|
||||
return
|
||||
}
|
||||
log.Printf("[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
|
||||
if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil {
|
||||
log.Printf("[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
log.Printf("[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
|
||||
response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled})
|
||||
}
|
||||
|
||||
@@ -84,9 +84,9 @@ func (h *UserHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.User, 0, len(users))
|
||||
out := make([]dto.AdminUser, 0, len(users))
|
||||
for i := range users {
|
||||
out = append(out, *dto.UserFromService(&users[i]))
|
||||
out = append(out, *dto.UserFromServiceAdmin(&users[i]))
|
||||
}
|
||||
response.Paginated(c, out, total, page, pageSize)
|
||||
}
|
||||
@@ -129,7 +129,7 @@ func (h *UserHandler) GetByID(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
response.Success(c, dto.UserFromServiceAdmin(user))
|
||||
}
|
||||
|
||||
// Create handles creating a new user
|
||||
@@ -155,7 +155,7 @@ func (h *UserHandler) Create(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
response.Success(c, dto.UserFromServiceAdmin(user))
|
||||
}
|
||||
|
||||
// Update handles updating a user
|
||||
@@ -189,7 +189,7 @@ func (h *UserHandler) Update(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
response.Success(c, dto.UserFromServiceAdmin(user))
|
||||
}
|
||||
|
||||
// Delete handles deleting a user
|
||||
@@ -231,7 +231,7 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.UserFromService(user))
|
||||
response.Success(c, dto.UserFromServiceAdmin(user))
|
||||
}
|
||||
|
||||
// GetUserAPIKeys handles getting user's API keys
|
||||
|
||||
@@ -15,7 +15,6 @@ func UserFromServiceShallow(u *service.User) *User {
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Notes: u.Notes,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
@@ -48,6 +47,22 @@ func UserFromService(u *service.User) *User {
|
||||
return out
|
||||
}
|
||||
|
||||
// UserFromServiceAdmin converts a service User to DTO for admin users.
|
||||
// It includes notes - user-facing endpoints must not use this.
|
||||
func UserFromServiceAdmin(u *service.User) *AdminUser {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
base := UserFromService(u)
|
||||
if base == nil {
|
||||
return nil
|
||||
}
|
||||
return &AdminUser{
|
||||
User: *base,
|
||||
Notes: u.Notes,
|
||||
}
|
||||
}
|
||||
|
||||
func APIKeyFromService(k *service.APIKey) *APIKey {
|
||||
if k == nil {
|
||||
return nil
|
||||
@@ -72,36 +87,29 @@ func GroupFromServiceShallow(g *service.Group) *Group {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
return &Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: g.Description,
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUSD,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
||||
ImagePrice1K: g.ImagePrice1K,
|
||||
ImagePrice2K: g.ImagePrice2K,
|
||||
ImagePrice4K: g.ImagePrice4K,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
AccountCount: g.AccountCount,
|
||||
}
|
||||
out := groupFromServiceBase(g)
|
||||
return &out
|
||||
}
|
||||
|
||||
func GroupFromService(g *service.Group) *Group {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
out := GroupFromServiceShallow(g)
|
||||
return GroupFromServiceShallow(g)
|
||||
}
|
||||
|
||||
// GroupFromServiceAdmin converts a service Group to DTO for admin users.
|
||||
// It includes internal fields like model_routing and account_count.
|
||||
func GroupFromServiceAdmin(g *service.Group) *AdminGroup {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
out := &AdminGroup{
|
||||
Group: groupFromServiceBase(g),
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
AccountCount: g.AccountCount,
|
||||
}
|
||||
if len(g.AccountGroups) > 0 {
|
||||
out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups))
|
||||
for i := range g.AccountGroups {
|
||||
@@ -112,6 +120,29 @@ func GroupFromService(g *service.Group) *Group {
|
||||
return out
|
||||
}
|
||||
|
||||
func groupFromServiceBase(g *service.Group) Group {
|
||||
return Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: g.Description,
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUSD,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUSD,
|
||||
ImagePrice1K: g.ImagePrice1K,
|
||||
ImagePrice2K: g.ImagePrice2K,
|
||||
ImagePrice4K: g.ImagePrice4K,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
if a == nil {
|
||||
return nil
|
||||
@@ -161,6 +192,16 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
if idleTimeout := a.GetSessionIdleTimeoutMinutes(); idleTimeout > 0 {
|
||||
out.SessionIdleTimeoutMin = &idleTimeout
|
||||
}
|
||||
// TLS指纹伪装开关
|
||||
if a.IsTLSFingerprintEnabled() {
|
||||
enabled := true
|
||||
out.EnableTLSFingerprint = &enabled
|
||||
}
|
||||
// 会话ID伪装开关
|
||||
if a.IsSessionIDMaskingEnabled() {
|
||||
enabled := true
|
||||
out.EnableSessionIDMasking = &enabled
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
@@ -263,7 +304,24 @@ func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
|
||||
if rc == nil {
|
||||
return nil
|
||||
}
|
||||
return &RedeemCode{
|
||||
out := redeemCodeFromServiceBase(rc)
|
||||
return &out
|
||||
}
|
||||
|
||||
// RedeemCodeFromServiceAdmin converts a service RedeemCode to DTO for admin users.
|
||||
// It includes notes - user-facing endpoints must not use this.
|
||||
func RedeemCodeFromServiceAdmin(rc *service.RedeemCode) *AdminRedeemCode {
|
||||
if rc == nil {
|
||||
return nil
|
||||
}
|
||||
return &AdminRedeemCode{
|
||||
RedeemCode: redeemCodeFromServiceBase(rc),
|
||||
Notes: rc.Notes,
|
||||
}
|
||||
}
|
||||
|
||||
func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode {
|
||||
return RedeemCode{
|
||||
ID: rc.ID,
|
||||
Code: rc.Code,
|
||||
Type: rc.Type,
|
||||
@@ -271,7 +329,6 @@ func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode {
|
||||
Status: rc.Status,
|
||||
UsedBy: rc.UsedBy,
|
||||
UsedAt: rc.UsedAt,
|
||||
Notes: rc.Notes,
|
||||
CreatedAt: rc.CreatedAt,
|
||||
GroupID: rc.GroupID,
|
||||
ValidityDays: rc.ValidityDays,
|
||||
@@ -292,14 +349,9 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary {
|
||||
}
|
||||
}
|
||||
|
||||
// usageLogFromServiceBase is a helper that converts service UsageLog to DTO.
|
||||
// The account parameter allows caller to control what Account info is included.
|
||||
// The includeIPAddress parameter controls whether to include the IP address (admin-only).
|
||||
func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, includeIPAddress bool) *UsageLog {
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
result := &UsageLog{
|
||||
func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
// 普通用户 DTO:严禁包含管理员字段(例如 account_rate_multiplier、ip_address、account)。
|
||||
return UsageLog{
|
||||
ID: l.ID,
|
||||
UserID: l.UserID,
|
||||
APIKeyID: l.APIKeyID,
|
||||
@@ -321,7 +373,6 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, inclu
|
||||
TotalCost: l.TotalCost,
|
||||
ActualCost: l.ActualCost,
|
||||
RateMultiplier: l.RateMultiplier,
|
||||
AccountRateMultiplier: l.AccountRateMultiplier,
|
||||
BillingType: l.BillingType,
|
||||
Stream: l.Stream,
|
||||
DurationMs: l.DurationMs,
|
||||
@@ -332,30 +383,63 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, inclu
|
||||
CreatedAt: l.CreatedAt,
|
||||
User: UserFromServiceShallow(l.User),
|
||||
APIKey: APIKeyFromService(l.APIKey),
|
||||
Account: account,
|
||||
Group: GroupFromServiceShallow(l.Group),
|
||||
Subscription: UserSubscriptionFromService(l.Subscription),
|
||||
}
|
||||
// IP 地址仅对管理员可见
|
||||
if includeIPAddress {
|
||||
result.IPAddress = l.IPAddress
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// UsageLogFromService converts a service UsageLog to DTO for regular users.
|
||||
// It excludes Account details and IP address - users should not see these.
|
||||
func UsageLogFromService(l *service.UsageLog) *UsageLog {
|
||||
return usageLogFromServiceBase(l, nil, false)
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
u := usageLogFromServiceUser(l)
|
||||
return &u
|
||||
}
|
||||
|
||||
// UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users.
|
||||
// It includes minimal Account info (ID, Name only) and IP address.
|
||||
func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog {
|
||||
func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog {
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true)
|
||||
return &AdminUsageLog{
|
||||
UsageLog: usageLogFromServiceUser(l),
|
||||
AccountRateMultiplier: l.AccountRateMultiplier,
|
||||
IPAddress: l.IPAddress,
|
||||
Account: AccountSummaryFromService(l.Account),
|
||||
}
|
||||
}
|
||||
|
||||
func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTask {
|
||||
if task == nil {
|
||||
return nil
|
||||
}
|
||||
return &UsageCleanupTask{
|
||||
ID: task.ID,
|
||||
Status: task.Status,
|
||||
Filters: UsageCleanupFilters{
|
||||
StartTime: task.Filters.StartTime,
|
||||
EndTime: task.Filters.EndTime,
|
||||
UserID: task.Filters.UserID,
|
||||
APIKeyID: task.Filters.APIKeyID,
|
||||
AccountID: task.Filters.AccountID,
|
||||
GroupID: task.Filters.GroupID,
|
||||
Model: task.Filters.Model,
|
||||
Stream: task.Filters.Stream,
|
||||
BillingType: task.Filters.BillingType,
|
||||
},
|
||||
CreatedBy: task.CreatedBy,
|
||||
DeletedRows: task.DeletedRows,
|
||||
ErrorMessage: task.ErrorMsg,
|
||||
CanceledBy: task.CanceledBy,
|
||||
CanceledAt: task.CanceledAt,
|
||||
StartedAt: task.StartedAt,
|
||||
FinishedAt: task.FinishedAt,
|
||||
CreatedAt: task.CreatedAt,
|
||||
UpdatedAt: task.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func SettingFromService(s *service.Setting) *Setting {
|
||||
@@ -374,7 +458,27 @@ func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscriptio
|
||||
if sub == nil {
|
||||
return nil
|
||||
}
|
||||
return &UserSubscription{
|
||||
out := userSubscriptionFromServiceBase(sub)
|
||||
return &out
|
||||
}
|
||||
|
||||
// UserSubscriptionFromServiceAdmin converts a service UserSubscription to DTO for admin users.
|
||||
// It includes assignment metadata and notes.
|
||||
func UserSubscriptionFromServiceAdmin(sub *service.UserSubscription) *AdminUserSubscription {
|
||||
if sub == nil {
|
||||
return nil
|
||||
}
|
||||
return &AdminUserSubscription{
|
||||
UserSubscription: userSubscriptionFromServiceBase(sub),
|
||||
AssignedBy: sub.AssignedBy,
|
||||
AssignedAt: sub.AssignedAt,
|
||||
Notes: sub.Notes,
|
||||
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
|
||||
}
|
||||
}
|
||||
|
||||
func userSubscriptionFromServiceBase(sub *service.UserSubscription) UserSubscription {
|
||||
return UserSubscription{
|
||||
ID: sub.ID,
|
||||
UserID: sub.UserID,
|
||||
GroupID: sub.GroupID,
|
||||
@@ -387,14 +491,10 @@ func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscriptio
|
||||
DailyUsageUSD: sub.DailyUsageUSD,
|
||||
WeeklyUsageUSD: sub.WeeklyUsageUSD,
|
||||
MonthlyUsageUSD: sub.MonthlyUsageUSD,
|
||||
AssignedBy: sub.AssignedBy,
|
||||
AssignedAt: sub.AssignedAt,
|
||||
Notes: sub.Notes,
|
||||
CreatedAt: sub.CreatedAt,
|
||||
UpdatedAt: sub.UpdatedAt,
|
||||
User: UserFromServiceShallow(sub.User),
|
||||
Group: GroupFromServiceShallow(sub.Group),
|
||||
AssignedByUser: UserFromServiceShallow(sub.AssignedByUser),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -402,9 +502,9 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
subs := make([]UserSubscription, 0, len(r.Subscriptions))
|
||||
subs := make([]AdminUserSubscription, 0, len(r.Subscriptions))
|
||||
for i := range r.Subscriptions {
|
||||
subs = append(subs, *UserSubscriptionFromService(&r.Subscriptions[i]))
|
||||
subs = append(subs, *UserSubscriptionFromServiceAdmin(&r.Subscriptions[i]))
|
||||
}
|
||||
return &BulkAssignResult{
|
||||
SuccessCount: r.SuccessCount,
|
||||
|
||||
@@ -22,13 +22,14 @@ type SystemSettings struct {
|
||||
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
|
||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
@@ -63,6 +64,7 @@ type PublicSettings struct {
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ type User struct {
|
||||
ID int64 `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Notes string `json:"notes"`
|
||||
Role string `json:"role"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
@@ -19,6 +18,14 @@ type User struct {
|
||||
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
|
||||
}
|
||||
|
||||
// AdminUser 是管理员接口使用的 user DTO(包含敏感/内部字段)。
|
||||
// 注意:普通用户接口不得返回 notes 等管理员备注信息。
|
||||
type AdminUser struct {
|
||||
User
|
||||
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
@@ -58,13 +65,19 @@ type Group struct {
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// AdminGroup 是管理员接口使用的 group DTO(包含敏感/内部字段)。
|
||||
// 注意:普通用户接口不得返回 model_routing/account_count/account_groups 等内部信息。
|
||||
type AdminGroup struct {
|
||||
Group
|
||||
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
AccountCount int64 `json:"account_count,omitempty"`
|
||||
}
|
||||
@@ -112,6 +125,15 @@ type Account struct {
|
||||
MaxSessions *int `json:"max_sessions,omitempty"`
|
||||
SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"`
|
||||
|
||||
// TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||
// 从 extra 字段提取,方便前端显示和编辑
|
||||
EnableTLSFingerprint *bool `json:"enable_tls_fingerprint,omitempty"`
|
||||
|
||||
// 会话ID伪装(仅 Anthropic OAuth/SetupToken 账号有效)
|
||||
// 启用后将在15分钟内固定 metadata.user_id 中的 session ID
|
||||
// 从 extra 字段提取,方便前端显示和编辑
|
||||
EnableSessionIDMasking *bool `json:"session_id_masking_enabled,omitempty"`
|
||||
|
||||
Proxy *Proxy `json:"proxy,omitempty"`
|
||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||
|
||||
@@ -171,7 +193,6 @@ type RedeemCode struct {
|
||||
Status string `json:"status"`
|
||||
UsedBy *int64 `json:"used_by"`
|
||||
UsedAt *time.Time `json:"used_at"`
|
||||
Notes string `json:"notes"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
GroupID *int64 `json:"group_id"`
|
||||
@@ -181,6 +202,15 @@ type RedeemCode struct {
|
||||
Group *Group `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
// AdminRedeemCode 是管理员接口使用的 redeem code DTO(包含 notes 等字段)。
|
||||
// 注意:普通用户接口不得返回 notes 等内部信息。
|
||||
type AdminRedeemCode struct {
|
||||
RedeemCode
|
||||
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
// UsageLog 是普通用户接口使用的 usage log DTO(不包含管理员字段)。
|
||||
type UsageLog struct {
|
||||
ID int64 `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
@@ -200,14 +230,13 @@ type UsageLog struct {
|
||||
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
|
||||
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
|
||||
|
||||
InputCost float64 `json:"input_cost"`
|
||||
OutputCost float64 `json:"output_cost"`
|
||||
CacheCreationCost float64 `json:"cache_creation_cost"`
|
||||
CacheReadCost float64 `json:"cache_read_cost"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
ActualCost float64 `json:"actual_cost"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
|
||||
InputCost float64 `json:"input_cost"`
|
||||
OutputCost float64 `json:"output_cost"`
|
||||
CacheCreationCost float64 `json:"cache_creation_cost"`
|
||||
CacheReadCost float64 `json:"cache_read_cost"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
ActualCost float64 `json:"actual_cost"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
|
||||
BillingType int8 `json:"billing_type"`
|
||||
Stream bool `json:"stream"`
|
||||
@@ -221,18 +250,55 @@ type UsageLog struct {
|
||||
// User-Agent
|
||||
UserAgent *string `json:"user_agent"`
|
||||
|
||||
// IP 地址(仅管理员可见)
|
||||
IPAddress *string `json:"ip_address,omitempty"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
APIKey *APIKey `json:"api_key,omitempty"`
|
||||
Account *AccountSummary `json:"account,omitempty"` // Use minimal AccountSummary to prevent data leakage
|
||||
Group *Group `json:"group,omitempty"`
|
||||
Subscription *UserSubscription `json:"subscription,omitempty"`
|
||||
}
|
||||
|
||||
// AdminUsageLog 是管理员接口使用的 usage log DTO(包含管理员字段)。
|
||||
type AdminUsageLog struct {
|
||||
UsageLog
|
||||
|
||||
// AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理)
|
||||
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
|
||||
|
||||
// IPAddress 用户请求 IP(仅管理员可见)
|
||||
IPAddress *string `json:"ip_address,omitempty"`
|
||||
|
||||
// Account 最小账号信息(避免泄露敏感字段)
|
||||
Account *AccountSummary `json:"account,omitempty"`
|
||||
}
|
||||
|
||||
type UsageCleanupFilters struct {
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
UserID *int64 `json:"user_id,omitempty"`
|
||||
APIKeyID *int64 `json:"api_key_id,omitempty"`
|
||||
AccountID *int64 `json:"account_id,omitempty"`
|
||||
GroupID *int64 `json:"group_id,omitempty"`
|
||||
Model *string `json:"model,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
BillingType *int8 `json:"billing_type,omitempty"`
|
||||
}
|
||||
|
||||
type UsageCleanupTask struct {
|
||||
ID int64 `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Filters UsageCleanupFilters `json:"filters"`
|
||||
CreatedBy int64 `json:"created_by"`
|
||||
DeletedRows int64 `json:"deleted_rows"`
|
||||
ErrorMessage *string `json:"error_message,omitempty"`
|
||||
CanceledBy *int64 `json:"canceled_by,omitempty"`
|
||||
CanceledAt *time.Time `json:"canceled_at,omitempty"`
|
||||
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||
FinishedAt *time.Time `json:"finished_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// AccountSummary is a minimal account info for usage log display.
|
||||
// It intentionally excludes sensitive fields like Credentials, Proxy, etc.
|
||||
type AccountSummary struct {
|
||||
@@ -264,23 +330,30 @@ type UserSubscription struct {
|
||||
WeeklyUsageUSD float64 `json:"weekly_usage_usd"`
|
||||
MonthlyUsageUSD float64 `json:"monthly_usage_usd"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
// AdminUserSubscription 是管理员接口使用的订阅 DTO(包含分配信息/备注等字段)。
|
||||
// 注意:普通用户接口不得返回 assigned_by/assigned_at/notes/assigned_by_user 等管理员字段。
|
||||
type AdminUserSubscription struct {
|
||||
UserSubscription
|
||||
|
||||
AssignedBy *int64 `json:"assigned_by"`
|
||||
AssignedAt time.Time `json:"assigned_at"`
|
||||
Notes string `json:"notes"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
Group *Group `json:"group,omitempty"`
|
||||
AssignedByUser *User `json:"assigned_by_user,omitempty"`
|
||||
AssignedByUser *User `json:"assigned_by_user,omitempty"`
|
||||
}
|
||||
|
||||
type BulkAssignResult struct {
|
||||
SuccessCount int `json:"success_count"`
|
||||
FailedCount int `json:"failed_count"`
|
||||
Subscriptions []UserSubscription `json:"subscriptions"`
|
||||
Errors []string `json:"errors"`
|
||||
SuccessCount int `json:"success_count"`
|
||||
FailedCount int `json:"failed_count"`
|
||||
Subscriptions []AdminUserSubscription `json:"subscriptions"`
|
||||
Errors []string `json:"errors"`
|
||||
}
|
||||
|
||||
// PromoCode 注册优惠码
|
||||
|
||||
@@ -31,6 +31,8 @@ type GatewayHandler struct {
|
||||
userService *service.UserService
|
||||
billingCacheService *service.BillingCacheService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
maxAccountSwitchesGemini int
|
||||
}
|
||||
|
||||
// NewGatewayHandler creates a new GatewayHandler
|
||||
@@ -44,8 +46,16 @@ func NewGatewayHandler(
|
||||
cfg *config.Config,
|
||||
) *GatewayHandler {
|
||||
pingInterval := time.Duration(0)
|
||||
maxAccountSwitches := 10
|
||||
maxAccountSwitchesGemini := 3
|
||||
if cfg != nil {
|
||||
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
|
||||
if cfg.Gateway.MaxAccountSwitches > 0 {
|
||||
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
|
||||
}
|
||||
if cfg.Gateway.MaxAccountSwitchesGemini > 0 {
|
||||
maxAccountSwitchesGemini = cfg.Gateway.MaxAccountSwitchesGemini
|
||||
}
|
||||
}
|
||||
return &GatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
@@ -54,6 +64,8 @@ func NewGatewayHandler(
|
||||
userService: userService,
|
||||
billingCacheService: billingCacheService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,7 +191,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
|
||||
if platform == service.PlatformGemini {
|
||||
const maxAccountSwitches = 3
|
||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
@@ -313,7 +325,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
const maxAccountSwitches = 10
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
@@ -220,7 +220,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
if sessionHash != "" {
|
||||
sessionKey = "gemini:" + sessionHash
|
||||
}
|
||||
const maxAccountSwitches = 3
|
||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
@@ -25,6 +25,7 @@ type OpenAIGatewayHandler struct {
|
||||
gatewayService *service.OpenAIGatewayService
|
||||
billingCacheService *service.BillingCacheService
|
||||
concurrencyHelper *ConcurrencyHelper
|
||||
maxAccountSwitches int
|
||||
}
|
||||
|
||||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||||
@@ -35,13 +36,18 @@ func NewOpenAIGatewayHandler(
|
||||
cfg *config.Config,
|
||||
) *OpenAIGatewayHandler {
|
||||
pingInterval := time.Duration(0)
|
||||
maxAccountSwitches := 3
|
||||
if cfg != nil {
|
||||
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
|
||||
if cfg.Gateway.MaxAccountSwitches > 0 {
|
||||
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
|
||||
}
|
||||
}
|
||||
return &OpenAIGatewayHandler{
|
||||
gatewayService: gatewayService,
|
||||
billingCacheService: billingCacheService,
|
||||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||||
maxAccountSwitches: maxAccountSwitches,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -186,10 +192,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Generate session hash (from header for OpenAI)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c)
|
||||
// Generate session hash (header first; fallback to prompt_cache_key)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody)
|
||||
|
||||
const maxAccountSwitches = 3
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
lastFailoverStatus := 0
|
||||
|
||||
@@ -43,6 +43,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
Version: h.version,
|
||||
})
|
||||
|
||||
@@ -47,9 +47,6 @@ func (h *UserHandler) GetProfile(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 清空notes字段,普通用户不应看到备注
|
||||
userData.Notes = ""
|
||||
|
||||
response.Success(c, dto.UserFromService(userData))
|
||||
}
|
||||
|
||||
@@ -105,8 +102,5 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 清空notes字段,普通用户不应看到备注
|
||||
updatedUser.Notes = ""
|
||||
|
||||
response.Success(c, dto.UserFromService(updatedUser))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user