First commit

This commit is contained in:
shaw
2025-12-18 13:50:39 +08:00
parent 569f4882e5
commit 642842c29e
218 changed files with 86902 additions and 0 deletions

View File

@@ -0,0 +1,537 @@
package admin
import (
"strconv"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// OAuthHandler handles OAuth-related operations for accounts
type OAuthHandler struct {
oauthService *service.OAuthService
adminService service.AdminService
}
// NewOAuthHandler creates a new OAuth handler
func NewOAuthHandler(oauthService *service.OAuthService, adminService service.AdminService) *OAuthHandler {
return &OAuthHandler{
oauthService: oauthService,
adminService: adminService,
}
}
// AccountHandler handles admin account management
type AccountHandler struct {
adminService service.AdminService
oauthService *service.OAuthService
rateLimitService *service.RateLimitService
accountUsageService *service.AccountUsageService
accountTestService *service.AccountTestService
}
// NewAccountHandler creates a new admin account handler
func NewAccountHandler(adminService service.AdminService, oauthService *service.OAuthService, rateLimitService *service.RateLimitService, accountUsageService *service.AccountUsageService, accountTestService *service.AccountTestService) *AccountHandler {
return &AccountHandler{
adminService: adminService,
oauthService: oauthService,
rateLimitService: rateLimitService,
accountUsageService: accountUsageService,
accountTestService: accountTestService,
}
}
// CreateAccountRequest represents create account request
type CreateAccountRequest struct {
Name string `json:"name" binding:"required"`
Platform string `json:"platform" binding:"required"`
Type string `json:"type" binding:"required,oneof=oauth setup-token apikey"`
Credentials map[string]interface{} `json:"credentials" binding:"required"`
Extra map[string]interface{} `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
}
// UpdateAccountRequest represents update account request
// 使用指针类型来区分"未提供"和"设置为0"
type UpdateAccountRequest struct {
Name string `json:"name"`
Type string `json:"type" binding:"omitempty,oneof=oauth setup-token apikey"`
Credentials map[string]interface{} `json:"credentials"`
Extra map[string]interface{} `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
GroupIDs *[]int64 `json:"group_ids"`
}
// List handles listing all accounts with pagination
// GET /api/v1/admin/accounts
func (h *AccountHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
platform := c.Query("platform")
accountType := c.Query("type")
status := c.Query("status")
search := c.Query("search")
accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search)
if err != nil {
response.InternalError(c, "Failed to list accounts: "+err.Error())
return
}
response.Paginated(c, accounts, total, page, pageSize)
}
// GetByID handles getting an account by ID
// GET /api/v1/admin/accounts/:id
func (h *AccountHandler) GetByID(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.NotFound(c, "Account not found")
return
}
response.Success(c, account)
}
// Create handles creating a new account
// POST /api/v1/admin/accounts
func (h *AccountHandler) Create(c *gin.Context) {
var req CreateAccountRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{
Name: req.Name,
Platform: req.Platform,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
GroupIDs: req.GroupIDs,
})
if err != nil {
response.BadRequest(c, "Failed to create account: "+err.Error())
return
}
response.Success(c, account)
}
// Update handles updating an account
// PUT /api/v1/admin/accounts/:id
func (h *AccountHandler) Update(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
var req UpdateAccountRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
account, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Name: req.Name,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency, // 指针类型nil 表示未提供
Priority: req.Priority, // 指针类型nil 表示未提供
Status: req.Status,
GroupIDs: req.GroupIDs,
})
if err != nil {
response.InternalError(c, "Failed to update account: "+err.Error())
return
}
response.Success(c, account)
}
// Delete handles deleting an account
// DELETE /api/v1/admin/accounts/:id
func (h *AccountHandler) Delete(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
err = h.adminService.DeleteAccount(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to delete account: "+err.Error())
return
}
response.Success(c, gin.H{"message": "Account deleted successfully"})
}
// Test handles testing account connectivity with SSE streaming
// POST /api/v1/admin/accounts/:id/test
func (h *AccountHandler) Test(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
// Use AccountTestService to test the account with SSE streaming
if err := h.accountTestService.TestAccountConnection(c, accountID); err != nil {
// Error already sent via SSE, just log
return
}
}
// Refresh handles refreshing account credentials
// POST /api/v1/admin/accounts/:id/refresh
func (h *AccountHandler) Refresh(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
// Get account
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.NotFound(c, "Account not found")
return
}
// Only refresh OAuth-based accounts (oauth and setup-token)
if !account.IsOAuth() {
response.BadRequest(c, "Cannot refresh non-OAuth account credentials")
return
}
// Use OAuth service to refresh token
tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account)
if err != nil {
response.InternalError(c, "Failed to refresh credentials: "+err.Error())
return
}
// Update account credentials
newCredentials := map[string]interface{}{
"access_token": tokenInfo.AccessToken,
"token_type": tokenInfo.TokenType,
"expires_in": tokenInfo.ExpiresIn,
"expires_at": tokenInfo.ExpiresAt,
"refresh_token": tokenInfo.RefreshToken,
"scope": tokenInfo.Scope,
}
updatedAccount, err := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
Credentials: newCredentials,
})
if err != nil {
response.InternalError(c, "Failed to update account credentials: "+err.Error())
return
}
response.Success(c, updatedAccount)
}
// GetStats handles getting account statistics
// GET /api/v1/admin/accounts/:id/stats
func (h *AccountHandler) GetStats(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
// Return mock data for now
_ = accountID
response.Success(c, gin.H{
"total_requests": 0,
"successful_requests": 0,
"failed_requests": 0,
"total_tokens": 0,
"average_response_time": 0,
})
}
// ClearError handles clearing account error
// POST /api/v1/admin/accounts/:id/clear-error
func (h *AccountHandler) ClearError(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
account, err := h.adminService.ClearAccountError(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to clear error: "+err.Error())
return
}
response.Success(c, account)
}
// BatchCreate handles batch creating accounts
// POST /api/v1/admin/accounts/batch
func (h *AccountHandler) BatchCreate(c *gin.Context) {
var req struct {
Accounts []CreateAccountRequest `json:"accounts" binding:"required,min=1"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Return mock data for now
response.Success(c, gin.H{
"success": len(req.Accounts),
"failed": 0,
"results": []gin.H{},
})
}
// ========== OAuth Handlers ==========
// GenerateAuthURLRequest represents the request for generating auth URL
type GenerateAuthURLRequest struct {
ProxyID *int64 `json:"proxy_id"`
}
// GenerateAuthURL generates OAuth authorization URL with full scope
// POST /api/v1/admin/accounts/generate-auth-url
func (h *OAuthHandler) GenerateAuthURL(c *gin.Context) {
var req GenerateAuthURLRequest
if err := c.ShouldBindJSON(&req); err != nil {
// Allow empty body
req = GenerateAuthURLRequest{}
}
result, err := h.oauthService.GenerateAuthURL(c.Request.Context(), req.ProxyID)
if err != nil {
response.InternalError(c, "Failed to generate auth URL: "+err.Error())
return
}
response.Success(c, result)
}
// GenerateSetupTokenURL generates OAuth authorization URL for setup token (inference only)
// POST /api/v1/admin/accounts/generate-setup-token-url
func (h *OAuthHandler) GenerateSetupTokenURL(c *gin.Context) {
var req GenerateAuthURLRequest
if err := c.ShouldBindJSON(&req); err != nil {
// Allow empty body
req = GenerateAuthURLRequest{}
}
result, err := h.oauthService.GenerateSetupTokenURL(c.Request.Context(), req.ProxyID)
if err != nil {
response.InternalError(c, "Failed to generate setup token URL: "+err.Error())
return
}
response.Success(c, result)
}
// ExchangeCodeRequest represents the request for exchanging auth code
type ExchangeCodeRequest struct {
SessionID string `json:"session_id" binding:"required"`
Code string `json:"code" binding:"required"`
ProxyID *int64 `json:"proxy_id"`
}
// ExchangeCode exchanges authorization code for tokens
// POST /api/v1/admin/accounts/exchange-code
func (h *OAuthHandler) ExchangeCode(c *gin.Context) {
var req ExchangeCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
tokenInfo, err := h.oauthService.ExchangeCode(c.Request.Context(), &service.ExchangeCodeInput{
SessionID: req.SessionID,
Code: req.Code,
ProxyID: req.ProxyID,
})
if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error())
return
}
response.Success(c, tokenInfo)
}
// ExchangeSetupTokenCode exchanges authorization code for setup token
// POST /api/v1/admin/accounts/exchange-setup-token-code
func (h *OAuthHandler) ExchangeSetupTokenCode(c *gin.Context) {
var req ExchangeCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
tokenInfo, err := h.oauthService.ExchangeCode(c.Request.Context(), &service.ExchangeCodeInput{
SessionID: req.SessionID,
Code: req.Code,
ProxyID: req.ProxyID,
})
if err != nil {
response.BadRequest(c, "Failed to exchange code: "+err.Error())
return
}
response.Success(c, tokenInfo)
}
// CookieAuthRequest represents the request for cookie-based authentication
type CookieAuthRequest struct {
SessionKey string `json:"code" binding:"required"` // Using 'code' field as sessionKey (frontend sends it this way)
ProxyID *int64 `json:"proxy_id"`
}
// CookieAuth performs OAuth using sessionKey (cookie-based auto-auth)
// POST /api/v1/admin/accounts/cookie-auth
func (h *OAuthHandler) CookieAuth(c *gin.Context) {
var req CookieAuthRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
tokenInfo, err := h.oauthService.CookieAuth(c.Request.Context(), &service.CookieAuthInput{
SessionKey: req.SessionKey,
ProxyID: req.ProxyID,
Scope: "full",
})
if err != nil {
response.BadRequest(c, "Cookie auth failed: "+err.Error())
return
}
response.Success(c, tokenInfo)
}
// SetupTokenCookieAuth performs OAuth using sessionKey for setup token (inference only)
// POST /api/v1/admin/accounts/setup-token-cookie-auth
func (h *OAuthHandler) SetupTokenCookieAuth(c *gin.Context) {
var req CookieAuthRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
tokenInfo, err := h.oauthService.CookieAuth(c.Request.Context(), &service.CookieAuthInput{
SessionKey: req.SessionKey,
ProxyID: req.ProxyID,
Scope: "inference",
})
if err != nil {
response.BadRequest(c, "Cookie auth failed: "+err.Error())
return
}
response.Success(c, tokenInfo)
}
// GetUsage handles getting account usage information
// GET /api/v1/admin/accounts/:id/usage
func (h *AccountHandler) GetUsage(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
usage, err := h.accountUsageService.GetUsage(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to get usage: "+err.Error())
return
}
response.Success(c, usage)
}
// ClearRateLimit handles clearing account rate limit status
// POST /api/v1/admin/accounts/:id/clear-rate-limit
func (h *AccountHandler) ClearRateLimit(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
err = h.rateLimitService.ClearRateLimit(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to clear rate limit: "+err.Error())
return
}
response.Success(c, gin.H{"message": "Rate limit cleared successfully"})
}
// GetTodayStats handles getting account today statistics
// GET /api/v1/admin/accounts/:id/today-stats
func (h *AccountHandler) GetTodayStats(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
stats, err := h.accountUsageService.GetTodayStats(c.Request.Context(), accountID)
if err != nil {
response.InternalError(c, "Failed to get today stats: "+err.Error())
return
}
response.Success(c, stats)
}
// SetSchedulableRequest represents the request body for setting schedulable status
type SetSchedulableRequest struct {
Schedulable bool `json:"schedulable"`
}
// SetSchedulable handles toggling account schedulable status
// POST /api/v1/admin/accounts/:id/schedulable
func (h *AccountHandler) SetSchedulable(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
var req SetSchedulableRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
account, err := h.adminService.SetAccountSchedulable(c.Request.Context(), accountID, req.Schedulable)
if err != nil {
response.InternalError(c, "Failed to update schedulable status: "+err.Error())
return
}
response.Success(c, account)
}

View File

@@ -0,0 +1,274 @@
package admin
import (
"strconv"
"sub2api/internal/pkg/response"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"sub2api/internal/service"
"time"
"github.com/gin-gonic/gin"
)
// DashboardHandler handles admin dashboard statistics
type DashboardHandler struct {
adminService service.AdminService
usageRepo *repository.UsageLogRepository
startTime time.Time // Server start time for uptime calculation
}
// NewDashboardHandler creates a new admin dashboard handler
func NewDashboardHandler(adminService service.AdminService, usageRepo *repository.UsageLogRepository) *DashboardHandler {
return &DashboardHandler{
adminService: adminService,
usageRepo: usageRepo,
startTime: time.Now(),
}
}
// parseTimeRange parses start_date, end_date query parameters
func parseTimeRange(c *gin.Context) (time.Time, time.Time) {
now := timezone.Now()
startDate := c.Query("start_date")
endDate := c.Query("end_date")
var startTime, endTime time.Time
if startDate != "" {
if t, err := timezone.ParseInLocation("2006-01-02", startDate); err == nil {
startTime = t
} else {
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
}
} else {
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
}
if endDate != "" {
if t, err := timezone.ParseInLocation("2006-01-02", endDate); err == nil {
endTime = t.Add(24 * time.Hour) // Include the end date
} else {
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
}
} else {
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
}
return startTime, endTime
}
// GetStats handles getting dashboard statistics
// GET /api/v1/admin/dashboard/stats
func (h *DashboardHandler) GetStats(c *gin.Context) {
stats, err := h.usageRepo.GetDashboardStats(c.Request.Context())
if err != nil {
response.Error(c, 500, "Failed to get dashboard statistics")
return
}
// Calculate uptime in seconds
uptime := int64(time.Since(h.startTime).Seconds())
response.Success(c, gin.H{
// 用户统计
"total_users": stats.TotalUsers,
"today_new_users": stats.TodayNewUsers,
"active_users": stats.ActiveUsers,
// API Key 统计
"total_api_keys": stats.TotalApiKeys,
"active_api_keys": stats.ActiveApiKeys,
// 账户统计
"total_accounts": stats.TotalAccounts,
"normal_accounts": stats.NormalAccounts,
"error_accounts": stats.ErrorAccounts,
"ratelimit_accounts": stats.RateLimitAccounts,
"overload_accounts": stats.OverloadAccounts,
// 累计 Token 使用统计
"total_requests": stats.TotalRequests,
"total_input_tokens": stats.TotalInputTokens,
"total_output_tokens": stats.TotalOutputTokens,
"total_cache_creation_tokens": stats.TotalCacheCreationTokens,
"total_cache_read_tokens": stats.TotalCacheReadTokens,
"total_tokens": stats.TotalTokens,
"total_cost": stats.TotalCost, // 标准计费
"total_actual_cost": stats.TotalActualCost, // 实际扣除
// 今日 Token 使用统计
"today_requests": stats.TodayRequests,
"today_input_tokens": stats.TodayInputTokens,
"today_output_tokens": stats.TodayOutputTokens,
"today_cache_creation_tokens": stats.TodayCacheCreationTokens,
"today_cache_read_tokens": stats.TodayCacheReadTokens,
"today_tokens": stats.TodayTokens,
"today_cost": stats.TodayCost, // 今日标准计费
"today_actual_cost": stats.TodayActualCost, // 今日实际扣除
// 系统运行统计
"average_duration_ms": stats.AverageDurationMs,
"uptime": uptime,
})
}
// GetRealtimeMetrics handles getting real-time system metrics
// GET /api/v1/admin/dashboard/realtime
func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
// Return mock data for now
response.Success(c, gin.H{
"active_requests": 0,
"requests_per_minute": 0,
"average_response_time": 0,
"error_rate": 0.0,
})
}
// GetUsageTrend handles getting usage trend data
// GET /api/v1/admin/dashboard/trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour)
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
trend, err := h.usageRepo.GetUsageTrend(c.Request.Context(), startTime, endTime, granularity)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
}
response.Success(c, gin.H{
"trend": trend,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
"granularity": granularity,
})
}
// GetModelStats handles getting model usage statistics
// GET /api/v1/admin/dashboard/models
// Query params: start_date, end_date (YYYY-MM-DD)
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
stats, err := h.usageRepo.GetModelStats(c.Request.Context(), startTime, endTime)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
}
response.Success(c, gin.H{
"models": stats,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
})
}
// GetApiKeyUsageTrend handles getting API key usage trend data
// GET /api/v1/admin/dashboard/api-keys-trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 5)
func (h *DashboardHandler) GetApiKeyUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
limitStr := c.DefaultQuery("limit", "5")
limit, err := strconv.Atoi(limitStr)
if err != nil || limit <= 0 {
limit = 5
}
trend, err := h.usageRepo.GetApiKeyUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
if err != nil {
response.Error(c, 500, "Failed to get API key usage trend")
return
}
response.Success(c, gin.H{
"trend": trend,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
"granularity": granularity,
})
}
// GetUserUsageTrend handles getting user usage trend data
// GET /api/v1/admin/dashboard/users-trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 12)
func (h *DashboardHandler) GetUserUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
limitStr := c.DefaultQuery("limit", "12")
limit, err := strconv.Atoi(limitStr)
if err != nil || limit <= 0 {
limit = 12
}
trend, err := h.usageRepo.GetUserUsageTrend(c.Request.Context(), startTime, endTime, granularity, limit)
if err != nil {
response.Error(c, 500, "Failed to get user usage trend")
return
}
response.Success(c, gin.H{
"trend": trend,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
"granularity": granularity,
})
}
// BatchUsersUsageRequest represents the request body for batch user usage stats
type BatchUsersUsageRequest struct {
UserIDs []int64 `json:"user_ids" binding:"required"`
}
// GetBatchUsersUsage handles getting usage stats for multiple users
// POST /api/v1/admin/dashboard/users-usage
func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
var req BatchUsersUsageRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.UserIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]interface{}{}})
return
}
stats, err := h.usageRepo.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs)
if err != nil {
response.Error(c, 500, "Failed to get user usage stats")
return
}
response.Success(c, gin.H{"stats": stats})
}
// BatchApiKeysUsageRequest represents the request body for batch api key usage stats
type BatchApiKeysUsageRequest struct {
ApiKeyIDs []int64 `json:"api_key_ids" binding:"required"`
}
// GetBatchApiKeysUsage handles getting usage stats for multiple API keys
// POST /api/v1/admin/dashboard/api-keys-usage
func (h *DashboardHandler) GetBatchApiKeysUsage(c *gin.Context) {
var req BatchApiKeysUsageRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.ApiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]interface{}{}})
return
}
stats, err := h.usageRepo.GetBatchApiKeyUsageStats(c.Request.Context(), req.ApiKeyIDs)
if err != nil {
response.Error(c, 500, "Failed to get API key usage stats")
return
}
response.Success(c, gin.H{"stats": stats})
}

View File

@@ -0,0 +1,233 @@
package admin
import (
"strconv"
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// GroupHandler handles admin group management
type GroupHandler struct {
adminService service.AdminService
}
// NewGroupHandler creates a new admin group handler
func NewGroupHandler(adminService service.AdminService) *GroupHandler {
return &GroupHandler{
adminService: adminService,
}
}
// CreateGroupRequest represents create group request
type CreateGroupRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini"`
RateMultiplier float64 `json:"rate_multiplier"`
IsExclusive bool `json:"is_exclusive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
DailyLimitUSD *float64 `json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
}
// UpdateGroupRequest represents update group request
type UpdateGroupRequest struct {
Name string `json:"name"`
Description string `json:"description"`
Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini"`
RateMultiplier *float64 `json:"rate_multiplier"`
IsExclusive *bool `json:"is_exclusive"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"`
DailyLimitUSD *float64 `json:"daily_limit_usd"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd"`
}
// List handles listing all groups with pagination
// GET /api/v1/admin/groups
func (h *GroupHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
platform := c.Query("platform")
status := c.Query("status")
isExclusiveStr := c.Query("is_exclusive")
var isExclusive *bool
if isExclusiveStr != "" {
val := isExclusiveStr == "true"
isExclusive = &val
}
groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, isExclusive)
if err != nil {
response.InternalError(c, "Failed to list groups: "+err.Error())
return
}
response.Paginated(c, groups, total, page, pageSize)
}
// GetAll handles getting all active groups without pagination
// GET /api/v1/admin/groups/all
func (h *GroupHandler) GetAll(c *gin.Context) {
platform := c.Query("platform")
var groups []model.Group
var err error
if platform != "" {
groups, err = h.adminService.GetAllGroupsByPlatform(c.Request.Context(), platform)
} else {
groups, err = h.adminService.GetAllGroups(c.Request.Context())
}
if err != nil {
response.InternalError(c, "Failed to get groups: "+err.Error())
return
}
response.Success(c, groups)
}
// GetByID handles getting a group by ID
// GET /api/v1/admin/groups/:id
func (h *GroupHandler) GetByID(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
group, err := h.adminService.GetGroup(c.Request.Context(), groupID)
if err != nil {
response.NotFound(c, "Group not found")
return
}
response.Success(c, group)
}
// Create handles creating a new group
// POST /api/v1/admin/groups
func (h *GroupHandler) Create(c *gin.Context) {
var req CreateGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{
Name: req.Name,
Description: req.Description,
Platform: req.Platform,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD,
})
if err != nil {
response.BadRequest(c, "Failed to create group: "+err.Error())
return
}
response.Success(c, group)
}
// Update handles updating a group
// PUT /api/v1/admin/groups/:id
func (h *GroupHandler) Update(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
var req UpdateGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{
Name: req.Name,
Description: req.Description,
Platform: req.Platform,
RateMultiplier: req.RateMultiplier,
IsExclusive: req.IsExclusive,
Status: req.Status,
SubscriptionType: req.SubscriptionType,
DailyLimitUSD: req.DailyLimitUSD,
WeeklyLimitUSD: req.WeeklyLimitUSD,
MonthlyLimitUSD: req.MonthlyLimitUSD,
})
if err != nil {
response.InternalError(c, "Failed to update group: "+err.Error())
return
}
response.Success(c, group)
}
// Delete handles deleting a group
// DELETE /api/v1/admin/groups/:id
func (h *GroupHandler) Delete(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
err = h.adminService.DeleteGroup(c.Request.Context(), groupID)
if err != nil {
response.InternalError(c, "Failed to delete group: "+err.Error())
return
}
response.Success(c, gin.H{"message": "Group deleted successfully"})
}
// GetStats handles getting group statistics
// GET /api/v1/admin/groups/:id/stats
func (h *GroupHandler) GetStats(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
// Return mock data for now
response.Success(c, gin.H{
"total_api_keys": 0,
"active_api_keys": 0,
"total_requests": 0,
"total_cost": 0.0,
})
_ = groupID // TODO: implement actual stats
}
// GetGroupAPIKeys handles getting API keys in a group
// GET /api/v1/admin/groups/:id/api-keys
func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
page, pageSize := response.ParsePagination(c)
keys, total, err := h.adminService.GetGroupAPIKeys(c.Request.Context(), groupID, page, pageSize)
if err != nil {
response.InternalError(c, "Failed to get group API keys: "+err.Error())
return
}
response.Paginated(c, keys, total, page, pageSize)
}

View File

@@ -0,0 +1,300 @@
package admin
import (
"strconv"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ProxyHandler handles admin proxy management
type ProxyHandler struct {
adminService service.AdminService
}
// NewProxyHandler creates a new admin proxy handler
func NewProxyHandler(adminService service.AdminService) *ProxyHandler {
return &ProxyHandler{
adminService: adminService,
}
}
// CreateProxyRequest represents create proxy request
type CreateProxyRequest struct {
Name string `json:"name" binding:"required"`
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
Host string `json:"host" binding:"required"`
Port int `json:"port" binding:"required,min=1,max=65535"`
Username string `json:"username"`
Password string `json:"password"`
}
// UpdateProxyRequest represents update proxy request
type UpdateProxyRequest struct {
Name string `json:"name"`
Protocol string `json:"protocol" binding:"omitempty,oneof=http https socks5"`
Host string `json:"host"`
Port int `json:"port" binding:"omitempty,min=1,max=65535"`
Username string `json:"username"`
Password string `json:"password"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
}
// List handles listing all proxies with pagination
// GET /api/v1/admin/proxies
func (h *ProxyHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
protocol := c.Query("protocol")
status := c.Query("status")
search := c.Query("search")
proxies, total, err := h.adminService.ListProxies(c.Request.Context(), page, pageSize, protocol, status, search)
if err != nil {
response.InternalError(c, "Failed to list proxies: "+err.Error())
return
}
response.Paginated(c, proxies, total, page, pageSize)
}
// GetAll handles getting all active proxies without pagination
// GET /api/v1/admin/proxies/all
// Optional query param: with_count=true to include account count per proxy
func (h *ProxyHandler) GetAll(c *gin.Context) {
withCount := c.Query("with_count") == "true"
if withCount {
proxies, err := h.adminService.GetAllProxiesWithAccountCount(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get proxies: "+err.Error())
return
}
response.Success(c, proxies)
return
}
proxies, err := h.adminService.GetAllProxies(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get proxies: "+err.Error())
return
}
response.Success(c, proxies)
}
// GetByID handles getting a proxy by ID
// GET /api/v1/admin/proxies/:id
func (h *ProxyHandler) GetByID(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
proxy, err := h.adminService.GetProxy(c.Request.Context(), proxyID)
if err != nil {
response.NotFound(c, "Proxy not found")
return
}
response.Success(c, proxy)
}
// Create handles creating a new proxy
// POST /api/v1/admin/proxies
func (h *ProxyHandler) Create(c *gin.Context) {
var req CreateProxyRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
Name: req.Name,
Protocol: req.Protocol,
Host: req.Host,
Port: req.Port,
Username: req.Username,
Password: req.Password,
})
if err != nil {
response.BadRequest(c, "Failed to create proxy: "+err.Error())
return
}
response.Success(c, proxy)
}
// Update handles updating a proxy
// PUT /api/v1/admin/proxies/:id
func (h *ProxyHandler) Update(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
var req UpdateProxyRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
proxy, err := h.adminService.UpdateProxy(c.Request.Context(), proxyID, &service.UpdateProxyInput{
Name: req.Name,
Protocol: req.Protocol,
Host: req.Host,
Port: req.Port,
Username: req.Username,
Password: req.Password,
Status: req.Status,
})
if err != nil {
response.InternalError(c, "Failed to update proxy: "+err.Error())
return
}
response.Success(c, proxy)
}
// Delete handles deleting a proxy
// DELETE /api/v1/admin/proxies/:id
func (h *ProxyHandler) Delete(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
err = h.adminService.DeleteProxy(c.Request.Context(), proxyID)
if err != nil {
response.InternalError(c, "Failed to delete proxy: "+err.Error())
return
}
response.Success(c, gin.H{"message": "Proxy deleted successfully"})
}
// Test handles testing proxy connectivity
// POST /api/v1/admin/proxies/:id/test
func (h *ProxyHandler) Test(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
result, err := h.adminService.TestProxy(c.Request.Context(), proxyID)
if err != nil {
response.InternalError(c, "Failed to test proxy: "+err.Error())
return
}
response.Success(c, result)
}
// GetStats handles getting proxy statistics
// GET /api/v1/admin/proxies/:id/stats
func (h *ProxyHandler) GetStats(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
// Return mock data for now
_ = proxyID
response.Success(c, gin.H{
"total_accounts": 0,
"active_accounts": 0,
"total_requests": 0,
"success_rate": 100.0,
"average_latency": 0,
})
}
// GetProxyAccounts handles getting accounts using a proxy
// GET /api/v1/admin/proxies/:id/accounts
func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid proxy ID")
return
}
page, pageSize := response.ParsePagination(c)
accounts, total, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID, page, pageSize)
if err != nil {
response.InternalError(c, "Failed to get proxy accounts: "+err.Error())
return
}
response.Paginated(c, accounts, total, page, pageSize)
}
// BatchCreateProxyItem represents a single proxy in batch create request
type BatchCreateProxyItem struct {
Protocol string `json:"protocol" binding:"required,oneof=http https socks5"`
Host string `json:"host" binding:"required"`
Port int `json:"port" binding:"required,min=1,max=65535"`
Username string `json:"username"`
Password string `json:"password"`
}
// BatchCreateRequest represents batch create proxies request
type BatchCreateRequest struct {
Proxies []BatchCreateProxyItem `json:"proxies" binding:"required,min=1"`
}
// BatchCreate handles batch creating proxies
// POST /api/v1/admin/proxies/batch
func (h *ProxyHandler) BatchCreate(c *gin.Context) {
var req BatchCreateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
created := 0
skipped := 0
for _, item := range req.Proxies {
// Check for duplicates (same host, port, username, password)
exists, err := h.adminService.CheckProxyExists(c.Request.Context(), item.Host, item.Port, item.Username, item.Password)
if err != nil {
response.InternalError(c, "Failed to check proxy existence: "+err.Error())
return
}
if exists {
skipped++
continue
}
// Create proxy with default name
_, err = h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{
Name: "default",
Protocol: item.Protocol,
Host: item.Host,
Port: item.Port,
Username: item.Username,
Password: item.Password,
})
if err != nil {
// If creation fails due to duplicate, count as skipped
skipped++
continue
}
created++
}
response.Success(c, gin.H{
"created": created,
"skipped": skipped,
})
}

View File

@@ -0,0 +1,219 @@
package admin
import (
"bytes"
"encoding/csv"
"fmt"
"strconv"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// RedeemHandler handles admin redeem code management
type RedeemHandler struct {
adminService service.AdminService
}
// NewRedeemHandler creates a new admin redeem handler
func NewRedeemHandler(adminService service.AdminService) *RedeemHandler {
return &RedeemHandler{
adminService: adminService,
}
}
// GenerateRedeemCodesRequest represents generate redeem codes request
type GenerateRedeemCodesRequest struct {
Count int `json:"count" binding:"required,min=1,max=100"`
Type string `json:"type" binding:"required,oneof=balance concurrency subscription"`
Value float64 `json:"value" binding:"min=0"`
GroupID *int64 `json:"group_id"` // 订阅类型必填
ValidityDays int `json:"validity_days"` // 订阅类型使用默认30天
}
// List handles listing all redeem codes with pagination
// GET /api/v1/admin/redeem-codes
func (h *RedeemHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
codeType := c.Query("type")
status := c.Query("status")
search := c.Query("search")
codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search)
if err != nil {
response.InternalError(c, "Failed to list redeem codes: "+err.Error())
return
}
response.Paginated(c, codes, total, page, pageSize)
}
// GetByID handles getting a redeem code by ID
// GET /api/v1/admin/redeem-codes/:id
func (h *RedeemHandler) GetByID(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid redeem code ID")
return
}
code, err := h.adminService.GetRedeemCode(c.Request.Context(), codeID)
if err != nil {
response.NotFound(c, "Redeem code not found")
return
}
response.Success(c, code)
}
// Generate handles generating new redeem codes
// POST /api/v1/admin/redeem-codes/generate
func (h *RedeemHandler) Generate(c *gin.Context) {
var req GenerateRedeemCodesRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
codes, err := h.adminService.GenerateRedeemCodes(c.Request.Context(), &service.GenerateRedeemCodesInput{
Count: req.Count,
Type: req.Type,
Value: req.Value,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
})
if err != nil {
response.InternalError(c, "Failed to generate redeem codes: "+err.Error())
return
}
response.Success(c, codes)
}
// Delete handles deleting a redeem code
// DELETE /api/v1/admin/redeem-codes/:id
func (h *RedeemHandler) Delete(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid redeem code ID")
return
}
err = h.adminService.DeleteRedeemCode(c.Request.Context(), codeID)
if err != nil {
response.InternalError(c, "Failed to delete redeem code: "+err.Error())
return
}
response.Success(c, gin.H{"message": "Redeem code deleted successfully"})
}
// BatchDelete handles batch deleting redeem codes
// POST /api/v1/admin/redeem-codes/batch-delete
func (h *RedeemHandler) BatchDelete(c *gin.Context) {
var req struct {
IDs []int64 `json:"ids" binding:"required,min=1"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
deleted, err := h.adminService.BatchDeleteRedeemCodes(c.Request.Context(), req.IDs)
if err != nil {
response.InternalError(c, "Failed to batch delete redeem codes: "+err.Error())
return
}
response.Success(c, gin.H{
"deleted": deleted,
"message": "Redeem codes deleted successfully",
})
}
// Expire handles expiring a redeem code
// POST /api/v1/admin/redeem-codes/:id/expire
func (h *RedeemHandler) Expire(c *gin.Context) {
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid redeem code ID")
return
}
code, err := h.adminService.ExpireRedeemCode(c.Request.Context(), codeID)
if err != nil {
response.InternalError(c, "Failed to expire redeem code: "+err.Error())
return
}
response.Success(c, code)
}
// GetStats handles getting redeem code statistics
// GET /api/v1/admin/redeem-codes/stats
func (h *RedeemHandler) GetStats(c *gin.Context) {
// Return mock data for now
response.Success(c, gin.H{
"total_codes": 0,
"active_codes": 0,
"used_codes": 0,
"expired_codes": 0,
"total_value_distributed": 0.0,
"by_type": gin.H{
"balance": 0,
"concurrency": 0,
"trial": 0,
},
})
}
// Export handles exporting redeem codes to CSV
// GET /api/v1/admin/redeem-codes/export
func (h *RedeemHandler) Export(c *gin.Context) {
codeType := c.Query("type")
status := c.Query("status")
// Get all codes without pagination (use large page size)
codes, _, err := h.adminService.ListRedeemCodes(c.Request.Context(), 1, 10000, codeType, status, "")
if err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
return
}
// Create CSV buffer
var buf bytes.Buffer
writer := csv.NewWriter(&buf)
// Write header
writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_at", "created_at"})
// Write data rows
for _, code := range codes {
usedBy := ""
if code.UsedBy != nil {
usedBy = fmt.Sprintf("%d", *code.UsedBy)
}
usedAt := ""
if code.UsedAt != nil {
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
}
writer.Write([]string{
fmt.Sprintf("%d", code.ID),
code.Code,
code.Type,
fmt.Sprintf("%.2f", code.Value),
code.Status,
usedBy,
usedAt,
code.CreatedAt.Format("2006-01-02 15:04:05"),
})
}
writer.Flush()
c.Header("Content-Type", "text/csv")
c.Header("Content-Disposition", "attachment; filename=redeem_codes.csv")
c.Data(200, "text/csv", buf.Bytes())
}

View File

@@ -0,0 +1,258 @@
package admin
import (
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// SettingHandler 系统设置处理器
type SettingHandler struct {
settingService *service.SettingService
emailService *service.EmailService
}
// NewSettingHandler 创建系统设置处理器
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService) *SettingHandler {
return &SettingHandler{
settingService: settingService,
emailService: emailService,
}
}
// GetSettings 获取所有系统设置
// GET /api/v1/admin/settings
func (h *SettingHandler) GetSettings(c *gin.Context) {
settings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get settings: "+err.Error())
return
}
response.Success(c, settings)
}
// UpdateSettingsRequest 更新设置请求
type UpdateSettingsRequest struct {
// 注册设置
RegistrationEnabled bool `json:"registration_enabled"`
EmailVerifyEnabled bool `json:"email_verify_enabled"`
// 邮件服务设置
SmtpHost string `json:"smtp_host"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password"`
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
// Cloudflare Turnstile 设置
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key"`
// OEM设置
SiteName string `json:"site_name"`
SiteLogo string `json:"site_logo"`
SiteSubtitle string `json:"site_subtitle"`
ApiBaseUrl string `json:"api_base_url"`
ContactInfo string `json:"contact_info"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
}
// UpdateSettings 更新系统设置
// PUT /api/v1/admin/settings
func (h *SettingHandler) UpdateSettings(c *gin.Context) {
var req UpdateSettingsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// 验证参数
if req.DefaultConcurrency < 1 {
req.DefaultConcurrency = 1
}
if req.DefaultBalance < 0 {
req.DefaultBalance = 0
}
if req.SmtpPort <= 0 {
req.SmtpPort = 587
}
settings := &model.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
SmtpHost: req.SmtpHost,
SmtpPort: req.SmtpPort,
SmtpUsername: req.SmtpUsername,
SmtpPassword: req.SmtpPassword,
SmtpFrom: req.SmtpFrom,
SmtpFromName: req.SmtpFromName,
SmtpUseTLS: req.SmtpUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
ApiBaseUrl: req.ApiBaseUrl,
ContactInfo: req.ContactInfo,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
}
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
response.InternalError(c, "Failed to update settings: "+err.Error())
return
}
// 重新获取设置返回
updatedSettings, err := h.settingService.GetAllSettings(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get updated settings: "+err.Error())
return
}
response.Success(c, updatedSettings)
}
// TestSmtpRequest 测试SMTP连接请求
type TestSmtpRequest struct {
SmtpHost string `json:"smtp_host" binding:"required"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password"`
SmtpUseTLS bool `json:"smtp_use_tls"`
}
// TestSmtpConnection 测试SMTP连接
// POST /api/v1/admin/settings/test-smtp
func (h *SettingHandler) TestSmtpConnection(c *gin.Context) {
var req TestSmtpRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.SmtpPort <= 0 {
req.SmtpPort = 587
}
// 如果未提供密码,从数据库获取已保存的密码
password := req.SmtpPassword
if password == "" {
savedConfig, err := h.emailService.GetSmtpConfig(c.Request.Context())
if err == nil && savedConfig != nil {
password = savedConfig.Password
}
}
config := &service.SmtpConfig{
Host: req.SmtpHost,
Port: req.SmtpPort,
Username: req.SmtpUsername,
Password: password,
UseTLS: req.SmtpUseTLS,
}
err := h.emailService.TestSmtpConnectionWithConfig(config)
if err != nil {
response.BadRequest(c, "SMTP connection test failed: "+err.Error())
return
}
response.Success(c, gin.H{"message": "SMTP connection successful"})
}
// SendTestEmailRequest 发送测试邮件请求
type SendTestEmailRequest struct {
Email string `json:"email" binding:"required,email"`
SmtpHost string `json:"smtp_host" binding:"required"`
SmtpPort int `json:"smtp_port"`
SmtpUsername string `json:"smtp_username"`
SmtpPassword string `json:"smtp_password"`
SmtpFrom string `json:"smtp_from_email"`
SmtpFromName string `json:"smtp_from_name"`
SmtpUseTLS bool `json:"smtp_use_tls"`
}
// SendTestEmail 发送测试邮件
// POST /api/v1/admin/settings/send-test-email
func (h *SettingHandler) SendTestEmail(c *gin.Context) {
var req SendTestEmailRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.SmtpPort <= 0 {
req.SmtpPort = 587
}
// 如果未提供密码,从数据库获取已保存的密码
password := req.SmtpPassword
if password == "" {
savedConfig, err := h.emailService.GetSmtpConfig(c.Request.Context())
if err == nil && savedConfig != nil {
password = savedConfig.Password
}
}
config := &service.SmtpConfig{
Host: req.SmtpHost,
Port: req.SmtpPort,
Username: req.SmtpUsername,
Password: password,
From: req.SmtpFrom,
FromName: req.SmtpFromName,
UseTLS: req.SmtpUseTLS,
}
siteName := h.settingService.GetSiteName(c.Request.Context())
subject := "[" + siteName + "] Test Email"
body := `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 30px; text-align: center; }
.content { padding: 40px 30px; text-align: center; }
.success { color: #10b981; font-size: 48px; margin-bottom: 20px; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>` + siteName + `</h1>
</div>
<div class="content">
<div class="success">✓</div>
<h2>Email Configuration Successful!</h2>
<p>This is a test email to verify your SMTP settings are working correctly.</p>
</div>
<div class="footer">
<p>This is an automated test message.</p>
</div>
</div>
</body>
</html>
`
if err := h.emailService.SendEmailWithConfig(config, req.Email, subject, body); err != nil {
response.BadRequest(c, "Failed to send test email: "+err.Error())
return
}
response.Success(c, gin.H{"message": "Test email sent successfully"})
}

View File

@@ -0,0 +1,266 @@
package admin
import (
"strconv"
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// toResponsePagination converts repository.PaginationResult to response.PaginationResult
func toResponsePagination(p *repository.PaginationResult) *response.PaginationResult {
if p == nil {
return nil
}
return &response.PaginationResult{
Total: p.Total,
Page: p.Page,
PageSize: p.PageSize,
Pages: p.Pages,
}
}
// SubscriptionHandler handles admin subscription management
type SubscriptionHandler struct {
subscriptionService *service.SubscriptionService
}
// NewSubscriptionHandler creates a new admin subscription handler
func NewSubscriptionHandler(subscriptionService *service.SubscriptionService) *SubscriptionHandler {
return &SubscriptionHandler{
subscriptionService: subscriptionService,
}
}
// AssignSubscriptionRequest represents assign subscription request
type AssignSubscriptionRequest struct {
UserID int64 `json:"user_id" binding:"required"`
GroupID int64 `json:"group_id" binding:"required"`
ValidityDays int `json:"validity_days"`
Notes string `json:"notes"`
}
// BulkAssignSubscriptionRequest represents bulk assign subscription request
type BulkAssignSubscriptionRequest struct {
UserIDs []int64 `json:"user_ids" binding:"required,min=1"`
GroupID int64 `json:"group_id" binding:"required"`
ValidityDays int `json:"validity_days"`
Notes string `json:"notes"`
}
// ExtendSubscriptionRequest represents extend subscription request
type ExtendSubscriptionRequest struct {
Days int `json:"days" binding:"required,min=1"`
}
// List handles listing all subscriptions with pagination and filters
// GET /api/v1/admin/subscriptions
func (h *SubscriptionHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
// Parse optional filters
var userID, groupID *int64
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = &id
}
}
if groupIDStr := c.Query("group_id"); groupIDStr != "" {
if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
groupID = &id
}
}
status := c.Query("status")
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
if err != nil {
response.InternalError(c, "Failed to list subscriptions: "+err.Error())
return
}
response.PaginatedWithResult(c, subscriptions, toResponsePagination(pagination))
}
// GetByID handles getting a subscription by ID
// GET /api/v1/admin/subscriptions/:id
func (h *SubscriptionHandler) GetByID(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid subscription ID")
return
}
subscription, err := h.subscriptionService.GetByID(c.Request.Context(), subscriptionID)
if err != nil {
response.NotFound(c, "Subscription not found")
return
}
response.Success(c, subscription)
}
// GetProgress handles getting subscription usage progress
// GET /api/v1/admin/subscriptions/:id/progress
func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid subscription ID")
return
}
progress, err := h.subscriptionService.GetSubscriptionProgress(c.Request.Context(), subscriptionID)
if err != nil {
response.NotFound(c, "Subscription not found")
return
}
response.Success(c, progress)
}
// Assign handles assigning a subscription to a user
// POST /api/v1/admin/subscriptions/assign
func (h *SubscriptionHandler) Assign(c *gin.Context) {
var req AssignSubscriptionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Get admin user ID from context
adminID := getAdminIDFromContext(c)
subscription, err := h.subscriptionService.AssignSubscription(c.Request.Context(), &service.AssignSubscriptionInput{
UserID: req.UserID,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
AssignedBy: adminID,
Notes: req.Notes,
})
if err != nil {
response.BadRequest(c, "Failed to assign subscription: "+err.Error())
return
}
response.Success(c, subscription)
}
// BulkAssign handles bulk assigning subscriptions to multiple users
// POST /api/v1/admin/subscriptions/bulk-assign
func (h *SubscriptionHandler) BulkAssign(c *gin.Context) {
var req BulkAssignSubscriptionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Get admin user ID from context
adminID := getAdminIDFromContext(c)
result, err := h.subscriptionService.BulkAssignSubscription(c.Request.Context(), &service.BulkAssignSubscriptionInput{
UserIDs: req.UserIDs,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
AssignedBy: adminID,
Notes: req.Notes,
})
if err != nil {
response.InternalError(c, "Failed to bulk assign subscriptions: "+err.Error())
return
}
response.Success(c, result)
}
// Extend handles extending a subscription
// POST /api/v1/admin/subscriptions/:id/extend
func (h *SubscriptionHandler) Extend(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid subscription ID")
return
}
var req ExtendSubscriptionRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days)
if err != nil {
response.InternalError(c, "Failed to extend subscription: "+err.Error())
return
}
response.Success(c, subscription)
}
// Revoke handles revoking a subscription
// DELETE /api/v1/admin/subscriptions/:id
func (h *SubscriptionHandler) Revoke(c *gin.Context) {
subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid subscription ID")
return
}
err = h.subscriptionService.RevokeSubscription(c.Request.Context(), subscriptionID)
if err != nil {
response.InternalError(c, "Failed to revoke subscription: "+err.Error())
return
}
response.Success(c, gin.H{"message": "Subscription revoked successfully"})
}
// ListByGroup handles listing subscriptions for a specific group
// GET /api/v1/admin/groups/:id/subscriptions
func (h *SubscriptionHandler) ListByGroup(c *gin.Context) {
groupID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid group ID")
return
}
page, pageSize := response.ParsePagination(c)
subscriptions, pagination, err := h.subscriptionService.ListGroupSubscriptions(c.Request.Context(), groupID, page, pageSize)
if err != nil {
response.InternalError(c, "Failed to list group subscriptions: "+err.Error())
return
}
response.PaginatedWithResult(c, subscriptions, toResponsePagination(pagination))
}
// ListByUser handles listing subscriptions for a specific user
// GET /api/v1/admin/users/:id/subscriptions
func (h *SubscriptionHandler) ListByUser(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), userID)
if err != nil {
response.InternalError(c, "Failed to list user subscriptions: "+err.Error())
return
}
response.Success(c, subscriptions)
}
// Helper function to get admin ID from context
func getAdminIDFromContext(c *gin.Context) int64 {
if user, exists := c.Get("user"); exists {
if u, ok := user.(*model.User); ok && u != nil {
return u.ID
}
}
return 0
}

View File

@@ -0,0 +1,82 @@
package admin
import (
"net/http"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// SystemHandler handles system-related operations
type SystemHandler struct {
updateSvc *service.UpdateService
}
// NewSystemHandler creates a new SystemHandler
func NewSystemHandler(rdb *redis.Client, version, buildType string) *SystemHandler {
return &SystemHandler{
updateSvc: service.NewUpdateService(rdb, version, buildType),
}
}
// GetVersion returns the current version
// GET /api/v1/admin/system/version
func (h *SystemHandler) GetVersion(c *gin.Context) {
info, _ := h.updateSvc.CheckUpdate(c.Request.Context(), false)
response.Success(c, gin.H{
"version": info.CurrentVersion,
})
}
// CheckUpdates checks for available updates
// GET /api/v1/admin/system/check-updates
func (h *SystemHandler) CheckUpdates(c *gin.Context) {
force := c.Query("force") == "true"
info, err := h.updateSvc.CheckUpdate(c.Request.Context(), force)
if err != nil {
response.Error(c, http.StatusInternalServerError, err.Error())
return
}
response.Success(c, info)
}
// PerformUpdate downloads and applies the update
// POST /api/v1/admin/system/update
func (h *SystemHandler) PerformUpdate(c *gin.Context) {
if err := h.updateSvc.PerformUpdate(c.Request.Context()); err != nil {
response.Error(c, http.StatusInternalServerError, err.Error())
return
}
response.Success(c, gin.H{
"message": "Update completed. Please restart the service.",
"need_restart": true,
})
}
// Rollback restores the previous version
// POST /api/v1/admin/system/rollback
func (h *SystemHandler) Rollback(c *gin.Context) {
if err := h.updateSvc.Rollback(); err != nil {
response.Error(c, http.StatusInternalServerError, err.Error())
return
}
response.Success(c, gin.H{
"message": "Rollback completed. Please restart the service.",
"need_restart": true,
})
}
// RestartService restarts the systemd service
// POST /api/v1/admin/system/restart
func (h *SystemHandler) RestartService(c *gin.Context) {
if err := h.updateSvc.RestartService(); err != nil {
response.Error(c, http.StatusInternalServerError, err.Error())
return
}
response.Success(c, gin.H{
"message": "Service restart initiated",
})
}

View File

@@ -0,0 +1,262 @@
package admin
import (
"strconv"
"time"
"sub2api/internal/pkg/response"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// UsageHandler handles admin usage-related requests
type UsageHandler struct {
usageRepo *repository.UsageLogRepository
apiKeyRepo *repository.ApiKeyRepository
usageService *service.UsageService
adminService service.AdminService
}
// NewUsageHandler creates a new admin usage handler
func NewUsageHandler(
usageRepo *repository.UsageLogRepository,
apiKeyRepo *repository.ApiKeyRepository,
usageService *service.UsageService,
adminService service.AdminService,
) *UsageHandler {
return &UsageHandler{
usageRepo: usageRepo,
apiKeyRepo: apiKeyRepo,
usageService: usageService,
adminService: adminService,
}
}
// List handles listing all usage records with filters
// GET /api/v1/admin/usage
func (h *UsageHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
// Parse filters
var userID, apiKeyID int64
if userIDStr := c.Query("user_id"); userIDStr != "" {
id, err := strconv.ParseInt(userIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user_id")
return
}
userID = id
}
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid api_key_id")
return
}
apiKeyID = id
}
// Parse date range
var startTime, endTime *time.Time
if startDateStr := c.Query("start_date"); startDateStr != "" {
t, err := timezone.ParseInLocation("2006-01-02", startDateStr)
if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return
}
startTime = &t
}
if endDateStr := c.Query("end_date"); endDateStr != "" {
t, err := timezone.ParseInLocation("2006-01-02", endDateStr)
if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
// Set end time to end of day
t = t.Add(24*time.Hour - time.Nanosecond)
endTime = &t
}
params := repository.PaginationParams{Page: page, PageSize: pageSize}
filters := repository.UsageLogFilters{
UserID: userID,
ApiKeyID: apiKeyID,
StartTime: startTime,
EndTime: endTime,
}
records, result, err := h.usageRepo.ListWithFilters(c.Request.Context(), params, filters)
if err != nil {
response.InternalError(c, "Failed to list usage records: "+err.Error())
return
}
response.Paginated(c, records, result.Total, page, pageSize)
}
// Stats handles getting usage statistics with filters
// GET /api/v1/admin/usage/stats
func (h *UsageHandler) Stats(c *gin.Context) {
// Parse filters
var userID, apiKeyID int64
if userIDStr := c.Query("user_id"); userIDStr != "" {
id, err := strconv.ParseInt(userIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user_id")
return
}
userID = id
}
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid api_key_id")
return
}
apiKeyID = id
}
// Parse date range
now := timezone.Now()
var startTime, endTime time.Time
startDateStr := c.Query("start_date")
endDateStr := c.Query("end_date")
if startDateStr != "" && endDateStr != "" {
var err error
startTime, err = timezone.ParseInLocation("2006-01-02", startDateStr)
if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return
}
endTime, err = timezone.ParseInLocation("2006-01-02", endDateStr)
if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
} else {
period := c.DefaultQuery("period", "today")
switch period {
case "today":
startTime = timezone.StartOfDay(now)
case "week":
startTime = now.AddDate(0, 0, -7)
case "month":
startTime = now.AddDate(0, -1, 0)
default:
startTime = timezone.StartOfDay(now)
}
endTime = now
}
if apiKeyID > 0 {
stats, err := h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
return
}
response.Success(c, stats)
return
}
if userID > 0 {
stats, err := h.usageService.GetStatsByUser(c.Request.Context(), userID, startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
return
}
response.Success(c, stats)
return
}
// Get global stats
stats, err := h.usageRepo.GetGlobalStats(c.Request.Context(), startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
return
}
response.Success(c, stats)
}
// SearchUsers handles searching users by email keyword
// GET /api/v1/admin/usage/search-users
func (h *UsageHandler) SearchUsers(c *gin.Context) {
keyword := c.Query("q")
if keyword == "" {
response.Success(c, []interface{}{})
return
}
// Limit to 30 results
users, _, err := h.adminService.ListUsers(c.Request.Context(), 1, 30, "", "", keyword)
if err != nil {
response.InternalError(c, "Failed to search users: "+err.Error())
return
}
// Return simplified user list (only id and email)
type SimpleUser struct {
ID int64 `json:"id"`
Email string `json:"email"`
}
result := make([]SimpleUser, len(users))
for i, u := range users {
result[i] = SimpleUser{
ID: u.ID,
Email: u.Email,
}
}
response.Success(c, result)
}
// SearchApiKeys handles searching API keys by user
// GET /api/v1/admin/usage/search-api-keys
func (h *UsageHandler) SearchApiKeys(c *gin.Context) {
userIDStr := c.Query("user_id")
keyword := c.Query("q")
var userID int64
if userIDStr != "" {
id, err := strconv.ParseInt(userIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user_id")
return
}
userID = id
}
keys, err := h.apiKeyRepo.SearchApiKeys(c.Request.Context(), userID, keyword, 30)
if err != nil {
response.InternalError(c, "Failed to search API keys: "+err.Error())
return
}
// Return simplified API key list (only id and name)
type SimpleApiKey struct {
ID int64 `json:"id"`
Name string `json:"name"`
UserID int64 `json:"user_id"`
}
result := make([]SimpleApiKey, len(keys))
for i, k := range keys {
result[i] = SimpleApiKey{
ID: k.ID,
Name: k.Name,
UserID: k.UserID,
}
}
response.Success(c, result)
}

View File

@@ -0,0 +1,221 @@
package admin
import (
"strconv"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// UserHandler handles admin user management
type UserHandler struct {
adminService service.AdminService
}
// NewUserHandler creates a new admin user handler
func NewUserHandler(adminService service.AdminService) *UserHandler {
return &UserHandler{
adminService: adminService,
}
}
// CreateUserRequest represents admin create user request
type CreateUserRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
AllowedGroups []int64 `json:"allowed_groups"`
}
// UpdateUserRequest represents admin update user request
// 使用指针类型来区分"未提供"和"设置为0"
type UpdateUserRequest struct {
Email string `json:"email" binding:"omitempty,email"`
Password string `json:"password" binding:"omitempty,min=6"`
Balance *float64 `json:"balance"`
Concurrency *int `json:"concurrency"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
AllowedGroups *[]int64 `json:"allowed_groups"`
}
// UpdateBalanceRequest represents balance update request
type UpdateBalanceRequest struct {
Balance float64 `json:"balance" binding:"required"`
Operation string `json:"operation" binding:"required,oneof=set add subtract"`
}
// List handles listing all users with pagination
// GET /api/v1/admin/users
func (h *UserHandler) List(c *gin.Context) {
page, pageSize := response.ParsePagination(c)
status := c.Query("status")
role := c.Query("role")
search := c.Query("search")
users, total, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, status, role, search)
if err != nil {
response.InternalError(c, "Failed to list users: "+err.Error())
return
}
response.Paginated(c, users, total, page, pageSize)
}
// GetByID handles getting a user by ID
// GET /api/v1/admin/users/:id
func (h *UserHandler) GetByID(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
user, err := h.adminService.GetUser(c.Request.Context(), userID)
if err != nil {
response.NotFound(c, "User not found")
return
}
response.Success(c, user)
}
// Create handles creating a new user
// POST /api/v1/admin/users
func (h *UserHandler) Create(c *gin.Context) {
var req CreateUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{
Email: req.Email,
Password: req.Password,
Balance: req.Balance,
Concurrency: req.Concurrency,
AllowedGroups: req.AllowedGroups,
})
if err != nil {
response.BadRequest(c, "Failed to create user: "+err.Error())
return
}
response.Success(c, user)
}
// Update handles updating a user
// PUT /api/v1/admin/users/:id
func (h *UserHandler) Update(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
var req UpdateUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// 使用指针类型直接传递nil 表示未提供该字段
user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{
Email: req.Email,
Password: req.Password,
Balance: req.Balance,
Concurrency: req.Concurrency,
Status: req.Status,
AllowedGroups: req.AllowedGroups,
})
if err != nil {
response.InternalError(c, "Failed to update user: "+err.Error())
return
}
response.Success(c, user)
}
// Delete handles deleting a user
// DELETE /api/v1/admin/users/:id
func (h *UserHandler) Delete(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
err = h.adminService.DeleteUser(c.Request.Context(), userID)
if err != nil {
response.InternalError(c, "Failed to delete user: "+err.Error())
return
}
response.Success(c, gin.H{"message": "User deleted successfully"})
}
// UpdateBalance handles updating user balance
// POST /api/v1/admin/users/:id/balance
func (h *UserHandler) UpdateBalance(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
var req UpdateBalanceRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation)
if err != nil {
response.InternalError(c, "Failed to update balance: "+err.Error())
return
}
response.Success(c, user)
}
// GetUserAPIKeys handles getting user's API keys
// GET /api/v1/admin/users/:id/api-keys
func (h *UserHandler) GetUserAPIKeys(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
page, pageSize := response.ParsePagination(c)
keys, total, err := h.adminService.GetUserAPIKeys(c.Request.Context(), userID, page, pageSize)
if err != nil {
response.InternalError(c, "Failed to get user API keys: "+err.Error())
return
}
response.Paginated(c, keys, total, page, pageSize)
}
// GetUserUsage handles getting user's usage statistics
// GET /api/v1/admin/users/:id/usage
func (h *UserHandler) GetUserUsage(c *gin.Context) {
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid user ID")
return
}
period := c.DefaultQuery("period", "month")
stats, err := h.adminService.GetUserUsageStats(c.Request.Context(), userID, period)
if err != nil {
response.InternalError(c, "Failed to get user usage: "+err.Error())
return
}
response.Success(c, stats)
}

View File

@@ -0,0 +1,235 @@
package handler
import (
"strconv"
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// APIKeyHandler handles API key-related requests
type APIKeyHandler struct {
apiKeyService *service.ApiKeyService
}
// NewAPIKeyHandler creates a new APIKeyHandler
func NewAPIKeyHandler(apiKeyService *service.ApiKeyService) *APIKeyHandler {
return &APIKeyHandler{
apiKeyService: apiKeyService,
}
}
// CreateAPIKeyRequest represents the create API key request payload
type CreateAPIKeyRequest struct {
Name string `json:"name" binding:"required"`
GroupID *int64 `json:"group_id"` // nullable
CustomKey *string `json:"custom_key"` // 可选的自定义key
}
// UpdateAPIKeyRequest represents the update API key request payload
type UpdateAPIKeyRequest struct {
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
}
// List handles listing user's API keys with pagination
// GET /api/v1/api-keys
func (h *APIKeyHandler) List(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
page, pageSize := response.ParsePagination(c)
params := repository.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params)
if err != nil {
response.InternalError(c, "Failed to list API keys: "+err.Error())
return
}
response.Paginated(c, keys, result.Total, page, pageSize)
}
// GetByID handles getting a single API key
// GET /api/v1/api-keys/:id
func (h *APIKeyHandler) GetByID(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid key ID")
return
}
key, err := h.apiKeyService.GetByID(c.Request.Context(), keyID)
if err != nil {
response.NotFound(c, "API key not found")
return
}
// 验证所有权
if key.UserID != user.ID {
response.Forbidden(c, "Not authorized to access this key")
return
}
response.Success(c, key)
}
// Create handles creating a new API key
// POST /api/v1/api-keys
func (h *APIKeyHandler) Create(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
var req CreateAPIKeyRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
svcReq := service.CreateApiKeyRequest{
Name: req.Name,
GroupID: req.GroupID,
CustomKey: req.CustomKey,
}
key, err := h.apiKeyService.Create(c.Request.Context(), user.ID, svcReq)
if err != nil {
response.InternalError(c, "Failed to create API key: "+err.Error())
return
}
response.Success(c, key)
}
// Update handles updating an API key
// PUT /api/v1/api-keys/:id
func (h *APIKeyHandler) Update(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid key ID")
return
}
var req UpdateAPIKeyRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
svcReq := service.UpdateApiKeyRequest{}
if req.Name != "" {
svcReq.Name = &req.Name
}
svcReq.GroupID = req.GroupID
if req.Status != "" {
svcReq.Status = &req.Status
}
key, err := h.apiKeyService.Update(c.Request.Context(), keyID, user.ID, svcReq)
if err != nil {
response.InternalError(c, "Failed to update API key: "+err.Error())
return
}
response.Success(c, key)
}
// Delete handles deleting an API key
// DELETE /api/v1/api-keys/:id
func (h *APIKeyHandler) Delete(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid key ID")
return
}
err = h.apiKeyService.Delete(c.Request.Context(), keyID, user.ID)
if err != nil {
response.InternalError(c, "Failed to delete API key: "+err.Error())
return
}
response.Success(c, gin.H{"message": "API key deleted successfully"})
}
// GetAvailableGroups 获取用户可以绑定的分组列表
// GET /api/v1/groups/available
func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
groups, err := h.apiKeyService.GetAvailableGroups(c.Request.Context(), user.ID)
if err != nil {
response.InternalError(c, "Failed to get available groups: "+err.Error())
return
}
response.Success(c, groups)
}

View File

@@ -0,0 +1,158 @@
package handler
import (
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AuthHandler handles authentication-related requests
type AuthHandler struct {
authService *service.AuthService
}
// NewAuthHandler creates a new AuthHandler
func NewAuthHandler(authService *service.AuthService) *AuthHandler {
return &AuthHandler{
authService: authService,
}
}
// RegisterRequest represents the registration request payload
type RegisterRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
VerifyCode string `json:"verify_code"`
TurnstileToken string `json:"turnstile_token"`
}
// SendVerifyCodeRequest 发送验证码请求
type SendVerifyCodeRequest struct {
Email string `json:"email" binding:"required,email"`
TurnstileToken string `json:"turnstile_token"`
}
// SendVerifyCodeResponse 发送验证码响应
type SendVerifyCodeResponse struct {
Message string `json:"message"`
Countdown int `json:"countdown"` // 倒计时秒数
}
// LoginRequest represents the login request payload
type LoginRequest struct {
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required"`
TurnstileToken string `json:"turnstile_token"`
}
// AuthResponse 认证响应格式(匹配前端期望)
type AuthResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
User *model.User `json:"user"`
}
// Register handles user registration
// POST /api/v1/auth/register
func (h *AuthHandler) Register(c *gin.Context) {
var req RegisterRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过)
if req.VerifyCode == "" {
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
return
}
}
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode)
if err != nil {
response.BadRequest(c, "Registration failed: "+err.Error())
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
User: user,
})
}
// SendVerifyCode 发送邮箱验证码
// POST /api/v1/auth/send-verify-code
func (h *AuthHandler) SendVerifyCode(c *gin.Context) {
var req SendVerifyCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Turnstile 验证
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
return
}
result, err := h.authService.SendVerifyCodeAsync(c.Request.Context(), req.Email)
if err != nil {
response.BadRequest(c, "Failed to send verification code: "+err.Error())
return
}
response.Success(c, SendVerifyCodeResponse{
Message: "Verification code sent successfully",
Countdown: result.Countdown,
})
}
// Login handles user login
// POST /api/v1/auth/login
func (h *AuthHandler) Login(c *gin.Context) {
var req LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
// Turnstile 验证
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, c.ClientIP()); err != nil {
response.BadRequest(c, "Turnstile verification failed: "+err.Error())
return
}
token, user, err := h.authService.Login(c.Request.Context(), req.Email, req.Password)
if err != nil {
response.Unauthorized(c, "Login failed: "+err.Error())
return
}
response.Success(c, AuthResponse{
AccessToken: token,
TokenType: "Bearer",
User: user,
})
}
// GetCurrentUser handles getting current authenticated user
// GET /api/v1/auth/me
func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
response.Success(c, user)
}

View File

@@ -0,0 +1,445 @@
package handler
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"time"
"sub2api/internal/middleware"
"sub2api/internal/model"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
const (
// Maximum wait time for concurrency slot
maxConcurrencyWait = 60 * time.Second
// Ping interval during wait
pingInterval = 5 * time.Second
)
// GatewayHandler handles API gateway requests
type GatewayHandler struct {
gatewayService *service.GatewayService
userService *service.UserService
concurrencyService *service.ConcurrencyService
billingCacheService *service.BillingCacheService
}
// NewGatewayHandler creates a new GatewayHandler
func NewGatewayHandler(gatewayService *service.GatewayService, userService *service.UserService, concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService) *GatewayHandler {
return &GatewayHandler{
gatewayService: gatewayService,
userService: userService,
concurrencyService: concurrencyService,
billingCacheService: billingCacheService,
}
}
// Messages handles Claude API compatible messages endpoint
// POST /v1/messages
func (h *GatewayHandler) Messages(c *gin.Context) {
// 从context获取apiKey和userApiKeyAuth中间件已设置
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
user, ok := middleware.GetUserFromContext(c)
if !ok {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
return
}
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
if err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
return
}
if len(body) == 0 {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return
}
// 解析请求获取模型名和stream
var req struct {
Model string `json:"model"`
Stream bool `json:"stream"`
}
if err := json.Unmarshal(body, &req); err != nil {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// Track if we've started streaming (for error handling)
streamStarted := false
// 获取订阅信息可能为nil- 提前获取用于后续检查
subscription, _ := middleware.GetSubscriptionFromContext(c)
// 0. 检查wait队列是否已满
maxWait := service.CalculateMaxWait(user.Concurrency)
canWait, err := h.concurrencyService.IncrementWaitCount(c.Request.Context(), user.ID, maxWait)
if err != nil {
log.Printf("Increment wait count failed: %v", err)
// On error, allow request to proceed
} else if !canWait {
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
// 确保在函数退出时减少wait计数
defer h.concurrencyService.DecrementWaitCount(c.Request.Context(), user.ID)
// 1. 首先获取用户并发槽位
userReleaseFunc, err := h.acquireUserSlotWithWait(c, user, req.Stream, &streamStarted)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
if userReleaseFunc != nil {
defer userReleaseFunc()
}
// 2. 【新增】Wait后二次检查余额/订阅
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), user, apiKey, apiKey.Group, subscription); err != nil {
log.Printf("Billing eligibility check failed after wait: %v", err)
h.handleStreamingAwareError(c, http.StatusForbidden, "billing_error", err.Error(), streamStarted)
return
}
// 计算粘性会话hash
sessionHash := h.gatewayService.GenerateSessionHash(body)
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, req.Model)
if err != nil {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
return
}
// 3. 获取账号并发槽位
accountReleaseFunc, err := h.acquireAccountSlotWithWait(c, account, req.Stream, &streamStarted)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountReleaseFunc != nil {
defer accountReleaseFunc()
}
// 转发请求
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
if err != nil {
// 错误响应已在Forward中处理这里只记录日志
log.Printf("Forward request failed: %v", err)
return
}
// 异步记录使用量subscription已在函数开头获取
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
ApiKey: apiKey,
User: user,
Account: account,
Subscription: subscription,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}()
}
// acquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary
// For streaming requests, sends ping events during the wait
// streamStarted is updated if streaming response has begun
func (h *GatewayHandler) acquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency)
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
// Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted)
}
// acquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary
// For streaming requests, sends ping events during the wait
// streamStarted is updated if streaming response has begun
func (h *GatewayHandler) acquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) {
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
// Need to wait - handle streaming ping if needed
return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted)
}
// concurrencyError represents a concurrency limit error with context
type concurrencyError struct {
SlotType string
IsTimeout bool
}
func (e *concurrencyError) Error() string {
if e.IsTimeout {
return fmt.Sprintf("timeout waiting for %s concurrency slot", e.SlotType)
}
return fmt.Sprintf("%s concurrency limit reached", e.SlotType)
}
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests
// Note: For streaming requests, we send ping to keep the connection alive.
// streamStarted pointer is updated when streaming begins (for proper error handling by caller)
func (h *GatewayHandler) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
defer cancel()
// For streaming requests, set up SSE headers for ping
var flusher http.Flusher
if isStream {
var ok bool
flusher, ok = c.Writer.(http.Flusher)
if !ok {
return nil, fmt.Errorf("streaming not supported")
}
}
pingTicker := time.NewTicker(pingInterval)
defer pingTicker.Stop()
pollTicker := time.NewTicker(100 * time.Millisecond)
defer pollTicker.Stop()
for {
select {
case <-ctx.Done():
return nil, &concurrencyError{
SlotType: slotType,
IsTimeout: true,
}
case <-pingTicker.C:
// Send ping for streaming requests to keep connection alive
if isStream && flusher != nil {
// Set headers on first ping (lazy initialization)
if !*streamStarted {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
*streamStarted = true
}
fmt.Fprintf(c.Writer, "data: {\"type\": \"ping\"}\n\n")
flusher.Flush()
}
case <-pollTicker.C:
// Try to acquire slot
var result *service.AcquireResult
var err error
if slotType == "user" {
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
} else {
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
}
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
}
}
}
}
// Models handles listing available models
// GET /v1/models
func (h *GatewayHandler) Models(c *gin.Context) {
models := []gin.H{
{
"id": "claude-opus-4-5-20251101",
"type": "model",
"display_name": "Claude Opus 4.5",
"created_at": "2025-11-01T00:00:00Z",
},
{
"id": "claude-sonnet-4-5-20250929",
"type": "model",
"display_name": "Claude Sonnet 4.5",
"created_at": "2025-09-29T00:00:00Z",
},
{
"id": "claude-haiku-4-5-20251001",
"type": "model",
"display_name": "Claude Haiku 4.5",
"created_at": "2025-10-01T00:00:00Z",
},
}
c.JSON(http.StatusOK, gin.H{
"data": models,
"object": "list",
})
}
// Usage handles getting account balance for CC Switch integration
// GET /v1/usage
func (h *GatewayHandler) Usage(c *gin.Context) {
apiKey, ok := middleware.GetApiKeyFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
user, ok := middleware.GetUserFromContext(c)
if !ok {
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
return
}
// 订阅模式:返回订阅限额信息
if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() {
subscription, ok := middleware.GetSubscriptionFromContext(c)
if !ok {
h.errorResponse(c, http.StatusForbidden, "subscription_error", "No active subscription")
return
}
remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription)
c.JSON(http.StatusOK, gin.H{
"isValid": true,
"planName": apiKey.Group.Name,
"remaining": remaining,
"unit": "USD",
})
return
}
// 余额模式:返回钱包余额
latestUser, err := h.userService.GetByID(c.Request.Context(), user.ID)
if err != nil {
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info")
return
}
c.JSON(http.StatusOK, gin.H{
"isValid": true,
"planName": "钱包余额",
"remaining": latestUser.Balance,
"unit": "USD",
})
}
// calculateSubscriptionRemaining 计算订阅剩余可用额度
// 逻辑:
// 1. 如果日/周/月任一限额达到100%返回0
// 2. 否则返回所有已配置周期中剩余额度的最小值
func (h *GatewayHandler) calculateSubscriptionRemaining(group *model.Group, sub *model.UserSubscription) float64 {
var remainingValues []float64
// 检查日限额
if group.HasDailyLimit() {
remaining := *group.DailyLimitUSD - sub.DailyUsageUSD
if remaining <= 0 {
return 0
}
remainingValues = append(remainingValues, remaining)
}
// 检查周限额
if group.HasWeeklyLimit() {
remaining := *group.WeeklyLimitUSD - sub.WeeklyUsageUSD
if remaining <= 0 {
return 0
}
remainingValues = append(remainingValues, remaining)
}
// 检查月限额
if group.HasMonthlyLimit() {
remaining := *group.MonthlyLimitUSD - sub.MonthlyUsageUSD
if remaining <= 0 {
return 0
}
remainingValues = append(remainingValues, remaining)
}
// 如果没有配置任何限额,返回-1表示无限制
if len(remainingValues) == 0 {
return -1
}
// 返回最小值
min := remainingValues[0]
for _, v := range remainingValues[1:] {
if v < min {
min = v
}
}
return min
}
// handleConcurrencyError handles concurrency-related errors with proper 429 response
func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
}
// handleStreamingAwareError handles errors that may occur after streaming has started
func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
if streamStarted {
// Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher)
if ok {
// Send error event in SSE format
errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
fmt.Fprint(c.Writer, errorEvent)
flusher.Flush()
}
return
}
// Normal case: return JSON response with proper status code
h.errorResponse(c, status, errType, message)
}
// errorResponse 返回Claude API格式的错误响应
func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{
"type": errType,
"message": message,
},
})
}

View File

@@ -0,0 +1,70 @@
package handler
import (
"sub2api/internal/handler/admin"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// AdminHandlers contains all admin-related HTTP handlers
type AdminHandlers struct {
Dashboard *admin.DashboardHandler
User *admin.UserHandler
Group *admin.GroupHandler
Account *admin.AccountHandler
OAuth *admin.OAuthHandler
Proxy *admin.ProxyHandler
Redeem *admin.RedeemHandler
Setting *admin.SettingHandler
System *admin.SystemHandler
Subscription *admin.SubscriptionHandler
Usage *admin.UsageHandler
}
// Handlers contains all HTTP handlers
type Handlers struct {
Auth *AuthHandler
User *UserHandler
APIKey *APIKeyHandler
Usage *UsageHandler
Redeem *RedeemHandler
Subscription *SubscriptionHandler
Admin *AdminHandlers
Gateway *GatewayHandler
Setting *SettingHandler
}
// BuildInfo contains build-time information
type BuildInfo struct {
Version string
BuildType string // "source" for manual builds, "release" for CI builds
}
// NewHandlers creates a new Handlers instance with all handlers initialized
func NewHandlers(services *service.Services, repos *repository.Repositories, rdb *redis.Client, buildInfo BuildInfo) *Handlers {
return &Handlers{
Auth: NewAuthHandler(services.Auth),
User: NewUserHandler(services.User),
APIKey: NewAPIKeyHandler(services.ApiKey),
Usage: NewUsageHandler(services.Usage, repos.UsageLog, services.ApiKey),
Redeem: NewRedeemHandler(services.Redeem),
Subscription: NewSubscriptionHandler(services.Subscription),
Admin: &AdminHandlers{
Dashboard: admin.NewDashboardHandler(services.Admin, repos.UsageLog),
User: admin.NewUserHandler(services.Admin),
Group: admin.NewGroupHandler(services.Admin),
Account: admin.NewAccountHandler(services.Admin, services.OAuth, services.RateLimit, services.AccountUsage, services.AccountTest),
OAuth: admin.NewOAuthHandler(services.OAuth, services.Admin),
Proxy: admin.NewProxyHandler(services.Admin),
Redeem: admin.NewRedeemHandler(services.Admin),
Setting: admin.NewSettingHandler(services.Setting, services.Email),
System: admin.NewSystemHandler(rdb, buildInfo.Version, buildInfo.BuildType),
Subscription: admin.NewSubscriptionHandler(services.Subscription),
Usage: admin.NewUsageHandler(repos.UsageLog, repos.ApiKey, services.Usage, services.Admin),
},
Gateway: NewGatewayHandler(services.Gateway, services.User, services.Concurrency, services.BillingCache),
Setting: NewSettingHandler(services.Setting, buildInfo.Version),
}
}

View File

@@ -0,0 +1,92 @@
package handler
import (
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// RedeemHandler handles redeem code-related requests
type RedeemHandler struct {
redeemService *service.RedeemService
}
// NewRedeemHandler creates a new RedeemHandler
func NewRedeemHandler(redeemService *service.RedeemService) *RedeemHandler {
return &RedeemHandler{
redeemService: redeemService,
}
}
// RedeemRequest represents the redeem code request payload
type RedeemRequest struct {
Code string `json:"code" binding:"required"`
}
// RedeemResponse represents the redeem response
type RedeemResponse struct {
Message string `json:"message"`
Type string `json:"type"`
Value float64 `json:"value"`
NewBalance *float64 `json:"new_balance,omitempty"`
NewConcurrency *int `json:"new_concurrency,omitempty"`
}
// Redeem handles redeeming a code
// POST /api/v1/redeem
func (h *RedeemHandler) Redeem(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
var req RedeemRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
result, err := h.redeemService.Redeem(c.Request.Context(), user.ID, req.Code)
if err != nil {
response.BadRequest(c, "Failed to redeem code: "+err.Error())
return
}
response.Success(c, result)
}
// GetHistory returns the user's redemption history
// GET /api/v1/redeem/history
func (h *RedeemHandler) GetHistory(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
// Default limit is 25
limit := 25
codes, err := h.redeemService.GetUserHistory(c.Request.Context(), user.ID, limit)
if err != nil {
response.InternalError(c, "Failed to get history: "+err.Error())
return
}
response.Success(c, codes)
}

View File

@@ -0,0 +1,35 @@
package handler
import (
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// SettingHandler 公开设置处理器(无需认证)
type SettingHandler struct {
settingService *service.SettingService
version string
}
// NewSettingHandler 创建公开设置处理器
func NewSettingHandler(settingService *service.SettingService, version string) *SettingHandler {
return &SettingHandler{
settingService: settingService,
version: version,
}
}
// GetPublicSettings 获取公开设置
// GET /api/v1/settings/public
func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
settings, err := h.settingService.GetPublicSettings(c.Request.Context())
if err != nil {
response.InternalError(c, "Failed to get settings: "+err.Error())
return
}
settings.Version = h.version
response.Success(c, settings)
}

View File

@@ -0,0 +1,203 @@
package handler
import (
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// SubscriptionSummaryItem represents a subscription item in summary
type SubscriptionSummaryItem struct {
ID int64 `json:"id"`
GroupID int64 `json:"group_id"`
GroupName string `json:"group_name"`
Status string `json:"status"`
DailyUsedUSD float64 `json:"daily_used_usd,omitempty"`
DailyLimitUSD float64 `json:"daily_limit_usd,omitempty"`
WeeklyUsedUSD float64 `json:"weekly_used_usd,omitempty"`
WeeklyLimitUSD float64 `json:"weekly_limit_usd,omitempty"`
MonthlyUsedUSD float64 `json:"monthly_used_usd,omitempty"`
MonthlyLimitUSD float64 `json:"monthly_limit_usd,omitempty"`
ExpiresAt *string `json:"expires_at,omitempty"`
}
// SubscriptionProgressInfo represents subscription with progress info
type SubscriptionProgressInfo struct {
Subscription *model.UserSubscription `json:"subscription"`
Progress *service.SubscriptionProgress `json:"progress"`
}
// SubscriptionHandler handles user subscription operations
type SubscriptionHandler struct {
subscriptionService *service.SubscriptionService
}
// NewSubscriptionHandler creates a new user subscription handler
func NewSubscriptionHandler(subscriptionService *service.SubscriptionService) *SubscriptionHandler {
return &SubscriptionHandler{
subscriptionService: subscriptionService,
}
}
// List handles listing current user's subscriptions
// GET /api/v1/subscriptions
func (h *SubscriptionHandler) List(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
subscriptions, err := h.subscriptionService.ListUserSubscriptions(c.Request.Context(), u.ID)
if err != nil {
response.InternalError(c, "Failed to list subscriptions: "+err.Error())
return
}
response.Success(c, subscriptions)
}
// GetActive handles getting current user's active subscriptions
// GET /api/v1/subscriptions/active
func (h *SubscriptionHandler) GetActive(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
if err != nil {
response.InternalError(c, "Failed to get active subscriptions: "+err.Error())
return
}
response.Success(c, subscriptions)
}
// GetProgress handles getting subscription progress for current user
// GET /api/v1/subscriptions/progress
func (h *SubscriptionHandler) GetProgress(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
// Get all active subscriptions with progress
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
if err != nil {
response.InternalError(c, "Failed to get subscriptions: "+err.Error())
return
}
result := make([]SubscriptionProgressInfo, 0, len(subscriptions))
for i := range subscriptions {
sub := &subscriptions[i]
progress, err := h.subscriptionService.GetSubscriptionProgress(c.Request.Context(), sub.ID)
if err != nil {
// Skip subscriptions with errors
continue
}
result = append(result, SubscriptionProgressInfo{
Subscription: sub,
Progress: progress,
})
}
response.Success(c, result)
}
// GetSummary handles getting a summary of current user's subscription status
// GET /api/v1/subscriptions/summary
func (h *SubscriptionHandler) GetSummary(c *gin.Context) {
user, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not found in context")
return
}
u, ok := user.(*model.User)
if !ok {
response.InternalError(c, "Invalid user in context")
return
}
// Get all active subscriptions
subscriptions, err := h.subscriptionService.ListActiveUserSubscriptions(c.Request.Context(), u.ID)
if err != nil {
response.InternalError(c, "Failed to get subscriptions: "+err.Error())
return
}
var totalUsed float64
items := make([]SubscriptionSummaryItem, 0, len(subscriptions))
for _, sub := range subscriptions {
item := SubscriptionSummaryItem{
ID: sub.ID,
GroupID: sub.GroupID,
Status: sub.Status,
DailyUsedUSD: sub.DailyUsageUSD,
WeeklyUsedUSD: sub.WeeklyUsageUSD,
MonthlyUsedUSD: sub.MonthlyUsageUSD,
}
// Add group info if preloaded
if sub.Group != nil {
item.GroupName = sub.Group.Name
if sub.Group.DailyLimitUSD != nil {
item.DailyLimitUSD = *sub.Group.DailyLimitUSD
}
if sub.Group.WeeklyLimitUSD != nil {
item.WeeklyLimitUSD = *sub.Group.WeeklyLimitUSD
}
if sub.Group.MonthlyLimitUSD != nil {
item.MonthlyLimitUSD = *sub.Group.MonthlyLimitUSD
}
}
// Format expiration time
if !sub.ExpiresAt.IsZero() {
formatted := sub.ExpiresAt.Format("2006-01-02T15:04:05Z07:00")
item.ExpiresAt = &formatted
}
// Track total usage (use monthly as the most comprehensive)
totalUsed += sub.MonthlyUsageUSD
items = append(items, item)
}
summary := struct {
ActiveCount int `json:"active_count"`
TotalUsedUSD float64 `json:"total_used_usd"`
Subscriptions []SubscriptionSummaryItem `json:"subscriptions"`
}{
ActiveCount: len(subscriptions),
TotalUsedUSD: totalUsed,
Subscriptions: items,
}
response.Success(c, summary)
}

View File

@@ -0,0 +1,396 @@
package handler
import (
"strconv"
"time"
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// UsageHandler handles usage-related requests
type UsageHandler struct {
usageService *service.UsageService
usageRepo *repository.UsageLogRepository
apiKeyService *service.ApiKeyService
}
// NewUsageHandler creates a new UsageHandler
func NewUsageHandler(usageService *service.UsageService, usageRepo *repository.UsageLogRepository, apiKeyService *service.ApiKeyService) *UsageHandler {
return &UsageHandler{
usageService: usageService,
usageRepo: usageRepo,
apiKeyService: apiKeyService,
}
}
// List handles listing usage records with pagination
// GET /api/v1/usage
func (h *UsageHandler) List(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
page, pageSize := response.ParsePagination(c)
var apiKeyID int64
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid api_key_id")
return
}
// [Security Fix] Verify API Key ownership to prevent horizontal privilege escalation
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id)
if err != nil {
response.NotFound(c, "API key not found")
return
}
if apiKey.UserID != user.ID {
response.Forbidden(c, "Not authorized to access this API key's usage records")
return
}
apiKeyID = id
}
params := repository.PaginationParams{Page: page, PageSize: pageSize}
var records []model.UsageLog
var result *repository.PaginationResult
var err error
if apiKeyID > 0 {
records, result, err = h.usageService.ListByApiKey(c.Request.Context(), apiKeyID, params)
} else {
records, result, err = h.usageService.ListByUser(c.Request.Context(), user.ID, params)
}
if err != nil {
response.InternalError(c, "Failed to list usage records: "+err.Error())
return
}
response.Paginated(c, records, result.Total, page, pageSize)
}
// GetByID handles getting a single usage record
// GET /api/v1/usage/:id
func (h *UsageHandler) GetByID(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
usageID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid usage ID")
return
}
record, err := h.usageService.GetByID(c.Request.Context(), usageID)
if err != nil {
response.NotFound(c, "Usage record not found")
return
}
// 验证所有权
if record.UserID != user.ID {
response.Forbidden(c, "Not authorized to access this record")
return
}
response.Success(c, record)
}
// Stats handles getting usage statistics
// GET /api/v1/usage/stats
func (h *UsageHandler) Stats(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
var apiKeyID int64
if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
id, err := strconv.ParseInt(apiKeyIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "Invalid api_key_id")
return
}
// [Security Fix] Verify API Key ownership to prevent horizontal privilege escalation
apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id)
if err != nil {
response.NotFound(c, "API key not found")
return
}
if apiKey.UserID != user.ID {
response.Forbidden(c, "Not authorized to access this API key's statistics")
return
}
apiKeyID = id
}
// 获取时间范围参数
now := timezone.Now()
var startTime, endTime time.Time
// 优先使用 start_date 和 end_date 参数
startDateStr := c.Query("start_date")
endDateStr := c.Query("end_date")
if startDateStr != "" && endDateStr != "" {
// 使用自定义日期范围
var err error
startTime, err = timezone.ParseInLocation("2006-01-02", startDateStr)
if err != nil {
response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
return
}
endTime, err = timezone.ParseInLocation("2006-01-02", endDateStr)
if err != nil {
response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
return
}
// 设置结束时间为当天结束
endTime = endTime.Add(24*time.Hour - time.Nanosecond)
} else {
// 使用 period 参数
period := c.DefaultQuery("period", "today")
switch period {
case "today":
startTime = timezone.StartOfDay(now)
case "week":
startTime = now.AddDate(0, 0, -7)
case "month":
startTime = now.AddDate(0, -1, 0)
default:
startTime = timezone.StartOfDay(now)
}
endTime = now
}
var stats *service.UsageStats
var err error
if apiKeyID > 0 {
stats, err = h.usageService.GetStatsByApiKey(c.Request.Context(), apiKeyID, startTime, endTime)
} else {
stats, err = h.usageService.GetStatsByUser(c.Request.Context(), user.ID, startTime, endTime)
}
if err != nil {
response.InternalError(c, "Failed to get usage statistics: "+err.Error())
return
}
response.Success(c, stats)
}
// parseUserTimeRange parses start_date, end_date query parameters for user dashboard
func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) {
now := timezone.Now()
startDate := c.Query("start_date")
endDate := c.Query("end_date")
var startTime, endTime time.Time
if startDate != "" {
if t, err := timezone.ParseInLocation("2006-01-02", startDate); err == nil {
startTime = t
} else {
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
}
} else {
startTime = timezone.StartOfDay(now.AddDate(0, 0, -7))
}
if endDate != "" {
if t, err := timezone.ParseInLocation("2006-01-02", endDate); err == nil {
endTime = t.Add(24 * time.Hour) // Include the end date
} else {
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
}
} else {
endTime = timezone.StartOfDay(now.AddDate(0, 0, 1))
}
return startTime, endTime
}
// DashboardStats handles getting user dashboard statistics
// GET /api/v1/usage/dashboard/stats
func (h *UsageHandler) DashboardStats(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
stats, err := h.usageRepo.GetUserDashboardStats(c.Request.Context(), user.ID)
if err != nil {
response.InternalError(c, "Failed to get dashboard statistics")
return
}
response.Success(c, stats)
}
// DashboardTrend handles getting user usage trend data
// GET /api/v1/usage/dashboard/trend
func (h *UsageHandler) DashboardTrend(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
startTime, endTime := parseUserTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
trend, err := h.usageRepo.GetUserUsageTrendByUserID(c.Request.Context(), user.ID, startTime, endTime, granularity)
if err != nil {
response.InternalError(c, "Failed to get usage trend")
return
}
response.Success(c, gin.H{
"trend": trend,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
"granularity": granularity,
})
}
// DashboardModels handles getting user model usage statistics
// GET /api/v1/usage/dashboard/models
func (h *UsageHandler) DashboardModels(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
startTime, endTime := parseUserTimeRange(c)
stats, err := h.usageRepo.GetUserModelStats(c.Request.Context(), user.ID, startTime, endTime)
if err != nil {
response.InternalError(c, "Failed to get model statistics")
return
}
response.Success(c, gin.H{
"models": stats,
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
})
}
// BatchApiKeysUsageRequest represents the request for batch API keys usage
type BatchApiKeysUsageRequest struct {
ApiKeyIDs []int64 `json:"api_key_ids" binding:"required"`
}
// DashboardApiKeysUsage handles getting usage stats for user's own API keys
// POST /api/v1/usage/dashboard/api-keys-usage
func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
var req BatchApiKeysUsageRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if len(req.ApiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]interface{}{}})
return
}
// Verify ownership of all requested API keys
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, repository.PaginationParams{Page: 1, PageSize: 1000})
if err != nil {
response.InternalError(c, "Failed to verify API key ownership")
return
}
userApiKeyIDs := make(map[int64]bool)
for _, key := range userApiKeys {
userApiKeyIDs[key.ID] = true
}
// Filter to only include user's own API keys
validApiKeyIDs := make([]int64, 0)
for _, id := range req.ApiKeyIDs {
if userApiKeyIDs[id] {
validApiKeyIDs = append(validApiKeyIDs, id)
}
}
if len(validApiKeyIDs) == 0 {
response.Success(c, gin.H{"stats": map[string]interface{}{}})
return
}
stats, err := h.usageRepo.GetBatchApiKeyUsageStats(c.Request.Context(), validApiKeyIDs)
if err != nil {
response.InternalError(c, "Failed to get API key usage stats")
return
}
response.Success(c, gin.H{"stats": stats})
}

View File

@@ -0,0 +1,85 @@
package handler
import (
"sub2api/internal/model"
"sub2api/internal/pkg/response"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// UserHandler handles user-related requests
type UserHandler struct {
userService *service.UserService
}
// NewUserHandler creates a new UserHandler
func NewUserHandler(userService *service.UserService) *UserHandler {
return &UserHandler{
userService: userService,
}
}
// ChangePasswordRequest represents the change password request payload
type ChangePasswordRequest struct {
OldPassword string `json:"old_password" binding:"required"`
NewPassword string `json:"new_password" binding:"required,min=6"`
}
// GetProfile handles getting user profile
// GET /api/v1/users/me
func (h *UserHandler) GetProfile(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
userData, err := h.userService.GetByID(c.Request.Context(), user.ID)
if err != nil {
response.InternalError(c, "Failed to get user profile: "+err.Error())
return
}
response.Success(c, userData)
}
// ChangePassword handles changing user password
// POST /api/v1/users/me/password
func (h *UserHandler) ChangePassword(c *gin.Context) {
userValue, exists := c.Get("user")
if !exists {
response.Unauthorized(c, "User not authenticated")
return
}
user, ok := userValue.(*model.User)
if !ok {
response.InternalError(c, "Invalid user context")
return
}
var req ChangePasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
svcReq := service.ChangePasswordRequest{
CurrentPassword: req.OldPassword,
NewPassword: req.NewPassword,
}
err := h.userService.ChangePassword(c.Request.Context(), user.ID, svcReq)
if err != nil {
response.BadRequest(c, "Failed to change password: "+err.Error())
return
}
response.Success(c, gin.H{"message": "Password changed successfully"})
}