refactor: 调整 server 目录结构

This commit is contained in:
Forest
2025-12-26 10:42:08 +08:00
parent 8d7a497553
commit 57fd172287
27 changed files with 548 additions and 472 deletions

View File

@@ -0,0 +1,133 @@
package middleware
import (
"crypto/subtle"
"errors"
"strings"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// NewAdminAuthMiddleware 创建管理员认证中间件
func NewAdminAuthMiddleware(
authService *service.AuthService,
userService *service.UserService,
settingService *service.SettingService,
) AdminAuthMiddleware {
return AdminAuthMiddleware(adminAuth(authService, userService, settingService))
}
// adminAuth 管理员认证中间件实现
// 支持两种认证方式(通过不同的 header 区分):
// 1. Admin API Key: x-api-key: <admin-api-key>
// 2. JWT Token: Authorization: Bearer <jwt-token> (需要管理员角色)
func adminAuth(
authService *service.AuthService,
userService *service.UserService,
settingService *service.SettingService,
) gin.HandlerFunc {
return func(c *gin.Context) {
// 检查 x-api-key headerAdmin API Key 认证)
apiKey := c.GetHeader("x-api-key")
if apiKey != "" {
if !validateAdminApiKey(c, apiKey, settingService, userService) {
return
}
c.Next()
return
}
// 检查 Authorization headerJWT 认证)
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" {
if !validateJWTForAdmin(c, parts[1], authService, userService) {
return
}
c.Next()
return
}
}
// 无有效认证信息
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required")
}
}
// validateAdminApiKey 验证管理员 API Key
func validateAdminApiKey(
c *gin.Context,
key string,
settingService *service.SettingService,
userService *service.UserService,
) bool {
storedKey, err := settingService.GetAdminApiKey(c.Request.Context())
if err != nil {
AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error")
return false
}
// 未配置或不匹配,统一返回相同错误(避免信息泄露)
if storedKey == "" || subtle.ConstantTimeCompare([]byte(key), []byte(storedKey)) != 1 {
AbortWithError(c, 401, "INVALID_ADMIN_KEY", "Invalid admin API key")
return false
}
// 获取真实的管理员用户
admin, err := userService.GetFirstAdmin(c.Request.Context())
if err != nil {
AbortWithError(c, 500, "INTERNAL_ERROR", "No admin user found")
return false
}
c.Set(string(ContextKeyUser), admin)
c.Set("auth_method", "admin_api_key")
return true
}
// validateJWTForAdmin 验证 JWT 并检查管理员权限
func validateJWTForAdmin(
c *gin.Context,
token string,
authService *service.AuthService,
userService *service.UserService,
) bool {
// 验证 JWT token
claims, err := authService.ValidateToken(token)
if err != nil {
if errors.Is(err, service.ErrTokenExpired) {
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
return false
}
AbortWithError(c, 401, "INVALID_TOKEN", "Invalid token")
return false
}
// 从数据库获取用户
user, err := userService.GetByID(c.Request.Context(), claims.UserID)
if err != nil {
AbortWithError(c, 401, "USER_NOT_FOUND", "User not found")
return false
}
// 检查用户状态
if !user.IsActive() {
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
return false
}
// 检查管理员权限
if user.Role != model.RoleAdmin {
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
return false
}
c.Set(string(ContextKeyUser), user)
c.Set("auth_method", "jwt")
return true
}

View File

@@ -0,0 +1,28 @@
package middleware
import (
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/gin-gonic/gin"
)
// AdminOnly 管理员权限中间件
// 必须在JWTAuth中间件之后使用
func AdminOnly() gin.HandlerFunc {
return func(c *gin.Context) {
// 从上下文获取用户
user, exists := GetUserFromContext(c)
if !exists {
AbortWithError(c, 401, "UNAUTHORIZED", "User not found in context")
return
}
// 检查是否为管理员
if user.Role != model.RoleAdmin {
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
return
}
c.Next()
}
}

View File

@@ -0,0 +1,148 @@
package middleware
import (
"errors"
"log"
"strings"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// NewApiKeyAuthMiddleware 创建 API Key 认证中间件
func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService) ApiKeyAuthMiddleware {
return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService))
}
// apiKeyAuthWithSubscription API Key认证中间件支持订阅验证
func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService) gin.HandlerFunc {
return func(c *gin.Context) {
// 尝试从Authorization header中提取API key (Bearer scheme)
authHeader := c.GetHeader("Authorization")
var apiKeyString string
if authHeader != "" {
// 验证Bearer scheme
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" {
apiKeyString = parts[1]
}
}
// 如果Authorization header中没有尝试从x-api-key header中提取
if apiKeyString == "" {
apiKeyString = c.GetHeader("x-api-key")
}
// 如果两个header都没有API key
if apiKeyString == "" {
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme) or x-api-key header")
return
}
// 从数据库验证API key
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
return
}
AbortWithError(c, 500, "INTERNAL_ERROR", "Failed to validate API key")
return
}
// 检查API key是否激活
if !apiKey.IsActive() {
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
return
}
// 检查关联的用户
if apiKey.User == nil {
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")
return
}
// 检查用户状态
if !apiKey.User.IsActive() {
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
return
}
// 判断计费方式:订阅模式 vs 余额模式
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
if isSubscriptionType && subscriptionService != nil {
// 订阅模式:验证订阅
subscription, err := subscriptionService.GetActiveSubscription(
c.Request.Context(),
apiKey.User.ID,
apiKey.Group.ID,
)
if err != nil {
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
return
}
// 验证订阅状态(是否过期、暂停等)
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error())
return
}
// 激活滑动窗口(首次使用时)
if err := subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription); err != nil {
log.Printf("Failed to activate subscription windows: %v", err)
}
// 检查并重置过期窗口
if err := subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription); err != nil {
log.Printf("Failed to reset subscription windows: %v", err)
}
// 预检查用量限制使用0作为额外费用进行预检查
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error())
return
}
// 将订阅信息存入上下文
c.Set(string(ContextKeySubscription), subscription)
} else {
// 余额模式:检查用户余额
if apiKey.User.Balance <= 0 {
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
return
}
}
// 将API key和用户信息存入上下文
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), apiKey.User)
c.Next()
}
}
// GetApiKeyFromContext 从上下文中获取API key
func GetApiKeyFromContext(c *gin.Context) (*model.ApiKey, bool) {
value, exists := c.Get(string(ContextKeyApiKey))
if !exists {
return nil, false
}
apiKey, ok := value.(*model.ApiKey)
return apiKey, ok
}
// GetSubscriptionFromContext 从上下文中获取订阅信息
func GetSubscriptionFromContext(c *gin.Context) (*model.UserSubscription, bool) {
value, exists := c.Get(string(ContextKeySubscription))
if !exists {
return nil, false
}
subscription, ok := value.(*model.UserSubscription)
return subscription, ok
}

View File

@@ -0,0 +1,24 @@
package middleware
import (
"github.com/gin-gonic/gin"
)
// CORS 跨域中间件
func CORS() gin.HandlerFunc {
return func(c *gin.Context) {
// 设置允许跨域的响应头
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
// 处理预检请求
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
}
}

View File

@@ -0,0 +1,80 @@
package middleware
import (
"errors"
"strings"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// NewJWTAuthMiddleware 创建 JWT 认证中间件
func NewJWTAuthMiddleware(authService *service.AuthService, userService *service.UserService) JWTAuthMiddleware {
return JWTAuthMiddleware(jwtAuth(authService, userService))
}
// jwtAuth JWT认证中间件实现
func jwtAuth(authService *service.AuthService, userService *service.UserService) gin.HandlerFunc {
return func(c *gin.Context) {
// 从Authorization header中提取token
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization header is required")
return
}
// 验证Bearer scheme
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
AbortWithError(c, 401, "INVALID_AUTH_HEADER", "Authorization header format must be 'Bearer {token}'")
return
}
tokenString := parts[1]
if tokenString == "" {
AbortWithError(c, 401, "EMPTY_TOKEN", "Token cannot be empty")
return
}
// 验证token
claims, err := authService.ValidateToken(tokenString)
if err != nil {
if errors.Is(err, service.ErrTokenExpired) {
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
return
}
AbortWithError(c, 401, "INVALID_TOKEN", "Invalid token")
return
}
// 从数据库获取最新的用户信息
user, err := userService.GetByID(c.Request.Context(), claims.UserID)
if err != nil {
AbortWithError(c, 401, "USER_NOT_FOUND", "User not found")
return
}
// 检查用户状态
if !user.IsActive() {
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
return
}
// 将用户信息存入上下文
c.Set(string(ContextKeyUser), user)
c.Next()
}
}
// GetUserFromContext 从上下文中获取用户
func GetUserFromContext(c *gin.Context) (*model.User, bool) {
value, exists := c.Get(string(ContextKeyUser))
if !exists {
return nil, false
}
user, ok := value.(*model.User)
return user, ok
}

View File

@@ -0,0 +1,52 @@
package middleware
import (
"log"
"time"
"github.com/gin-gonic/gin"
)
// Logger 请求日志中间件
func Logger() gin.HandlerFunc {
return func(c *gin.Context) {
// 开始时间
startTime := time.Now()
// 处理请求
c.Next()
// 结束时间
endTime := time.Now()
// 执行时间
latency := endTime.Sub(startTime)
// 请求方法
method := c.Request.Method
// 请求路径
path := c.Request.URL.Path
// 状态码
statusCode := c.Writer.Status()
// 客户端IP
clientIP := c.ClientIP()
// 日志格式: [时间] 状态码 | 延迟 | IP | 方法 路径
log.Printf("[GIN] %v | %3d | %13v | %15s | %-7s %s",
endTime.Format("2006/01/02 - 15:04:05"),
statusCode,
latency,
clientIP,
method,
path,
)
// 如果有错误,额外记录错误信息
if len(c.Errors) > 0 {
log.Printf("[GIN] Errors: %v", c.Errors.String())
}
}
}

View File

@@ -0,0 +1,35 @@
package middleware
import "github.com/gin-gonic/gin"
// ContextKey 定义上下文键类型
type ContextKey string
const (
// ContextKeyUser 用户上下文键
ContextKeyUser ContextKey = "user"
// ContextKeyApiKey API密钥上下文键
ContextKeyApiKey ContextKey = "api_key"
// ContextKeySubscription 订阅上下文键
ContextKeySubscription ContextKey = "subscription"
)
// ErrorResponse 标准错误响应结构
type ErrorResponse struct {
Code string `json:"code"`
Message string `json:"message"`
}
// NewErrorResponse 创建错误响应
func NewErrorResponse(code, message string) ErrorResponse {
return ErrorResponse{
Code: code,
Message: message,
}
}
// AbortWithError 中断请求并返回JSON错误
func AbortWithError(c *gin.Context, statusCode int, code, message string) {
c.JSON(statusCode, NewErrorResponse(code, message))
c.Abort()
}

View File

@@ -0,0 +1,64 @@
package middleware
import (
"errors"
"net"
"net/http"
"os"
"strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/gin-gonic/gin"
)
// Recovery converts panics into the project's standard JSON error envelope.
//
// It preserves Gin's broken-pipe handling by not attempting to write a response
// when the client connection is already gone.
func Recovery() gin.HandlerFunc {
return gin.CustomRecoveryWithWriter(gin.DefaultErrorWriter, func(c *gin.Context, recovered any) {
recoveredErr, _ := recovered.(error)
if isBrokenPipe(recoveredErr) {
if recoveredErr != nil {
_ = c.Error(recoveredErr)
}
c.Abort()
return
}
if c.Writer.Written() {
c.Abort()
return
}
response.ErrorWithDetails(
c,
http.StatusInternalServerError,
infraerrors.UnknownMessage,
infraerrors.UnknownReason,
nil,
)
c.Abort()
})
}
func isBrokenPipe(err error) bool {
if err == nil {
return false
}
var opErr *net.OpError
if !errors.As(err, &opErr) {
return false
}
var syscallErr *os.SyscallError
if !errors.As(opErr.Err, &syscallErr) {
return false
}
msg := strings.ToLower(syscallErr.Error())
return strings.Contains(msg, "broken pipe") || strings.Contains(msg, "connection reset by peer")
}

View File

@@ -0,0 +1,81 @@
//go:build unit
package middleware
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestRecovery(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
handler gin.HandlerFunc
wantHTTPCode int
wantBody response.Response
}{
{
name: "panic_returns_standard_json_500",
handler: func(c *gin.Context) {
panic("boom")
},
wantHTTPCode: http.StatusInternalServerError,
wantBody: response.Response{
Code: http.StatusInternalServerError,
Message: infraerrors.UnknownMessage,
},
},
{
name: "no_panic_passthrough",
handler: func(c *gin.Context) {
response.Success(c, gin.H{"ok": true})
},
wantHTTPCode: http.StatusOK,
wantBody: response.Response{
Code: 0,
Message: "success",
Data: map[string]any{"ok": true},
},
},
{
name: "panic_after_write_does_not_override_body",
handler: func(c *gin.Context) {
response.Success(c, gin.H{"ok": true})
panic("boom")
},
wantHTTPCode: http.StatusOK,
wantBody: response.Response{
Code: 0,
Message: "success",
Data: map[string]any{"ok": true},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := gin.New()
r.Use(Recovery())
r.GET("/t", tt.handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
r.ServeHTTP(w, req)
require.Equal(t, tt.wantHTTPCode, w.Code)
var got response.Response
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
require.Equal(t, tt.wantBody, got)
})
}
}

View File

@@ -0,0 +1,22 @@
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/google/wire"
)
// JWTAuthMiddleware JWT 认证中间件类型
type JWTAuthMiddleware gin.HandlerFunc
// AdminAuthMiddleware 管理员认证中间件类型
type AdminAuthMiddleware gin.HandlerFunc
// ApiKeyAuthMiddleware API Key 认证中间件类型
type ApiKeyAuthMiddleware gin.HandlerFunc
// ProviderSet 中间件层的依赖注入
var ProviderSet = wire.NewSet(
NewJWTAuthMiddleware,
NewAdminAuthMiddleware,
NewApiKeyAuthMiddleware,
)