feat: 品牌重命名 Sub2API -> TianShuAPI
- 前端: 所有界面显示、i18n 文本、组件中的品牌名称 - 后端: 服务层、设置默认值、邮件模板、安装向导 - 数据库: 迁移脚本注释 - 保持功能完全一致,仅更改品牌名称 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,140 +1,140 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// NewAdminAuthMiddleware 创建管理员认证中间件
|
||||
func NewAdminAuthMiddleware(
|
||||
authService *service.AuthService,
|
||||
userService *service.UserService,
|
||||
settingService *service.SettingService,
|
||||
) AdminAuthMiddleware {
|
||||
return AdminAuthMiddleware(adminAuth(authService, userService, settingService))
|
||||
}
|
||||
|
||||
// adminAuth 管理员认证中间件实现
|
||||
// 支持两种认证方式(通过不同的 header 区分):
|
||||
// 1. Admin API Key: x-api-key: <admin-api-key>
|
||||
// 2. JWT Token: Authorization: Bearer <jwt-token> (需要管理员角色)
|
||||
func adminAuth(
|
||||
authService *service.AuthService,
|
||||
userService *service.UserService,
|
||||
settingService *service.SettingService,
|
||||
) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 检查 x-api-key header(Admin API Key 认证)
|
||||
apiKey := c.GetHeader("x-api-key")
|
||||
if apiKey != "" {
|
||||
if !validateAdminApiKey(c, apiKey, settingService, userService) {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查 Authorization header(JWT 认证)
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" {
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) == 2 && parts[0] == "Bearer" {
|
||||
if !validateJWTForAdmin(c, parts[1], authService, userService) {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 无有效认证信息
|
||||
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required")
|
||||
}
|
||||
}
|
||||
|
||||
// validateAdminApiKey 验证管理员 API Key
|
||||
func validateAdminApiKey(
|
||||
c *gin.Context,
|
||||
key string,
|
||||
settingService *service.SettingService,
|
||||
userService *service.UserService,
|
||||
) bool {
|
||||
storedKey, err := settingService.GetAdminApiKey(c.Request.Context())
|
||||
if err != nil {
|
||||
AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error")
|
||||
return false
|
||||
}
|
||||
|
||||
// 未配置或不匹配,统一返回相同错误(避免信息泄露)
|
||||
if storedKey == "" || subtle.ConstantTimeCompare([]byte(key), []byte(storedKey)) != 1 {
|
||||
AbortWithError(c, 401, "INVALID_ADMIN_KEY", "Invalid admin API key")
|
||||
return false
|
||||
}
|
||||
|
||||
// 获取真实的管理员用户
|
||||
admin, err := userService.GetFirstAdmin(c.Request.Context())
|
||||
if err != nil {
|
||||
AbortWithError(c, 500, "INTERNAL_ERROR", "No admin user found")
|
||||
return false
|
||||
}
|
||||
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: admin.ID,
|
||||
Concurrency: admin.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), admin.Role)
|
||||
c.Set("auth_method", "admin_api_key")
|
||||
return true
|
||||
}
|
||||
|
||||
// validateJWTForAdmin 验证 JWT 并检查管理员权限
|
||||
func validateJWTForAdmin(
|
||||
c *gin.Context,
|
||||
token string,
|
||||
authService *service.AuthService,
|
||||
userService *service.UserService,
|
||||
) bool {
|
||||
// 验证 JWT token
|
||||
claims, err := authService.ValidateToken(token)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrTokenExpired) {
|
||||
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
|
||||
return false
|
||||
}
|
||||
AbortWithError(c, 401, "INVALID_TOKEN", "Invalid token")
|
||||
return false
|
||||
}
|
||||
|
||||
// 从数据库获取用户
|
||||
user, err := userService.GetByID(c.Request.Context(), claims.UserID)
|
||||
if err != nil {
|
||||
AbortWithError(c, 401, "USER_NOT_FOUND", "User not found")
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !user.IsActive() {
|
||||
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查管理员权限
|
||||
if !user.IsAdmin() {
|
||||
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
|
||||
return false
|
||||
}
|
||||
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: user.ID,
|
||||
Concurrency: user.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), user.Role)
|
||||
c.Set("auth_method", "jwt")
|
||||
|
||||
return true
|
||||
}
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// NewAdminAuthMiddleware 创建管理员认证中间件
|
||||
func NewAdminAuthMiddleware(
|
||||
authService *service.AuthService,
|
||||
userService *service.UserService,
|
||||
settingService *service.SettingService,
|
||||
) AdminAuthMiddleware {
|
||||
return AdminAuthMiddleware(adminAuth(authService, userService, settingService))
|
||||
}
|
||||
|
||||
// adminAuth 管理员认证中间件实现
|
||||
// 支持两种认证方式(通过不同的 header 区分):
|
||||
// 1. Admin API Key: x-api-key: <admin-api-key>
|
||||
// 2. JWT Token: Authorization: Bearer <jwt-token> (需要管理员角色)
|
||||
func adminAuth(
|
||||
authService *service.AuthService,
|
||||
userService *service.UserService,
|
||||
settingService *service.SettingService,
|
||||
) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 检查 x-api-key header(Admin API Key 认证)
|
||||
apiKey := c.GetHeader("x-api-key")
|
||||
if apiKey != "" {
|
||||
if !validateAdminApiKey(c, apiKey, settingService, userService) {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查 Authorization header(JWT 认证)
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" {
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) == 2 && parts[0] == "Bearer" {
|
||||
if !validateJWTForAdmin(c, parts[1], authService, userService) {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 无有效认证信息
|
||||
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required")
|
||||
}
|
||||
}
|
||||
|
||||
// validateAdminApiKey 验证管理员 API Key
|
||||
func validateAdminApiKey(
|
||||
c *gin.Context,
|
||||
key string,
|
||||
settingService *service.SettingService,
|
||||
userService *service.UserService,
|
||||
) bool {
|
||||
storedKey, err := settingService.GetAdminApiKey(c.Request.Context())
|
||||
if err != nil {
|
||||
AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error")
|
||||
return false
|
||||
}
|
||||
|
||||
// 未配置或不匹配,统一返回相同错误(避免信息泄露)
|
||||
if storedKey == "" || subtle.ConstantTimeCompare([]byte(key), []byte(storedKey)) != 1 {
|
||||
AbortWithError(c, 401, "INVALID_ADMIN_KEY", "Invalid admin API key")
|
||||
return false
|
||||
}
|
||||
|
||||
// 获取真实的管理员用户
|
||||
admin, err := userService.GetFirstAdmin(c.Request.Context())
|
||||
if err != nil {
|
||||
AbortWithError(c, 500, "INTERNAL_ERROR", "No admin user found")
|
||||
return false
|
||||
}
|
||||
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: admin.ID,
|
||||
Concurrency: admin.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), admin.Role)
|
||||
c.Set("auth_method", "admin_api_key")
|
||||
return true
|
||||
}
|
||||
|
||||
// validateJWTForAdmin 验证 JWT 并检查管理员权限
|
||||
func validateJWTForAdmin(
|
||||
c *gin.Context,
|
||||
token string,
|
||||
authService *service.AuthService,
|
||||
userService *service.UserService,
|
||||
) bool {
|
||||
// 验证 JWT token
|
||||
claims, err := authService.ValidateToken(token)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrTokenExpired) {
|
||||
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
|
||||
return false
|
||||
}
|
||||
AbortWithError(c, 401, "INVALID_TOKEN", "Invalid token")
|
||||
return false
|
||||
}
|
||||
|
||||
// 从数据库获取用户
|
||||
user, err := userService.GetByID(c.Request.Context(), claims.UserID)
|
||||
if err != nil {
|
||||
AbortWithError(c, 401, "USER_NOT_FOUND", "User not found")
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !user.IsActive() {
|
||||
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查管理员权限
|
||||
if !user.IsAdmin() {
|
||||
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
|
||||
return false
|
||||
}
|
||||
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: user.ID,
|
||||
Concurrency: user.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), user.Role)
|
||||
c.Set("auth_method", "jwt")
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1,27 +1,27 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AdminOnly 管理员权限中间件
|
||||
// 必须在JWTAuth中间件之后使用
|
||||
func AdminOnly() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
role, ok := GetUserRoleFromContext(c)
|
||||
if !ok {
|
||||
AbortWithError(c, 401, "UNAUTHORIZED", "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否为管理员
|
||||
if role != service.RoleAdmin {
|
||||
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AdminOnly 管理员权限中间件
|
||||
// 必须在JWTAuth中间件之后使用
|
||||
func AdminOnly() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
role, ok := GetUserRoleFromContext(c)
|
||||
if !ok {
|
||||
AbortWithError(c, 401, "UNAUTHORIZED", "User not found in context")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否为管理员
|
||||
if role != service.RoleAdmin {
|
||||
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,178 +1,178 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// NewApiKeyAuthMiddleware 创建 API Key 认证中间件
|
||||
func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) ApiKeyAuthMiddleware {
|
||||
return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg))
|
||||
}
|
||||
|
||||
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
|
||||
func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 尝试从Authorization header中提取API key (Bearer scheme)
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
var apiKeyString string
|
||||
|
||||
if authHeader != "" {
|
||||
// 验证Bearer scheme
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) == 2 && parts[0] == "Bearer" {
|
||||
apiKeyString = parts[1]
|
||||
}
|
||||
}
|
||||
|
||||
// 如果Authorization header中没有,尝试从x-api-key header中提取
|
||||
if apiKeyString == "" {
|
||||
apiKeyString = c.GetHeader("x-api-key")
|
||||
}
|
||||
|
||||
// 如果x-api-key header中没有,尝试从x-goog-api-key header中提取(Gemini CLI兼容)
|
||||
if apiKeyString == "" {
|
||||
apiKeyString = c.GetHeader("x-goog-api-key")
|
||||
}
|
||||
|
||||
// 如果header中没有,尝试从query参数中提取(Google API key风格)
|
||||
if apiKeyString == "" {
|
||||
apiKeyString = c.Query("key")
|
||||
}
|
||||
|
||||
// 兼容常见别名
|
||||
if apiKeyString == "" {
|
||||
apiKeyString = c.Query("api_key")
|
||||
}
|
||||
|
||||
// 如果所有header都没有API key
|
||||
if apiKeyString == "" {
|
||||
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter")
|
||||
return
|
||||
}
|
||||
|
||||
// 从数据库验证API key
|
||||
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrApiKeyNotFound) {
|
||||
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
|
||||
return
|
||||
}
|
||||
AbortWithError(c, 500, "INTERNAL_ERROR", "Failed to validate API key")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查API key是否激活
|
||||
if !apiKey.IsActive() {
|
||||
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查关联的用户
|
||||
if apiKey.User == nil {
|
||||
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !apiKey.User.IsActive() {
|
||||
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
|
||||
return
|
||||
}
|
||||
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
// 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文
|
||||
c.Set(string(ContextKeyApiKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 判断计费方式:订阅模式 vs 余额模式
|
||||
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
|
||||
if isSubscriptionType && subscriptionService != nil {
|
||||
// 订阅模式:验证订阅
|
||||
subscription, err := subscriptionService.GetActiveSubscription(
|
||||
c.Request.Context(),
|
||||
apiKey.User.ID,
|
||||
apiKey.Group.ID,
|
||||
)
|
||||
if err != nil {
|
||||
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证订阅状态(是否过期、暂停等)
|
||||
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
|
||||
AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 激活滑动窗口(首次使用时)
|
||||
if err := subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription); err != nil {
|
||||
log.Printf("Failed to activate subscription windows: %v", err)
|
||||
}
|
||||
|
||||
// 检查并重置过期窗口
|
||||
if err := subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription); err != nil {
|
||||
log.Printf("Failed to reset subscription windows: %v", err)
|
||||
}
|
||||
|
||||
// 预检查用量限制(使用0作为额外费用进行预检查)
|
||||
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
|
||||
AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 将订阅信息存入上下文
|
||||
c.Set(string(ContextKeySubscription), subscription)
|
||||
} else {
|
||||
// 余额模式:检查用户余额
|
||||
if apiKey.User.Balance <= 0 {
|
||||
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 将API key和用户信息存入上下文
|
||||
c.Set(string(ContextKeyApiKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// GetApiKeyFromContext 从上下文中获取API key
|
||||
func GetApiKeyFromContext(c *gin.Context) (*service.ApiKey, bool) {
|
||||
value, exists := c.Get(string(ContextKeyApiKey))
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
apiKey, ok := value.(*service.ApiKey)
|
||||
return apiKey, ok
|
||||
}
|
||||
|
||||
// GetSubscriptionFromContext 从上下文中获取订阅信息
|
||||
func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool) {
|
||||
value, exists := c.Get(string(ContextKeySubscription))
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
subscription, ok := value.(*service.UserSubscription)
|
||||
return subscription, ok
|
||||
}
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// NewApiKeyAuthMiddleware 创建 API Key 认证中间件
|
||||
func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) ApiKeyAuthMiddleware {
|
||||
return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg))
|
||||
}
|
||||
|
||||
// apiKeyAuthWithSubscription API Key认证中间件(支持订阅验证)
|
||||
func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 尝试从Authorization header中提取API key (Bearer scheme)
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
var apiKeyString string
|
||||
|
||||
if authHeader != "" {
|
||||
// 验证Bearer scheme
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) == 2 && parts[0] == "Bearer" {
|
||||
apiKeyString = parts[1]
|
||||
}
|
||||
}
|
||||
|
||||
// 如果Authorization header中没有,尝试从x-api-key header中提取
|
||||
if apiKeyString == "" {
|
||||
apiKeyString = c.GetHeader("x-api-key")
|
||||
}
|
||||
|
||||
// 如果x-api-key header中没有,尝试从x-goog-api-key header中提取(Gemini CLI兼容)
|
||||
if apiKeyString == "" {
|
||||
apiKeyString = c.GetHeader("x-goog-api-key")
|
||||
}
|
||||
|
||||
// 如果header中没有,尝试从query参数中提取(Google API key风格)
|
||||
if apiKeyString == "" {
|
||||
apiKeyString = c.Query("key")
|
||||
}
|
||||
|
||||
// 兼容常见别名
|
||||
if apiKeyString == "" {
|
||||
apiKeyString = c.Query("api_key")
|
||||
}
|
||||
|
||||
// 如果所有header都没有API key
|
||||
if apiKeyString == "" {
|
||||
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter")
|
||||
return
|
||||
}
|
||||
|
||||
// 从数据库验证API key
|
||||
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrApiKeyNotFound) {
|
||||
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
|
||||
return
|
||||
}
|
||||
AbortWithError(c, 500, "INTERNAL_ERROR", "Failed to validate API key")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查API key是否激活
|
||||
if !apiKey.IsActive() {
|
||||
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查关联的用户
|
||||
if apiKey.User == nil {
|
||||
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !apiKey.User.IsActive() {
|
||||
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
|
||||
return
|
||||
}
|
||||
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
// 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文
|
||||
c.Set(string(ContextKeyApiKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 判断计费方式:订阅模式 vs 余额模式
|
||||
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
|
||||
if isSubscriptionType && subscriptionService != nil {
|
||||
// 订阅模式:验证订阅
|
||||
subscription, err := subscriptionService.GetActiveSubscription(
|
||||
c.Request.Context(),
|
||||
apiKey.User.ID,
|
||||
apiKey.Group.ID,
|
||||
)
|
||||
if err != nil {
|
||||
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证订阅状态(是否过期、暂停等)
|
||||
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
|
||||
AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 激活滑动窗口(首次使用时)
|
||||
if err := subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription); err != nil {
|
||||
log.Printf("Failed to activate subscription windows: %v", err)
|
||||
}
|
||||
|
||||
// 检查并重置过期窗口
|
||||
if err := subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription); err != nil {
|
||||
log.Printf("Failed to reset subscription windows: %v", err)
|
||||
}
|
||||
|
||||
// 预检查用量限制(使用0作为额外费用进行预检查)
|
||||
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
|
||||
AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 将订阅信息存入上下文
|
||||
c.Set(string(ContextKeySubscription), subscription)
|
||||
} else {
|
||||
// 余额模式:检查用户余额
|
||||
if apiKey.User.Balance <= 0 {
|
||||
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 将API key和用户信息存入上下文
|
||||
c.Set(string(ContextKeyApiKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// GetApiKeyFromContext 从上下文中获取API key
|
||||
func GetApiKeyFromContext(c *gin.Context) (*service.ApiKey, bool) {
|
||||
value, exists := c.Get(string(ContextKeyApiKey))
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
apiKey, ok := value.(*service.ApiKey)
|
||||
return apiKey, ok
|
||||
}
|
||||
|
||||
// GetSubscriptionFromContext 从上下文中获取订阅信息
|
||||
func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool) {
|
||||
value, exists := c.Get(string(ContextKeySubscription))
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
subscription, ok := value.(*service.UserSubscription)
|
||||
return subscription, ok
|
||||
}
|
||||
|
||||
@@ -1,137 +1,137 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ApiKeyAuthGoogle is a Google-style error wrapper for API key auth.
|
||||
func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService, cfg *config.Config) gin.HandlerFunc {
|
||||
return ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
|
||||
}
|
||||
|
||||
// ApiKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors:
|
||||
// {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}}
|
||||
//
|
||||
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
|
||||
func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
apiKeyString := extractAPIKeyFromRequest(c)
|
||||
if apiKeyString == "" {
|
||||
abortWithGoogleError(c, 401, "API key is required")
|
||||
return
|
||||
}
|
||||
|
||||
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrApiKeyNotFound) {
|
||||
abortWithGoogleError(c, 401, "Invalid API key")
|
||||
return
|
||||
}
|
||||
abortWithGoogleError(c, 500, "Failed to validate API key")
|
||||
return
|
||||
}
|
||||
|
||||
if !apiKey.IsActive() {
|
||||
abortWithGoogleError(c, 401, "API key is disabled")
|
||||
return
|
||||
}
|
||||
if apiKey.User == nil {
|
||||
abortWithGoogleError(c, 401, "User associated with API key not found")
|
||||
return
|
||||
}
|
||||
if !apiKey.User.IsActive() {
|
||||
abortWithGoogleError(c, 401, "User account is not active")
|
||||
return
|
||||
}
|
||||
|
||||
// 简易模式:跳过余额和订阅检查
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
c.Set(string(ContextKeyApiKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
if isSubscriptionType && subscriptionService != nil {
|
||||
subscription, err := subscriptionService.GetActiveSubscription(
|
||||
c.Request.Context(),
|
||||
apiKey.User.ID,
|
||||
apiKey.Group.ID,
|
||||
)
|
||||
if err != nil {
|
||||
abortWithGoogleError(c, 403, "No active subscription found for this group")
|
||||
return
|
||||
}
|
||||
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
|
||||
abortWithGoogleError(c, 403, err.Error())
|
||||
return
|
||||
}
|
||||
_ = subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription)
|
||||
_ = subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription)
|
||||
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
|
||||
abortWithGoogleError(c, 429, err.Error())
|
||||
return
|
||||
}
|
||||
c.Set(string(ContextKeySubscription), subscription)
|
||||
} else {
|
||||
if apiKey.User.Balance <= 0 {
|
||||
abortWithGoogleError(c, 403, "Insufficient account balance")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Set(string(ContextKeyApiKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func extractAPIKeyFromRequest(c *gin.Context) string {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" {
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) == 2 && parts[0] == "Bearer" && strings.TrimSpace(parts[1]) != "" {
|
||||
return strings.TrimSpace(parts[1])
|
||||
}
|
||||
}
|
||||
if v := strings.TrimSpace(c.GetHeader("x-api-key")); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(c.GetHeader("x-goog-api-key")); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(c.Query("key")); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(c.Query("api_key")); v != "" {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func abortWithGoogleError(c *gin.Context, status int, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"code": status,
|
||||
"message": message,
|
||||
"status": googleapi.HTTPStatusToGoogleStatus(status),
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ApiKeyAuthGoogle is a Google-style error wrapper for API key auth.
|
||||
func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService, cfg *config.Config) gin.HandlerFunc {
|
||||
return ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
|
||||
}
|
||||
|
||||
// ApiKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors:
|
||||
// {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}}
|
||||
//
|
||||
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
|
||||
func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
apiKeyString := extractAPIKeyFromRequest(c)
|
||||
if apiKeyString == "" {
|
||||
abortWithGoogleError(c, 401, "API key is required")
|
||||
return
|
||||
}
|
||||
|
||||
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrApiKeyNotFound) {
|
||||
abortWithGoogleError(c, 401, "Invalid API key")
|
||||
return
|
||||
}
|
||||
abortWithGoogleError(c, 500, "Failed to validate API key")
|
||||
return
|
||||
}
|
||||
|
||||
if !apiKey.IsActive() {
|
||||
abortWithGoogleError(c, 401, "API key is disabled")
|
||||
return
|
||||
}
|
||||
if apiKey.User == nil {
|
||||
abortWithGoogleError(c, 401, "User associated with API key not found")
|
||||
return
|
||||
}
|
||||
if !apiKey.User.IsActive() {
|
||||
abortWithGoogleError(c, 401, "User account is not active")
|
||||
return
|
||||
}
|
||||
|
||||
// 简易模式:跳过余额和订阅检查
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
c.Set(string(ContextKeyApiKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
|
||||
if isSubscriptionType && subscriptionService != nil {
|
||||
subscription, err := subscriptionService.GetActiveSubscription(
|
||||
c.Request.Context(),
|
||||
apiKey.User.ID,
|
||||
apiKey.Group.ID,
|
||||
)
|
||||
if err != nil {
|
||||
abortWithGoogleError(c, 403, "No active subscription found for this group")
|
||||
return
|
||||
}
|
||||
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
|
||||
abortWithGoogleError(c, 403, err.Error())
|
||||
return
|
||||
}
|
||||
_ = subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription)
|
||||
_ = subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription)
|
||||
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
|
||||
abortWithGoogleError(c, 429, err.Error())
|
||||
return
|
||||
}
|
||||
c.Set(string(ContextKeySubscription), subscription)
|
||||
} else {
|
||||
if apiKey.User.Balance <= 0 {
|
||||
abortWithGoogleError(c, 403, "Insufficient account balance")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Set(string(ContextKeyApiKey), apiKey)
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: apiKey.User.ID,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func extractAPIKeyFromRequest(c *gin.Context) string {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" {
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) == 2 && parts[0] == "Bearer" && strings.TrimSpace(parts[1]) != "" {
|
||||
return strings.TrimSpace(parts[1])
|
||||
}
|
||||
}
|
||||
if v := strings.TrimSpace(c.GetHeader("x-api-key")); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(c.GetHeader("x-goog-api-key")); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(c.Query("key")); v != "" {
|
||||
return v
|
||||
}
|
||||
if v := strings.TrimSpace(c.Query("api_key")); v != "" {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func abortWithGoogleError(c *gin.Context, status int, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
"error": gin.H{
|
||||
"code": status,
|
||||
"message": message,
|
||||
"status": googleapi.HTTPStatusToGoogleStatus(status),
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
@@ -1,227 +1,227 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeApiKeyRepo struct {
|
||||
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
|
||||
}
|
||||
|
||||
func (f fakeApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
if f.getByKey == nil {
|
||||
return nil, errors.New("unexpected call")
|
||||
}
|
||||
return f.getByKey(ctx, key)
|
||||
}
|
||||
func (f fakeApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type googleErrorResponse struct {
|
||||
Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func newTestApiKeyService(repo service.ApiKeyRepository) *service.ApiKeyService {
|
||||
return service.NewApiKeyService(
|
||||
repo,
|
||||
nil, // userRepo (unused in GetByKey)
|
||||
nil, // groupRepo
|
||||
nil, // userSubRepo
|
||||
nil, // cache
|
||||
&config.Config{},
|
||||
)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return nil, errors.New("should not be called")
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusUnauthorized, resp.Error.Code)
|
||||
require.Equal(t, "API key is required", resp.Error.Message)
|
||||
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer invalid")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusUnauthorized, resp.Error.Code)
|
||||
require.Equal(t, "Invalid API key", resp.Error.Message)
|
||||
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return nil, errors.New("db down")
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer any")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusInternalServerError, resp.Error.Code)
|
||||
require.Equal(t, "Failed to validate API key", resp.Error.Message)
|
||||
require.Equal(t, "INTERNAL", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return &service.ApiKey{
|
||||
ID: 1,
|
||||
Key: key,
|
||||
Status: service.StatusDisabled,
|
||||
User: &service.User{
|
||||
ID: 123,
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer disabled")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusUnauthorized, resp.Error.Code)
|
||||
require.Equal(t, "API key is disabled", resp.Error.Message)
|
||||
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return &service.ApiKey{
|
||||
ID: 1,
|
||||
Key: key,
|
||||
Status: service.StatusActive,
|
||||
User: &service.User{
|
||||
ID: 123,
|
||||
Status: service.StatusActive,
|
||||
Balance: 0,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer ok")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusForbidden, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusForbidden, resp.Error.Code)
|
||||
require.Equal(t, "Insufficient account balance", resp.Error.Message)
|
||||
require.Equal(t, "PERMISSION_DENIED", resp.Error.Status)
|
||||
}
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type fakeApiKeyRepo struct {
|
||||
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
|
||||
}
|
||||
|
||||
func (f fakeApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
if f.getByKey == nil {
|
||||
return nil, errors.New("unexpected call")
|
||||
}
|
||||
return f.getByKey(ctx, key)
|
||||
}
|
||||
func (f fakeApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
func (f fakeApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type googleErrorResponse struct {
|
||||
Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func newTestApiKeyService(repo service.ApiKeyRepository) *service.ApiKeyService {
|
||||
return service.NewApiKeyService(
|
||||
repo,
|
||||
nil, // userRepo (unused in GetByKey)
|
||||
nil, // groupRepo
|
||||
nil, // userSubRepo
|
||||
nil, // cache
|
||||
&config.Config{},
|
||||
)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return nil, errors.New("should not be called")
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusUnauthorized, resp.Error.Code)
|
||||
require.Equal(t, "API key is required", resp.Error.Message)
|
||||
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer invalid")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusUnauthorized, resp.Error.Code)
|
||||
require.Equal(t, "Invalid API key", resp.Error.Message)
|
||||
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return nil, errors.New("db down")
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer any")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusInternalServerError, resp.Error.Code)
|
||||
require.Equal(t, "Failed to validate API key", resp.Error.Message)
|
||||
require.Equal(t, "INTERNAL", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return &service.ApiKey{
|
||||
ID: 1,
|
||||
Key: key,
|
||||
Status: service.StatusDisabled,
|
||||
User: &service.User{
|
||||
ID: 123,
|
||||
Status: service.StatusActive,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer disabled")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusUnauthorized, resp.Error.Code)
|
||||
require.Equal(t, "API key is disabled", resp.Error.Message)
|
||||
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
r := gin.New()
|
||||
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
return &service.ApiKey{
|
||||
ID: 1,
|
||||
Key: key,
|
||||
Status: service.StatusActive,
|
||||
User: &service.User{
|
||||
ID: 123,
|
||||
Status: service.StatusActive,
|
||||
Balance: 0,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer ok")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusForbidden, rec.Code)
|
||||
var resp googleErrorResponse
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
require.Equal(t, http.StatusForbidden, resp.Error.Code)
|
||||
require.Equal(t, "Insufficient account balance", resp.Error.Message)
|
||||
require.Equal(t, "PERMISSION_DENIED", resp.Error.Status)
|
||||
}
|
||||
|
||||
@@ -1,290 +1,290 @@
|
||||
//go:build unit
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
limit := 1.0
|
||||
group := &service.Group{
|
||||
ID: 42,
|
||||
Name: "sub",
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeSubscription,
|
||||
DailyLimitUSD: &limit,
|
||||
}
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.ApiKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
Group: group,
|
||||
}
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
|
||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeStandard}
|
||||
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
|
||||
now := time.Now()
|
||||
sub := &service.UserSubscription{
|
||||
ID: 55,
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
DailyWindowStart: &now,
|
||||
DailyUsageUSD: 10,
|
||||
}
|
||||
subscriptionRepo := &stubUserSubscriptionRepo{
|
||||
getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
if userID != sub.UserID || groupID != sub.GroupID {
|
||||
return nil, service.ErrSubscriptionNotFound
|
||||
}
|
||||
clone := *sub
|
||||
return &clone, nil
|
||||
},
|
||||
updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil },
|
||||
activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetDaily: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
}
|
||||
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil)
|
||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||
require.Contains(t, w.Body.String(), "USAGE_LIMIT_EXCEEDED")
|
||||
})
|
||||
}
|
||||
|
||||
func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
return router
|
||||
}
|
||||
|
||||
type stubApiKeyRepo struct {
|
||||
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
if r.getByKey != nil {
|
||||
return r.getByKey(ctx, key)
|
||||
}
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type stubUserSubscriptionRepo struct {
|
||||
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
|
||||
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
|
||||
activateWindow func(ctx context.Context, id int64, start time.Time) error
|
||||
resetDaily func(ctx context.Context, id int64, start time.Time) error
|
||||
resetWeekly func(ctx context.Context, id int64, start time.Time) error
|
||||
resetMonthly func(ctx context.Context, id int64, start time.Time) error
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
if r.getActive != nil {
|
||||
return r.getActive(ctx, userID, groupID)
|
||||
}
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) Update(ctx context.Context, sub *service.UserSubscription) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
|
||||
if r.updateStatus != nil {
|
||||
return r.updateStatus(ctx, subscriptionID, status)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
|
||||
if r.activateWindow != nil {
|
||||
return r.activateWindow(ctx, id, start)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
if r.resetDaily != nil {
|
||||
return r.resetDaily(ctx, id, newWindowStart)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
if r.resetWeekly != nil {
|
||||
return r.resetWeekly(ctx, id, newWindowStart)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
if r.resetMonthly != nil {
|
||||
return r.resetMonthly(ctx, id, newWindowStart)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
//go:build unit
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
limit := 1.0
|
||||
group := &service.Group{
|
||||
ID: 42,
|
||||
Name: "sub",
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeSubscription,
|
||||
DailyLimitUSD: &limit,
|
||||
}
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.ApiKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
Group: group,
|
||||
}
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrApiKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
|
||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
})
|
||||
|
||||
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
|
||||
cfg := &config.Config{RunMode: config.RunModeStandard}
|
||||
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
|
||||
now := time.Now()
|
||||
sub := &service.UserSubscription{
|
||||
ID: 55,
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
DailyWindowStart: &now,
|
||||
DailyUsageUSD: 10,
|
||||
}
|
||||
subscriptionRepo := &stubUserSubscriptionRepo{
|
||||
getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
if userID != sub.UserID || groupID != sub.GroupID {
|
||||
return nil, service.ErrSubscriptionNotFound
|
||||
}
|
||||
clone := *sub
|
||||
return &clone, nil
|
||||
},
|
||||
updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil },
|
||||
activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetDaily: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
|
||||
}
|
||||
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil)
|
||||
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||
require.Contains(t, w.Body.String(), "USAGE_LIMIT_EXCEEDED")
|
||||
})
|
||||
}
|
||||
|
||||
func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
return router
|
||||
}
|
||||
|
||||
type stubApiKeyRepo struct {
|
||||
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
|
||||
if r.getByKey != nil {
|
||||
return r.getByKey(ctx, key)
|
||||
}
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type stubUserSubscriptionRepo struct {
|
||||
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
|
||||
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
|
||||
activateWindow func(ctx context.Context, id int64, start time.Time) error
|
||||
resetDaily func(ctx context.Context, id int64, start time.Time) error
|
||||
resetWeekly func(ctx context.Context, id int64, start time.Time) error
|
||||
resetMonthly func(ctx context.Context, id int64, start time.Time) error
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
|
||||
if r.getActive != nil {
|
||||
return r.getActive(ctx, userID, groupID)
|
||||
}
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) Update(ctx context.Context, sub *service.UserSubscription) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
|
||||
if r.updateStatus != nil {
|
||||
return r.updateStatus(ctx, subscriptionID, status)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
|
||||
if r.activateWindow != nil {
|
||||
return r.activateWindow(ctx, id, start)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
if r.resetDaily != nil {
|
||||
return r.resetDaily(ctx, id, newWindowStart)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
if r.resetWeekly != nil {
|
||||
return r.resetWeekly(ctx, id, newWindowStart)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
|
||||
if r.resetMonthly != nil {
|
||||
return r.resetMonthly(ctx, id, newWindowStart)
|
||||
}
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (r *stubUserSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -1,28 +1,28 @@
|
||||
package middleware
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
// AuthSubject is the minimal authenticated identity stored in gin context.
|
||||
// Decision: {UserID int64, Concurrency int}
|
||||
type AuthSubject struct {
|
||||
UserID int64
|
||||
Concurrency int
|
||||
}
|
||||
|
||||
func GetAuthSubjectFromContext(c *gin.Context) (AuthSubject, bool) {
|
||||
value, exists := c.Get(string(ContextKeyUser))
|
||||
if !exists {
|
||||
return AuthSubject{}, false
|
||||
}
|
||||
subject, ok := value.(AuthSubject)
|
||||
return subject, ok
|
||||
}
|
||||
|
||||
func GetUserRoleFromContext(c *gin.Context) (string, bool) {
|
||||
value, exists := c.Get(string(ContextKeyUserRole))
|
||||
if !exists {
|
||||
return "", false
|
||||
}
|
||||
role, ok := value.(string)
|
||||
return role, ok
|
||||
}
|
||||
package middleware
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
// AuthSubject is the minimal authenticated identity stored in gin context.
|
||||
// Decision: {UserID int64, Concurrency int}
|
||||
type AuthSubject struct {
|
||||
UserID int64
|
||||
Concurrency int
|
||||
}
|
||||
|
||||
func GetAuthSubjectFromContext(c *gin.Context) (AuthSubject, bool) {
|
||||
value, exists := c.Get(string(ContextKeyUser))
|
||||
if !exists {
|
||||
return AuthSubject{}, false
|
||||
}
|
||||
subject, ok := value.(AuthSubject)
|
||||
return subject, ok
|
||||
}
|
||||
|
||||
func GetUserRoleFromContext(c *gin.Context) (string, bool) {
|
||||
value, exists := c.Get(string(ContextKeyUserRole))
|
||||
if !exists {
|
||||
return "", false
|
||||
}
|
||||
role, ok := value.(string)
|
||||
return role, ok
|
||||
}
|
||||
|
||||
@@ -1,24 +1,24 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORS 跨域中间件
|
||||
func CORS() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 设置允许跨域的响应头
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
|
||||
|
||||
// 处理预检请求
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORS 跨域中间件
|
||||
func CORS() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 设置允许跨域的响应头
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
|
||||
|
||||
// 处理预检请求
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,81 +1,81 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// NewJWTAuthMiddleware 创建 JWT 认证中间件
|
||||
func NewJWTAuthMiddleware(authService *service.AuthService, userService *service.UserService) JWTAuthMiddleware {
|
||||
return JWTAuthMiddleware(jwtAuth(authService, userService))
|
||||
}
|
||||
|
||||
// jwtAuth JWT认证中间件实现
|
||||
func jwtAuth(authService *service.AuthService, userService *service.UserService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 从Authorization header中提取token
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization header is required")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证Bearer scheme
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
AbortWithError(c, 401, "INVALID_AUTH_HEADER", "Authorization header format must be 'Bearer {token}'")
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := parts[1]
|
||||
if tokenString == "" {
|
||||
AbortWithError(c, 401, "EMPTY_TOKEN", "Token cannot be empty")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证token
|
||||
claims, err := authService.ValidateToken(tokenString)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrTokenExpired) {
|
||||
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
|
||||
return
|
||||
}
|
||||
AbortWithError(c, 401, "INVALID_TOKEN", "Invalid token")
|
||||
return
|
||||
}
|
||||
|
||||
// 从数据库获取最新的用户信息
|
||||
user, err := userService.GetByID(c.Request.Context(), claims.UserID)
|
||||
if err != nil {
|
||||
AbortWithError(c, 401, "USER_NOT_FOUND", "User not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !user.IsActive() {
|
||||
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
|
||||
return
|
||||
}
|
||||
|
||||
// Security: Validate TokenVersion to ensure token hasn't been invalidated
|
||||
// This check ensures tokens issued before a password change are rejected
|
||||
if claims.TokenVersion != user.TokenVersion {
|
||||
AbortWithError(c, 401, "TOKEN_REVOKED", "Token has been revoked (password changed)")
|
||||
return
|
||||
}
|
||||
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: user.ID,
|
||||
Concurrency: user.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), user.Role)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// Deprecated: prefer GetAuthSubjectFromContext in auth_subject.go.
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// NewJWTAuthMiddleware 创建 JWT 认证中间件
|
||||
func NewJWTAuthMiddleware(authService *service.AuthService, userService *service.UserService) JWTAuthMiddleware {
|
||||
return JWTAuthMiddleware(jwtAuth(authService, userService))
|
||||
}
|
||||
|
||||
// jwtAuth JWT认证中间件实现
|
||||
func jwtAuth(authService *service.AuthService, userService *service.UserService) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 从Authorization header中提取token
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization header is required")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证Bearer scheme
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
AbortWithError(c, 401, "INVALID_AUTH_HEADER", "Authorization header format must be 'Bearer {token}'")
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := parts[1]
|
||||
if tokenString == "" {
|
||||
AbortWithError(c, 401, "EMPTY_TOKEN", "Token cannot be empty")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证token
|
||||
claims, err := authService.ValidateToken(tokenString)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrTokenExpired) {
|
||||
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
|
||||
return
|
||||
}
|
||||
AbortWithError(c, 401, "INVALID_TOKEN", "Invalid token")
|
||||
return
|
||||
}
|
||||
|
||||
// 从数据库获取最新的用户信息
|
||||
user, err := userService.GetByID(c.Request.Context(), claims.UserID)
|
||||
if err != nil {
|
||||
AbortWithError(c, 401, "USER_NOT_FOUND", "User not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !user.IsActive() {
|
||||
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
|
||||
return
|
||||
}
|
||||
|
||||
// Security: Validate TokenVersion to ensure token hasn't been invalidated
|
||||
// This check ensures tokens issued before a password change are rejected
|
||||
if claims.TokenVersion != user.TokenVersion {
|
||||
AbortWithError(c, 401, "TOKEN_REVOKED", "Token has been revoked (password changed)")
|
||||
return
|
||||
}
|
||||
|
||||
c.Set(string(ContextKeyUser), AuthSubject{
|
||||
UserID: user.ID,
|
||||
Concurrency: user.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), user.Role)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// Deprecated: prefer GetAuthSubjectFromContext in auth_subject.go.
|
||||
|
||||
@@ -1,52 +1,52 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Logger 请求日志中间件
|
||||
func Logger() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 开始时间
|
||||
startTime := time.Now()
|
||||
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// 结束时间
|
||||
endTime := time.Now()
|
||||
|
||||
// 执行时间
|
||||
latency := endTime.Sub(startTime)
|
||||
|
||||
// 请求方法
|
||||
method := c.Request.Method
|
||||
|
||||
// 请求路径
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// 状态码
|
||||
statusCode := c.Writer.Status()
|
||||
|
||||
// 客户端IP
|
||||
clientIP := c.ClientIP()
|
||||
|
||||
// 日志格式: [时间] 状态码 | 延迟 | IP | 方法 路径
|
||||
log.Printf("[GIN] %v | %3d | %13v | %15s | %-7s %s",
|
||||
endTime.Format("2006/01/02 - 15:04:05"),
|
||||
statusCode,
|
||||
latency,
|
||||
clientIP,
|
||||
method,
|
||||
path,
|
||||
)
|
||||
|
||||
// 如果有错误,额外记录错误信息
|
||||
if len(c.Errors) > 0 {
|
||||
log.Printf("[GIN] Errors: %v", c.Errors.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Logger 请求日志中间件
|
||||
func Logger() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 开始时间
|
||||
startTime := time.Now()
|
||||
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// 结束时间
|
||||
endTime := time.Now()
|
||||
|
||||
// 执行时间
|
||||
latency := endTime.Sub(startTime)
|
||||
|
||||
// 请求方法
|
||||
method := c.Request.Method
|
||||
|
||||
// 请求路径
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// 状态码
|
||||
statusCode := c.Writer.Status()
|
||||
|
||||
// 客户端IP
|
||||
clientIP := c.ClientIP()
|
||||
|
||||
// 日志格式: [时间] 状态码 | 延迟 | IP | 方法 路径
|
||||
log.Printf("[GIN] %v | %3d | %13v | %15s | %-7s %s",
|
||||
endTime.Format("2006/01/02 - 15:04:05"),
|
||||
statusCode,
|
||||
latency,
|
||||
clientIP,
|
||||
method,
|
||||
path,
|
||||
)
|
||||
|
||||
// 如果有错误,额外记录错误信息
|
||||
if len(c.Errors) > 0 {
|
||||
log.Printf("[GIN] Errors: %v", c.Errors.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,73 +1,73 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ContextKey 定义上下文键类型
|
||||
type ContextKey string
|
||||
|
||||
const (
|
||||
// ContextKeyUser 用户上下文键
|
||||
ContextKeyUser ContextKey = "user"
|
||||
// ContextKeyUserRole 当前用户角色(string)
|
||||
ContextKeyUserRole ContextKey = "user_role"
|
||||
// ContextKeyApiKey API密钥上下文键
|
||||
ContextKeyApiKey ContextKey = "api_key"
|
||||
// ContextKeySubscription 订阅上下文键
|
||||
ContextKeySubscription ContextKey = "subscription"
|
||||
// ContextKeyForcePlatform 强制平台(用于 /antigravity 路由)
|
||||
ContextKeyForcePlatform ContextKey = "force_platform"
|
||||
)
|
||||
|
||||
// ForcePlatform 返回设置强制平台的中间件
|
||||
// 同时设置 request.Context(供 Service 使用)和 gin.Context(供 Handler 快速检查)
|
||||
func ForcePlatform(platform string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 设置到 request.Context,使用 ctxkey.ForcePlatform 供 Service 层读取
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, platform)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
// 同时设置到 gin.Context,供 Handler 快速检查
|
||||
c.Set(string(ContextKeyForcePlatform), platform)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// HasForcePlatform 检查是否有强制平台(用于 Handler 跳过分组检查)
|
||||
func HasForcePlatform(c *gin.Context) bool {
|
||||
_, exists := c.Get(string(ContextKeyForcePlatform))
|
||||
return exists
|
||||
}
|
||||
|
||||
// GetForcePlatformFromContext 从 gin.Context 获取强制平台
|
||||
func GetForcePlatformFromContext(c *gin.Context) (string, bool) {
|
||||
value, exists := c.Get(string(ContextKeyForcePlatform))
|
||||
if !exists {
|
||||
return "", false
|
||||
}
|
||||
platform, ok := value.(string)
|
||||
return platform, ok
|
||||
}
|
||||
|
||||
// ErrorResponse 标准错误响应结构
|
||||
type ErrorResponse struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// NewErrorResponse 创建错误响应
|
||||
func NewErrorResponse(code, message string) ErrorResponse {
|
||||
return ErrorResponse{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
// AbortWithError 中断请求并返回JSON错误
|
||||
func AbortWithError(c *gin.Context, statusCode int, code, message string) {
|
||||
c.JSON(statusCode, NewErrorResponse(code, message))
|
||||
c.Abort()
|
||||
}
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ContextKey 定义上下文键类型
|
||||
type ContextKey string
|
||||
|
||||
const (
|
||||
// ContextKeyUser 用户上下文键
|
||||
ContextKeyUser ContextKey = "user"
|
||||
// ContextKeyUserRole 当前用户角色(string)
|
||||
ContextKeyUserRole ContextKey = "user_role"
|
||||
// ContextKeyApiKey API密钥上下文键
|
||||
ContextKeyApiKey ContextKey = "api_key"
|
||||
// ContextKeySubscription 订阅上下文键
|
||||
ContextKeySubscription ContextKey = "subscription"
|
||||
// ContextKeyForcePlatform 强制平台(用于 /antigravity 路由)
|
||||
ContextKeyForcePlatform ContextKey = "force_platform"
|
||||
)
|
||||
|
||||
// ForcePlatform 返回设置强制平台的中间件
|
||||
// 同时设置 request.Context(供 Service 使用)和 gin.Context(供 Handler 快速检查)
|
||||
func ForcePlatform(platform string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 设置到 request.Context,使用 ctxkey.ForcePlatform 供 Service 层读取
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, platform)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
// 同时设置到 gin.Context,供 Handler 快速检查
|
||||
c.Set(string(ContextKeyForcePlatform), platform)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// HasForcePlatform 检查是否有强制平台(用于 Handler 跳过分组检查)
|
||||
func HasForcePlatform(c *gin.Context) bool {
|
||||
_, exists := c.Get(string(ContextKeyForcePlatform))
|
||||
return exists
|
||||
}
|
||||
|
||||
// GetForcePlatformFromContext 从 gin.Context 获取强制平台
|
||||
func GetForcePlatformFromContext(c *gin.Context) (string, bool) {
|
||||
value, exists := c.Get(string(ContextKeyForcePlatform))
|
||||
if !exists {
|
||||
return "", false
|
||||
}
|
||||
platform, ok := value.(string)
|
||||
return platform, ok
|
||||
}
|
||||
|
||||
// ErrorResponse 标准错误响应结构
|
||||
type ErrorResponse struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// NewErrorResponse 创建错误响应
|
||||
func NewErrorResponse(code, message string) ErrorResponse {
|
||||
return ErrorResponse{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
// AbortWithError 中断请求并返回JSON错误
|
||||
func AbortWithError(c *gin.Context, statusCode int, code, message string) {
|
||||
c.JSON(statusCode, NewErrorResponse(code, message))
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
@@ -1,64 +1,64 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Recovery converts panics into the project's standard JSON error envelope.
|
||||
//
|
||||
// It preserves Gin's broken-pipe handling by not attempting to write a response
|
||||
// when the client connection is already gone.
|
||||
func Recovery() gin.HandlerFunc {
|
||||
return gin.CustomRecoveryWithWriter(gin.DefaultErrorWriter, func(c *gin.Context, recovered any) {
|
||||
recoveredErr, _ := recovered.(error)
|
||||
|
||||
if isBrokenPipe(recoveredErr) {
|
||||
if recoveredErr != nil {
|
||||
_ = c.Error(recoveredErr)
|
||||
}
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if c.Writer.Written() {
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
response.ErrorWithDetails(
|
||||
c,
|
||||
http.StatusInternalServerError,
|
||||
infraerrors.UnknownMessage,
|
||||
infraerrors.UnknownReason,
|
||||
nil,
|
||||
)
|
||||
c.Abort()
|
||||
})
|
||||
}
|
||||
|
||||
func isBrokenPipe(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var opErr *net.OpError
|
||||
if !errors.As(err, &opErr) {
|
||||
return false
|
||||
}
|
||||
|
||||
var syscallErr *os.SyscallError
|
||||
if !errors.As(opErr.Err, &syscallErr) {
|
||||
return false
|
||||
}
|
||||
|
||||
msg := strings.ToLower(syscallErr.Error())
|
||||
return strings.Contains(msg, "broken pipe") || strings.Contains(msg, "connection reset by peer")
|
||||
}
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Recovery converts panics into the project's standard JSON error envelope.
|
||||
//
|
||||
// It preserves Gin's broken-pipe handling by not attempting to write a response
|
||||
// when the client connection is already gone.
|
||||
func Recovery() gin.HandlerFunc {
|
||||
return gin.CustomRecoveryWithWriter(gin.DefaultErrorWriter, func(c *gin.Context, recovered any) {
|
||||
recoveredErr, _ := recovered.(error)
|
||||
|
||||
if isBrokenPipe(recoveredErr) {
|
||||
if recoveredErr != nil {
|
||||
_ = c.Error(recoveredErr)
|
||||
}
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if c.Writer.Written() {
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
response.ErrorWithDetails(
|
||||
c,
|
||||
http.StatusInternalServerError,
|
||||
infraerrors.UnknownMessage,
|
||||
infraerrors.UnknownReason,
|
||||
nil,
|
||||
)
|
||||
c.Abort()
|
||||
})
|
||||
}
|
||||
|
||||
func isBrokenPipe(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var opErr *net.OpError
|
||||
if !errors.As(err, &opErr) {
|
||||
return false
|
||||
}
|
||||
|
||||
var syscallErr *os.SyscallError
|
||||
if !errors.As(opErr.Err, &syscallErr) {
|
||||
return false
|
||||
}
|
||||
|
||||
msg := strings.ToLower(syscallErr.Error())
|
||||
return strings.Contains(msg, "broken pipe") || strings.Contains(msg, "connection reset by peer")
|
||||
}
|
||||
|
||||
@@ -1,81 +1,81 @@
|
||||
//go:build unit
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRecovery(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
handler gin.HandlerFunc
|
||||
wantHTTPCode int
|
||||
wantBody response.Response
|
||||
}{
|
||||
{
|
||||
name: "panic_returns_standard_json_500",
|
||||
handler: func(c *gin.Context) {
|
||||
panic("boom")
|
||||
},
|
||||
wantHTTPCode: http.StatusInternalServerError,
|
||||
wantBody: response.Response{
|
||||
Code: http.StatusInternalServerError,
|
||||
Message: infraerrors.UnknownMessage,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no_panic_passthrough",
|
||||
handler: func(c *gin.Context) {
|
||||
response.Success(c, gin.H{"ok": true})
|
||||
},
|
||||
wantHTTPCode: http.StatusOK,
|
||||
wantBody: response.Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: map[string]any{"ok": true},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "panic_after_write_does_not_override_body",
|
||||
handler: func(c *gin.Context) {
|
||||
response.Success(c, gin.H{"ok": true})
|
||||
panic("boom")
|
||||
},
|
||||
wantHTTPCode: http.StatusOK,
|
||||
wantBody: response.Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: map[string]any{"ok": true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(Recovery())
|
||||
r.GET("/t", tt.handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, tt.wantHTTPCode, w.Code)
|
||||
|
||||
var got response.Response
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
|
||||
require.Equal(t, tt.wantBody, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
//go:build unit
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRecovery(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
handler gin.HandlerFunc
|
||||
wantHTTPCode int
|
||||
wantBody response.Response
|
||||
}{
|
||||
{
|
||||
name: "panic_returns_standard_json_500",
|
||||
handler: func(c *gin.Context) {
|
||||
panic("boom")
|
||||
},
|
||||
wantHTTPCode: http.StatusInternalServerError,
|
||||
wantBody: response.Response{
|
||||
Code: http.StatusInternalServerError,
|
||||
Message: infraerrors.UnknownMessage,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no_panic_passthrough",
|
||||
handler: func(c *gin.Context) {
|
||||
response.Success(c, gin.H{"ok": true})
|
||||
},
|
||||
wantHTTPCode: http.StatusOK,
|
||||
wantBody: response.Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: map[string]any{"ok": true},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "panic_after_write_does_not_override_body",
|
||||
handler: func(c *gin.Context) {
|
||||
response.Success(c, gin.H{"ok": true})
|
||||
panic("boom")
|
||||
},
|
||||
wantHTTPCode: http.StatusOK,
|
||||
wantBody: response.Response{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
Data: map[string]any{"ok": true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := gin.New()
|
||||
r.Use(Recovery())
|
||||
r.GET("/t", tt.handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, tt.wantHTTPCode, w.Code)
|
||||
|
||||
var got response.Response
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
|
||||
require.Equal(t, tt.wantBody, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RequestBodyLimit 使用 MaxBytesReader 限制请求体大小。
|
||||
func RequestBodyLimit(maxBytes int64) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxBytes)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RequestBodyLimit 使用 MaxBytesReader 限制请求体大小。
|
||||
func RequestBodyLimit(maxBytes int64) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxBytes)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
// JWTAuthMiddleware JWT 认证中间件类型
|
||||
type JWTAuthMiddleware gin.HandlerFunc
|
||||
|
||||
// AdminAuthMiddleware 管理员认证中间件类型
|
||||
type AdminAuthMiddleware gin.HandlerFunc
|
||||
|
||||
// ApiKeyAuthMiddleware API Key 认证中间件类型
|
||||
type ApiKeyAuthMiddleware gin.HandlerFunc
|
||||
|
||||
// ProviderSet 中间件层的依赖注入
|
||||
var ProviderSet = wire.NewSet(
|
||||
NewJWTAuthMiddleware,
|
||||
NewAdminAuthMiddleware,
|
||||
NewApiKeyAuthMiddleware,
|
||||
)
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/wire"
|
||||
)
|
||||
|
||||
// JWTAuthMiddleware JWT 认证中间件类型
|
||||
type JWTAuthMiddleware gin.HandlerFunc
|
||||
|
||||
// AdminAuthMiddleware 管理员认证中间件类型
|
||||
type AdminAuthMiddleware gin.HandlerFunc
|
||||
|
||||
// ApiKeyAuthMiddleware API Key 认证中间件类型
|
||||
type ApiKeyAuthMiddleware gin.HandlerFunc
|
||||
|
||||
// ProviderSet 中间件层的依赖注入
|
||||
var ProviderSet = wire.NewSet(
|
||||
NewJWTAuthMiddleware,
|
||||
NewAdminAuthMiddleware,
|
||||
NewApiKeyAuthMiddleware,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user