feat: 品牌重命名 Sub2API -> TianShuAPI
Some checks failed
CI / test (push) Has been cancelled
CI / golangci-lint (push) Has been cancelled

- 前端: 所有界面显示、i18n 文本、组件中的品牌名称
- 后端: 服务层、设置默认值、邮件模板、安装向导
- 数据库: 迁移脚本注释
- 保持功能完全一致,仅更改品牌名称

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
huangzhenpc
2026-01-04 17:50:29 +08:00
parent e27c1acf79
commit d274c8cb14
417 changed files with 112280 additions and 112280 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -1,54 +1,54 @@
package server
import (
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/google/wire"
)
// ProviderSet 提供服务器层的依赖
var ProviderSet = wire.NewSet(
ProvideRouter,
ProvideHTTPServer,
)
// ProvideRouter 提供路由器
func ProvideRouter(
cfg *config.Config,
handlers *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService,
subscriptionService *service.SubscriptionService,
) *gin.Engine {
if cfg.Server.Mode == "release" {
gin.SetMode(gin.ReleaseMode)
}
r := gin.New()
r.Use(middleware2.Recovery())
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
}
// ProvideHTTPServer 提供 HTTP 服务器
func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
return &http.Server{
Addr: cfg.Server.Address(),
Handler: router,
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源
IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second,
// 注意:不设置 WriteTimeout因为流式响应可能持续十几分钟
// 不设置 ReadTimeout因为大请求体可能需要较长时间读取
}
}
package server
import (
"net/http"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/google/wire"
)
// ProviderSet 提供服务器层的依赖
var ProviderSet = wire.NewSet(
ProvideRouter,
ProvideHTTPServer,
)
// ProvideRouter 提供路由器
func ProvideRouter(
cfg *config.Config,
handlers *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService,
subscriptionService *service.SubscriptionService,
) *gin.Engine {
if cfg.Server.Mode == "release" {
gin.SetMode(gin.ReleaseMode)
}
r := gin.New()
r.Use(middleware2.Recovery())
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
}
// ProvideHTTPServer 提供 HTTP 服务器
func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
return &http.Server{
Addr: cfg.Server.Address(),
Handler: router,
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源
IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second,
// 注意:不设置 WriteTimeout因为流式响应可能持续十几分钟
// 不设置 ReadTimeout因为大请求体可能需要较长时间读取
}
}

View File

@@ -1,140 +1,140 @@
package middleware
import (
"crypto/subtle"
"errors"
"strings"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// NewAdminAuthMiddleware 创建管理员认证中间件
func NewAdminAuthMiddleware(
authService *service.AuthService,
userService *service.UserService,
settingService *service.SettingService,
) AdminAuthMiddleware {
return AdminAuthMiddleware(adminAuth(authService, userService, settingService))
}
// adminAuth 管理员认证中间件实现
// 支持两种认证方式(通过不同的 header 区分):
// 1. Admin API Key: x-api-key: <admin-api-key>
// 2. JWT Token: Authorization: Bearer <jwt-token> (需要管理员角色)
func adminAuth(
authService *service.AuthService,
userService *service.UserService,
settingService *service.SettingService,
) gin.HandlerFunc {
return func(c *gin.Context) {
// 检查 x-api-key headerAdmin API Key 认证)
apiKey := c.GetHeader("x-api-key")
if apiKey != "" {
if !validateAdminApiKey(c, apiKey, settingService, userService) {
return
}
c.Next()
return
}
// 检查 Authorization headerJWT 认证)
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" {
if !validateJWTForAdmin(c, parts[1], authService, userService) {
return
}
c.Next()
return
}
}
// 无有效认证信息
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required")
}
}
// validateAdminApiKey 验证管理员 API Key
func validateAdminApiKey(
c *gin.Context,
key string,
settingService *service.SettingService,
userService *service.UserService,
) bool {
storedKey, err := settingService.GetAdminApiKey(c.Request.Context())
if err != nil {
AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error")
return false
}
// 未配置或不匹配,统一返回相同错误(避免信息泄露)
if storedKey == "" || subtle.ConstantTimeCompare([]byte(key), []byte(storedKey)) != 1 {
AbortWithError(c, 401, "INVALID_ADMIN_KEY", "Invalid admin API key")
return false
}
// 获取真实的管理员用户
admin, err := userService.GetFirstAdmin(c.Request.Context())
if err != nil {
AbortWithError(c, 500, "INTERNAL_ERROR", "No admin user found")
return false
}
c.Set(string(ContextKeyUser), AuthSubject{
UserID: admin.ID,
Concurrency: admin.Concurrency,
})
c.Set(string(ContextKeyUserRole), admin.Role)
c.Set("auth_method", "admin_api_key")
return true
}
// validateJWTForAdmin 验证 JWT 并检查管理员权限
func validateJWTForAdmin(
c *gin.Context,
token string,
authService *service.AuthService,
userService *service.UserService,
) bool {
// 验证 JWT token
claims, err := authService.ValidateToken(token)
if err != nil {
if errors.Is(err, service.ErrTokenExpired) {
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
return false
}
AbortWithError(c, 401, "INVALID_TOKEN", "Invalid token")
return false
}
// 从数据库获取用户
user, err := userService.GetByID(c.Request.Context(), claims.UserID)
if err != nil {
AbortWithError(c, 401, "USER_NOT_FOUND", "User not found")
return false
}
// 检查用户状态
if !user.IsActive() {
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
return false
}
// 检查管理员权限
if !user.IsAdmin() {
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
return false
}
c.Set(string(ContextKeyUser), AuthSubject{
UserID: user.ID,
Concurrency: user.Concurrency,
})
c.Set(string(ContextKeyUserRole), user.Role)
c.Set("auth_method", "jwt")
return true
}
package middleware
import (
"crypto/subtle"
"errors"
"strings"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// NewAdminAuthMiddleware 创建管理员认证中间件
func NewAdminAuthMiddleware(
authService *service.AuthService,
userService *service.UserService,
settingService *service.SettingService,
) AdminAuthMiddleware {
return AdminAuthMiddleware(adminAuth(authService, userService, settingService))
}
// adminAuth 管理员认证中间件实现
// 支持两种认证方式(通过不同的 header 区分):
// 1. Admin API Key: x-api-key: <admin-api-key>
// 2. JWT Token: Authorization: Bearer <jwt-token> (需要管理员角色)
func adminAuth(
authService *service.AuthService,
userService *service.UserService,
settingService *service.SettingService,
) gin.HandlerFunc {
return func(c *gin.Context) {
// 检查 x-api-key headerAdmin API Key 认证)
apiKey := c.GetHeader("x-api-key")
if apiKey != "" {
if !validateAdminApiKey(c, apiKey, settingService, userService) {
return
}
c.Next()
return
}
// 检查 Authorization headerJWT 认证)
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" {
if !validateJWTForAdmin(c, parts[1], authService, userService) {
return
}
c.Next()
return
}
}
// 无有效认证信息
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required")
}
}
// validateAdminApiKey 验证管理员 API Key
func validateAdminApiKey(
c *gin.Context,
key string,
settingService *service.SettingService,
userService *service.UserService,
) bool {
storedKey, err := settingService.GetAdminApiKey(c.Request.Context())
if err != nil {
AbortWithError(c, 500, "INTERNAL_ERROR", "Internal server error")
return false
}
// 未配置或不匹配,统一返回相同错误(避免信息泄露)
if storedKey == "" || subtle.ConstantTimeCompare([]byte(key), []byte(storedKey)) != 1 {
AbortWithError(c, 401, "INVALID_ADMIN_KEY", "Invalid admin API key")
return false
}
// 获取真实的管理员用户
admin, err := userService.GetFirstAdmin(c.Request.Context())
if err != nil {
AbortWithError(c, 500, "INTERNAL_ERROR", "No admin user found")
return false
}
c.Set(string(ContextKeyUser), AuthSubject{
UserID: admin.ID,
Concurrency: admin.Concurrency,
})
c.Set(string(ContextKeyUserRole), admin.Role)
c.Set("auth_method", "admin_api_key")
return true
}
// validateJWTForAdmin 验证 JWT 并检查管理员权限
func validateJWTForAdmin(
c *gin.Context,
token string,
authService *service.AuthService,
userService *service.UserService,
) bool {
// 验证 JWT token
claims, err := authService.ValidateToken(token)
if err != nil {
if errors.Is(err, service.ErrTokenExpired) {
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
return false
}
AbortWithError(c, 401, "INVALID_TOKEN", "Invalid token")
return false
}
// 从数据库获取用户
user, err := userService.GetByID(c.Request.Context(), claims.UserID)
if err != nil {
AbortWithError(c, 401, "USER_NOT_FOUND", "User not found")
return false
}
// 检查用户状态
if !user.IsActive() {
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
return false
}
// 检查管理员权限
if !user.IsAdmin() {
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
return false
}
c.Set(string(ContextKeyUser), AuthSubject{
UserID: user.ID,
Concurrency: user.Concurrency,
})
c.Set(string(ContextKeyUserRole), user.Role)
c.Set("auth_method", "jwt")
return true
}

View File

@@ -1,27 +1,27 @@
package middleware
import (
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AdminOnly 管理员权限中间件
// 必须在JWTAuth中间件之后使用
func AdminOnly() gin.HandlerFunc {
return func(c *gin.Context) {
role, ok := GetUserRoleFromContext(c)
if !ok {
AbortWithError(c, 401, "UNAUTHORIZED", "User not found in context")
return
}
// 检查是否为管理员
if role != service.RoleAdmin {
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
return
}
c.Next()
}
}
package middleware
import (
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// AdminOnly 管理员权限中间件
// 必须在JWTAuth中间件之后使用
func AdminOnly() gin.HandlerFunc {
return func(c *gin.Context) {
role, ok := GetUserRoleFromContext(c)
if !ok {
AbortWithError(c, 401, "UNAUTHORIZED", "User not found in context")
return
}
// 检查是否为管理员
if role != service.RoleAdmin {
AbortWithError(c, 403, "FORBIDDEN", "Admin access required")
return
}
c.Next()
}
}

View File

@@ -1,178 +1,178 @@
package middleware
import (
"errors"
"log"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// NewApiKeyAuthMiddleware 创建 API Key 认证中间件
func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) ApiKeyAuthMiddleware {
return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg))
}
// apiKeyAuthWithSubscription API Key认证中间件支持订阅验证
func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
// 尝试从Authorization header中提取API key (Bearer scheme)
authHeader := c.GetHeader("Authorization")
var apiKeyString string
if authHeader != "" {
// 验证Bearer scheme
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" {
apiKeyString = parts[1]
}
}
// 如果Authorization header中没有尝试从x-api-key header中提取
if apiKeyString == "" {
apiKeyString = c.GetHeader("x-api-key")
}
// 如果x-api-key header中没有尝试从x-goog-api-key header中提取Gemini CLI兼容
if apiKeyString == "" {
apiKeyString = c.GetHeader("x-goog-api-key")
}
// 如果header中没有尝试从query参数中提取Google API key风格
if apiKeyString == "" {
apiKeyString = c.Query("key")
}
// 兼容常见别名
if apiKeyString == "" {
apiKeyString = c.Query("api_key")
}
// 如果所有header都没有API key
if apiKeyString == "" {
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter")
return
}
// 从数据库验证API key
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil {
if errors.Is(err, service.ErrApiKeyNotFound) {
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
return
}
AbortWithError(c, 500, "INTERNAL_ERROR", "Failed to validate API key")
return
}
// 检查API key是否激活
if !apiKey.IsActive() {
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
return
}
// 检查关联的用户
if apiKey.User == nil {
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")
return
}
// 检查用户状态
if !apiKey.User.IsActive() {
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
return
}
if cfg.RunMode == config.RunModeSimple {
// 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
c.Next()
return
}
// 判断计费方式:订阅模式 vs 余额模式
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
if isSubscriptionType && subscriptionService != nil {
// 订阅模式:验证订阅
subscription, err := subscriptionService.GetActiveSubscription(
c.Request.Context(),
apiKey.User.ID,
apiKey.Group.ID,
)
if err != nil {
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
return
}
// 验证订阅状态(是否过期、暂停等)
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error())
return
}
// 激活滑动窗口(首次使用时)
if err := subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription); err != nil {
log.Printf("Failed to activate subscription windows: %v", err)
}
// 检查并重置过期窗口
if err := subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription); err != nil {
log.Printf("Failed to reset subscription windows: %v", err)
}
// 预检查用量限制使用0作为额外费用进行预检查
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error())
return
}
// 将订阅信息存入上下文
c.Set(string(ContextKeySubscription), subscription)
} else {
// 余额模式:检查用户余额
if apiKey.User.Balance <= 0 {
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
return
}
}
// 将API key和用户信息存入上下文
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
c.Next()
}
}
// GetApiKeyFromContext 从上下文中获取API key
func GetApiKeyFromContext(c *gin.Context) (*service.ApiKey, bool) {
value, exists := c.Get(string(ContextKeyApiKey))
if !exists {
return nil, false
}
apiKey, ok := value.(*service.ApiKey)
return apiKey, ok
}
// GetSubscriptionFromContext 从上下文中获取订阅信息
func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool) {
value, exists := c.Get(string(ContextKeySubscription))
if !exists {
return nil, false
}
subscription, ok := value.(*service.UserSubscription)
return subscription, ok
}
package middleware
import (
"errors"
"log"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// NewApiKeyAuthMiddleware 创建 API Key 认证中间件
func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) ApiKeyAuthMiddleware {
return ApiKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg))
}
// apiKeyAuthWithSubscription API Key认证中间件支持订阅验证
func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
// 尝试从Authorization header中提取API key (Bearer scheme)
authHeader := c.GetHeader("Authorization")
var apiKeyString string
if authHeader != "" {
// 验证Bearer scheme
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" {
apiKeyString = parts[1]
}
}
// 如果Authorization header中没有尝试从x-api-key header中提取
if apiKeyString == "" {
apiKeyString = c.GetHeader("x-api-key")
}
// 如果x-api-key header中没有尝试从x-goog-api-key header中提取Gemini CLI兼容
if apiKeyString == "" {
apiKeyString = c.GetHeader("x-goog-api-key")
}
// 如果header中没有尝试从query参数中提取Google API key风格
if apiKeyString == "" {
apiKeyString = c.Query("key")
}
// 兼容常见别名
if apiKeyString == "" {
apiKeyString = c.Query("api_key")
}
// 如果所有header都没有API key
if apiKeyString == "" {
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter")
return
}
// 从数据库验证API key
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil {
if errors.Is(err, service.ErrApiKeyNotFound) {
AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key")
return
}
AbortWithError(c, 500, "INTERNAL_ERROR", "Failed to validate API key")
return
}
// 检查API key是否激活
if !apiKey.IsActive() {
AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled")
return
}
// 检查关联的用户
if apiKey.User == nil {
AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found")
return
}
// 检查用户状态
if !apiKey.User.IsActive() {
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
return
}
if cfg.RunMode == config.RunModeSimple {
// 简易模式:跳过余额和订阅检查,但仍需设置必要的上下文
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
c.Next()
return
}
// 判断计费方式:订阅模式 vs 余额模式
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
if isSubscriptionType && subscriptionService != nil {
// 订阅模式:验证订阅
subscription, err := subscriptionService.GetActiveSubscription(
c.Request.Context(),
apiKey.User.ID,
apiKey.Group.ID,
)
if err != nil {
AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group")
return
}
// 验证订阅状态(是否过期、暂停等)
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error())
return
}
// 激活滑动窗口(首次使用时)
if err := subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription); err != nil {
log.Printf("Failed to activate subscription windows: %v", err)
}
// 检查并重置过期窗口
if err := subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription); err != nil {
log.Printf("Failed to reset subscription windows: %v", err)
}
// 预检查用量限制使用0作为额外费用进行预检查
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error())
return
}
// 将订阅信息存入上下文
c.Set(string(ContextKeySubscription), subscription)
} else {
// 余额模式:检查用户余额
if apiKey.User.Balance <= 0 {
AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance")
return
}
}
// 将API key和用户信息存入上下文
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
c.Next()
}
}
// GetApiKeyFromContext 从上下文中获取API key
func GetApiKeyFromContext(c *gin.Context) (*service.ApiKey, bool) {
value, exists := c.Get(string(ContextKeyApiKey))
if !exists {
return nil, false
}
apiKey, ok := value.(*service.ApiKey)
return apiKey, ok
}
// GetSubscriptionFromContext 从上下文中获取订阅信息
func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool) {
value, exists := c.Get(string(ContextKeySubscription))
if !exists {
return nil, false
}
subscription, ok := value.(*service.UserSubscription)
return subscription, ok
}

View File

@@ -1,137 +1,137 @@
package middleware
import (
"errors"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ApiKeyAuthGoogle is a Google-style error wrapper for API key auth.
func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService, cfg *config.Config) gin.HandlerFunc {
return ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
}
// ApiKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors:
// {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}}
//
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
apiKeyString := extractAPIKeyFromRequest(c)
if apiKeyString == "" {
abortWithGoogleError(c, 401, "API key is required")
return
}
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil {
if errors.Is(err, service.ErrApiKeyNotFound) {
abortWithGoogleError(c, 401, "Invalid API key")
return
}
abortWithGoogleError(c, 500, "Failed to validate API key")
return
}
if !apiKey.IsActive() {
abortWithGoogleError(c, 401, "API key is disabled")
return
}
if apiKey.User == nil {
abortWithGoogleError(c, 401, "User associated with API key not found")
return
}
if !apiKey.User.IsActive() {
abortWithGoogleError(c, 401, "User account is not active")
return
}
// 简易模式:跳过余额和订阅检查
if cfg.RunMode == config.RunModeSimple {
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
c.Next()
return
}
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
if isSubscriptionType && subscriptionService != nil {
subscription, err := subscriptionService.GetActiveSubscription(
c.Request.Context(),
apiKey.User.ID,
apiKey.Group.ID,
)
if err != nil {
abortWithGoogleError(c, 403, "No active subscription found for this group")
return
}
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
abortWithGoogleError(c, 403, err.Error())
return
}
_ = subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription)
_ = subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription)
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
abortWithGoogleError(c, 429, err.Error())
return
}
c.Set(string(ContextKeySubscription), subscription)
} else {
if apiKey.User.Balance <= 0 {
abortWithGoogleError(c, 403, "Insufficient account balance")
return
}
}
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
c.Next()
}
}
func extractAPIKeyFromRequest(c *gin.Context) string {
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" && strings.TrimSpace(parts[1]) != "" {
return strings.TrimSpace(parts[1])
}
}
if v := strings.TrimSpace(c.GetHeader("x-api-key")); v != "" {
return v
}
if v := strings.TrimSpace(c.GetHeader("x-goog-api-key")); v != "" {
return v
}
if v := strings.TrimSpace(c.Query("key")); v != "" {
return v
}
if v := strings.TrimSpace(c.Query("api_key")); v != "" {
return v
}
return ""
}
func abortWithGoogleError(c *gin.Context, status int, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"code": status,
"message": message,
"status": googleapi.HTTPStatusToGoogleStatus(status),
},
})
c.Abort()
}
package middleware
import (
"errors"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// ApiKeyAuthGoogle is a Google-style error wrapper for API key auth.
func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService, cfg *config.Config) gin.HandlerFunc {
return ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)
}
// ApiKeyAuthWithSubscriptionGoogle behaves like ApiKeyAuthWithSubscription but returns Google-style errors:
// {"error":{"code":401,"message":"...","status":"UNAUTHENTICATED"}}
//
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
apiKeyString := extractAPIKeyFromRequest(c)
if apiKeyString == "" {
abortWithGoogleError(c, 401, "API key is required")
return
}
apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString)
if err != nil {
if errors.Is(err, service.ErrApiKeyNotFound) {
abortWithGoogleError(c, 401, "Invalid API key")
return
}
abortWithGoogleError(c, 500, "Failed to validate API key")
return
}
if !apiKey.IsActive() {
abortWithGoogleError(c, 401, "API key is disabled")
return
}
if apiKey.User == nil {
abortWithGoogleError(c, 401, "User associated with API key not found")
return
}
if !apiKey.User.IsActive() {
abortWithGoogleError(c, 401, "User account is not active")
return
}
// 简易模式:跳过余额和订阅检查
if cfg.RunMode == config.RunModeSimple {
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
c.Next()
return
}
isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
if isSubscriptionType && subscriptionService != nil {
subscription, err := subscriptionService.GetActiveSubscription(
c.Request.Context(),
apiKey.User.ID,
apiKey.Group.ID,
)
if err != nil {
abortWithGoogleError(c, 403, "No active subscription found for this group")
return
}
if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil {
abortWithGoogleError(c, 403, err.Error())
return
}
_ = subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription)
_ = subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription)
if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil {
abortWithGoogleError(c, 429, err.Error())
return
}
c.Set(string(ContextKeySubscription), subscription)
} else {
if apiKey.User.Balance <= 0 {
abortWithGoogleError(c, 403, "Insufficient account balance")
return
}
}
c.Set(string(ContextKeyApiKey), apiKey)
c.Set(string(ContextKeyUser), AuthSubject{
UserID: apiKey.User.ID,
Concurrency: apiKey.User.Concurrency,
})
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
c.Next()
}
}
func extractAPIKeyFromRequest(c *gin.Context) string {
authHeader := c.GetHeader("Authorization")
if authHeader != "" {
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" && strings.TrimSpace(parts[1]) != "" {
return strings.TrimSpace(parts[1])
}
}
if v := strings.TrimSpace(c.GetHeader("x-api-key")); v != "" {
return v
}
if v := strings.TrimSpace(c.GetHeader("x-goog-api-key")); v != "" {
return v
}
if v := strings.TrimSpace(c.Query("key")); v != "" {
return v
}
if v := strings.TrimSpace(c.Query("api_key")); v != "" {
return v
}
return ""
}
func abortWithGoogleError(c *gin.Context, status int, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"code": status,
"message": message,
"status": googleapi.HTTPStatusToGoogleStatus(status),
},
})
c.Abort()
}

View File

@@ -1,227 +1,227 @@
package middleware
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type fakeApiKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
}
func (f fakeApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
return errors.New("not implemented")
}
func (f fakeApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
return nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (f fakeApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
if f.getByKey == nil {
return nil, errors.New("unexpected call")
}
return f.getByKey(ctx, key)
}
func (f fakeApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
return errors.New("not implemented")
}
func (f fakeApiKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (f fakeApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
return nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (f fakeApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
return false, errors.New("not implemented")
}
func (f fakeApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
return nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (f fakeApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
type googleErrorResponse struct {
Error struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
} `json:"error"`
}
func newTestApiKeyService(repo service.ApiKeyRepository) *service.ApiKeyService {
return service.NewApiKeyService(
repo,
nil, // userRepo (unused in GetByKey)
nil, // groupRepo
nil, // userSubRepo
nil, // cache
&config.Config{},
)
}
func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return nil, errors.New("should not be called")
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
var resp googleErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, http.StatusUnauthorized, resp.Error.Code)
require.Equal(t, "API key is required", resp.Error.Message)
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return nil, service.ErrApiKeyNotFound
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
req.Header.Set("Authorization", "Bearer invalid")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
var resp googleErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, http.StatusUnauthorized, resp.Error.Code)
require.Equal(t, "Invalid API key", resp.Error.Message)
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return nil, errors.New("db down")
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
req.Header.Set("Authorization", "Bearer any")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusInternalServerError, rec.Code)
var resp googleErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, http.StatusInternalServerError, resp.Error.Code)
require.Equal(t, "Failed to validate API key", resp.Error.Message)
require.Equal(t, "INTERNAL", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return &service.ApiKey{
ID: 1,
Key: key,
Status: service.StatusDisabled,
User: &service.User{
ID: 123,
Status: service.StatusActive,
},
}, nil
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
req.Header.Set("Authorization", "Bearer disabled")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
var resp googleErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, http.StatusUnauthorized, resp.Error.Code)
require.Equal(t, "API key is disabled", resp.Error.Message)
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return &service.ApiKey{
ID: 1,
Key: key,
Status: service.StatusActive,
User: &service.User{
ID: 123,
Status: service.StatusActive,
Balance: 0,
},
}, nil
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
req.Header.Set("Authorization", "Bearer ok")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusForbidden, rec.Code)
var resp googleErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, http.StatusForbidden, resp.Error.Code)
require.Equal(t, "Insufficient account balance", resp.Error.Message)
require.Equal(t, "PERMISSION_DENIED", resp.Error.Status)
}
package middleware
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type fakeApiKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
}
func (f fakeApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
return errors.New("not implemented")
}
func (f fakeApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
return nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (f fakeApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
if f.getByKey == nil {
return nil, errors.New("unexpected call")
}
return f.getByKey(ctx, key)
}
func (f fakeApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
return errors.New("not implemented")
}
func (f fakeApiKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (f fakeApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
return nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (f fakeApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
return false, errors.New("not implemented")
}
func (f fakeApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
return nil, errors.New("not implemented")
}
func (f fakeApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (f fakeApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
type googleErrorResponse struct {
Error struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
} `json:"error"`
}
func newTestApiKeyService(repo service.ApiKeyRepository) *service.ApiKeyService {
return service.NewApiKeyService(
repo,
nil, // userRepo (unused in GetByKey)
nil, // groupRepo
nil, // userSubRepo
nil, // cache
&config.Config{},
)
}
func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return nil, errors.New("should not be called")
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
var resp googleErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, http.StatusUnauthorized, resp.Error.Code)
require.Equal(t, "API key is required", resp.Error.Message)
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return nil, service.ErrApiKeyNotFound
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
req.Header.Set("Authorization", "Bearer invalid")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
var resp googleErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, http.StatusUnauthorized, resp.Error.Code)
require.Equal(t, "Invalid API key", resp.Error.Message)
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_RepoError(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return nil, errors.New("db down")
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
req.Header.Set("Authorization", "Bearer any")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusInternalServerError, rec.Code)
var resp googleErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, http.StatusInternalServerError, resp.Error.Code)
require.Equal(t, "Failed to validate API key", resp.Error.Message)
require.Equal(t, "INTERNAL", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_DisabledKey(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return &service.ApiKey{
ID: 1,
Key: key,
Status: service.StatusDisabled,
User: &service.User{
ID: 123,
Status: service.StatusActive,
},
}, nil
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
req.Header.Set("Authorization", "Bearer disabled")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
var resp googleErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, http.StatusUnauthorized, resp.Error.Code)
require.Equal(t, "API key is disabled", resp.Error.Message)
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return &service.ApiKey{
ID: 1,
Key: key,
Status: service.StatusActive,
User: &service.User{
ID: 123,
Status: service.StatusActive,
Balance: 0,
},
}, nil
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
req.Header.Set("Authorization", "Bearer ok")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusForbidden, rec.Code)
var resp googleErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, http.StatusForbidden, resp.Error.Code)
require.Equal(t, "Insufficient account balance", resp.Error.Message)
require.Equal(t, "PERMISSION_DENIED", resp.Error.Status)
}

View File

@@ -1,290 +1,290 @@
//go:build unit
package middleware
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
gin.SetMode(gin.TestMode)
limit := 1.0
group := &service.Group{
ID: 42,
Name: "sub",
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeSubscription,
DailyLimitUSD: &limit,
}
user := &service.User{
ID: 7,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.ApiKey{
ID: 100,
UserID: user.ID,
Key: "test-key",
Status: service.StatusActive,
User: user,
Group: group,
}
apiKey.GroupID = &group.ID
apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
if key != apiKey.Key {
return nil, service.ErrApiKeyNotFound
}
clone := *apiKey
return &clone, nil
},
}
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("x-api-key", apiKey.Key)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
})
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeStandard}
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
now := time.Now()
sub := &service.UserSubscription{
ID: 55,
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: now.Add(24 * time.Hour),
DailyWindowStart: &now,
DailyUsageUSD: 10,
}
subscriptionRepo := &stubUserSubscriptionRepo{
getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
if userID != sub.UserID || groupID != sub.GroupID {
return nil, service.ErrSubscriptionNotFound
}
clone := *sub
return &clone, nil
},
updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil },
activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetDaily: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
}
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("x-api-key", apiKey.Key)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusTooManyRequests, w.Code)
require.Contains(t, w.Body.String(), "USAGE_LIMIT_EXCEEDED")
})
}
func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
router := gin.New()
router.Use(gin.HandlerFunc(NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
router.GET("/t", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
return router
}
type stubApiKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
}
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
return errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
if r.getByKey != nil {
return r.getByKey(ctx, key)
}
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
return errors.New("not implemented")
}
func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
return false, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
type stubUserSubscriptionRepo struct {
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
activateWindow func(ctx context.Context, id int64, start time.Time) error
resetDaily func(ctx context.Context, id int64, start time.Time) error
resetWeekly func(ctx context.Context, id int64, start time.Time) error
resetMonthly func(ctx context.Context, id int64, start time.Time) error
}
func (r *stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
if r.getActive != nil {
return r.getActive(ctx, userID, groupID)
}
return nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) Update(ctx context.Context, sub *service.UserSubscription) error {
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
return false, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
if r.updateStatus != nil {
return r.updateStatus(ctx, subscriptionID, status)
}
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
if r.activateWindow != nil {
return r.activateWindow(ctx, id, start)
}
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
if r.resetDaily != nil {
return r.resetDaily(ctx, id, newWindowStart)
}
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
if r.resetWeekly != nil {
return r.resetWeekly(ctx, id, newWindowStart)
}
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
if r.resetMonthly != nil {
return r.resetMonthly(ctx, id, newWindowStart)
}
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
return 0, errors.New("not implemented")
}
//go:build unit
package middleware
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
gin.SetMode(gin.TestMode)
limit := 1.0
group := &service.Group{
ID: 42,
Name: "sub",
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeSubscription,
DailyLimitUSD: &limit,
}
user := &service.User{
ID: 7,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.ApiKey{
ID: 100,
UserID: user.ID,
Key: "test-key",
Status: service.StatusActive,
User: user,
Group: group,
}
apiKey.GroupID = &group.ID
apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
if key != apiKey.Key {
return nil, service.ErrApiKeyNotFound
}
clone := *apiKey
return &clone, nil
},
}
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("x-api-key", apiKey.Key)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
})
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeStandard}
apiKeyService := service.NewApiKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
now := time.Now()
sub := &service.UserSubscription{
ID: 55,
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: now.Add(24 * time.Hour),
DailyWindowStart: &now,
DailyUsageUSD: 10,
}
subscriptionRepo := &stubUserSubscriptionRepo{
getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
if userID != sub.UserID || groupID != sub.GroupID {
return nil, service.ErrSubscriptionNotFound
}
clone := *sub
return &clone, nil
},
updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil },
activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetDaily: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil },
resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil },
}
subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.Header.Set("x-api-key", apiKey.Key)
router.ServeHTTP(w, req)
require.Equal(t, http.StatusTooManyRequests, w.Code)
require.Contains(t, w.Body.String(), "USAGE_LIMIT_EXCEEDED")
})
}
func newAuthTestRouter(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
router := gin.New()
router.Use(gin.HandlerFunc(NewApiKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
router.GET("/t", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
return router
}
type stubApiKeyRepo struct {
getByKey func(ctx context.Context, key string) (*service.ApiKey, error)
}
func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.ApiKey) error {
return errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.ApiKey, error) {
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.ApiKey, error) {
if r.getByKey != nil {
return r.getByKey(ctx, key)
}
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.ApiKey) error {
return errors.New("not implemented")
}
func (r *stubApiKeyRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) CountByUserID(ctx context.Context, userID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ExistsByKey(ctx context.Context, key string) (bool, error) {
return false, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.ApiKey, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.ApiKey, error) {
return nil, errors.New("not implemented")
}
func (r *stubApiKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
return 0, errors.New("not implemented")
}
type stubUserSubscriptionRepo struct {
getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error)
updateStatus func(ctx context.Context, subscriptionID int64, status string) error
activateWindow func(ctx context.Context, id int64, start time.Time) error
resetDaily func(ctx context.Context, id int64, start time.Time) error
resetWeekly func(ctx context.Context, id int64, start time.Time) error
resetMonthly func(ctx context.Context, id int64, start time.Time) error
}
func (r *stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error {
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) {
if r.getActive != nil {
return r.getActive(ctx, userID, groupID)
}
return nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) Update(ctx context.Context, sub *service.UserSubscription) error {
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) {
return nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]service.UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) {
return false, errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error {
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error {
if r.updateStatus != nil {
return r.updateStatus(ctx, subscriptionID, status)
}
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error {
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ActivateWindows(ctx context.Context, id int64, start time.Time) error {
if r.activateWindow != nil {
return r.activateWindow(ctx, id, start)
}
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
if r.resetDaily != nil {
return r.resetDaily(ctx, id, newWindowStart)
}
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
if r.resetWeekly != nil {
return r.resetWeekly(ctx, id, newWindowStart)
}
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error {
if r.resetMonthly != nil {
return r.resetMonthly(ctx, id, newWindowStart)
}
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
return errors.New("not implemented")
}
func (r *stubUserSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) {
return 0, errors.New("not implemented")
}

View File

@@ -1,28 +1,28 @@
package middleware
import "github.com/gin-gonic/gin"
// AuthSubject is the minimal authenticated identity stored in gin context.
// Decision: {UserID int64, Concurrency int}
type AuthSubject struct {
UserID int64
Concurrency int
}
func GetAuthSubjectFromContext(c *gin.Context) (AuthSubject, bool) {
value, exists := c.Get(string(ContextKeyUser))
if !exists {
return AuthSubject{}, false
}
subject, ok := value.(AuthSubject)
return subject, ok
}
func GetUserRoleFromContext(c *gin.Context) (string, bool) {
value, exists := c.Get(string(ContextKeyUserRole))
if !exists {
return "", false
}
role, ok := value.(string)
return role, ok
}
package middleware
import "github.com/gin-gonic/gin"
// AuthSubject is the minimal authenticated identity stored in gin context.
// Decision: {UserID int64, Concurrency int}
type AuthSubject struct {
UserID int64
Concurrency int
}
func GetAuthSubjectFromContext(c *gin.Context) (AuthSubject, bool) {
value, exists := c.Get(string(ContextKeyUser))
if !exists {
return AuthSubject{}, false
}
subject, ok := value.(AuthSubject)
return subject, ok
}
func GetUserRoleFromContext(c *gin.Context) (string, bool) {
value, exists := c.Get(string(ContextKeyUserRole))
if !exists {
return "", false
}
role, ok := value.(string)
return role, ok
}

View File

@@ -1,24 +1,24 @@
package middleware
import (
"github.com/gin-gonic/gin"
)
// CORS 跨域中间件
func CORS() gin.HandlerFunc {
return func(c *gin.Context) {
// 设置允许跨域的响应头
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
// 处理预检请求
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
}
}
package middleware
import (
"github.com/gin-gonic/gin"
)
// CORS 跨域中间件
func CORS() gin.HandlerFunc {
return func(c *gin.Context) {
// 设置允许跨域的响应头
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
// 处理预检请求
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
}
}

View File

@@ -1,81 +1,81 @@
package middleware
import (
"errors"
"strings"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// NewJWTAuthMiddleware 创建 JWT 认证中间件
func NewJWTAuthMiddleware(authService *service.AuthService, userService *service.UserService) JWTAuthMiddleware {
return JWTAuthMiddleware(jwtAuth(authService, userService))
}
// jwtAuth JWT认证中间件实现
func jwtAuth(authService *service.AuthService, userService *service.UserService) gin.HandlerFunc {
return func(c *gin.Context) {
// 从Authorization header中提取token
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization header is required")
return
}
// 验证Bearer scheme
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
AbortWithError(c, 401, "INVALID_AUTH_HEADER", "Authorization header format must be 'Bearer {token}'")
return
}
tokenString := parts[1]
if tokenString == "" {
AbortWithError(c, 401, "EMPTY_TOKEN", "Token cannot be empty")
return
}
// 验证token
claims, err := authService.ValidateToken(tokenString)
if err != nil {
if errors.Is(err, service.ErrTokenExpired) {
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
return
}
AbortWithError(c, 401, "INVALID_TOKEN", "Invalid token")
return
}
// 从数据库获取最新的用户信息
user, err := userService.GetByID(c.Request.Context(), claims.UserID)
if err != nil {
AbortWithError(c, 401, "USER_NOT_FOUND", "User not found")
return
}
// 检查用户状态
if !user.IsActive() {
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
return
}
// Security: Validate TokenVersion to ensure token hasn't been invalidated
// This check ensures tokens issued before a password change are rejected
if claims.TokenVersion != user.TokenVersion {
AbortWithError(c, 401, "TOKEN_REVOKED", "Token has been revoked (password changed)")
return
}
c.Set(string(ContextKeyUser), AuthSubject{
UserID: user.ID,
Concurrency: user.Concurrency,
})
c.Set(string(ContextKeyUserRole), user.Role)
c.Next()
}
}
// Deprecated: prefer GetAuthSubjectFromContext in auth_subject.go.
package middleware
import (
"errors"
"strings"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// NewJWTAuthMiddleware 创建 JWT 认证中间件
func NewJWTAuthMiddleware(authService *service.AuthService, userService *service.UserService) JWTAuthMiddleware {
return JWTAuthMiddleware(jwtAuth(authService, userService))
}
// jwtAuth JWT认证中间件实现
func jwtAuth(authService *service.AuthService, userService *service.UserService) gin.HandlerFunc {
return func(c *gin.Context) {
// 从Authorization header中提取token
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
AbortWithError(c, 401, "UNAUTHORIZED", "Authorization header is required")
return
}
// 验证Bearer scheme
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
AbortWithError(c, 401, "INVALID_AUTH_HEADER", "Authorization header format must be 'Bearer {token}'")
return
}
tokenString := parts[1]
if tokenString == "" {
AbortWithError(c, 401, "EMPTY_TOKEN", "Token cannot be empty")
return
}
// 验证token
claims, err := authService.ValidateToken(tokenString)
if err != nil {
if errors.Is(err, service.ErrTokenExpired) {
AbortWithError(c, 401, "TOKEN_EXPIRED", "Token has expired")
return
}
AbortWithError(c, 401, "INVALID_TOKEN", "Invalid token")
return
}
// 从数据库获取最新的用户信息
user, err := userService.GetByID(c.Request.Context(), claims.UserID)
if err != nil {
AbortWithError(c, 401, "USER_NOT_FOUND", "User not found")
return
}
// 检查用户状态
if !user.IsActive() {
AbortWithError(c, 401, "USER_INACTIVE", "User account is not active")
return
}
// Security: Validate TokenVersion to ensure token hasn't been invalidated
// This check ensures tokens issued before a password change are rejected
if claims.TokenVersion != user.TokenVersion {
AbortWithError(c, 401, "TOKEN_REVOKED", "Token has been revoked (password changed)")
return
}
c.Set(string(ContextKeyUser), AuthSubject{
UserID: user.ID,
Concurrency: user.Concurrency,
})
c.Set(string(ContextKeyUserRole), user.Role)
c.Next()
}
}
// Deprecated: prefer GetAuthSubjectFromContext in auth_subject.go.

View File

@@ -1,52 +1,52 @@
package middleware
import (
"log"
"time"
"github.com/gin-gonic/gin"
)
// Logger 请求日志中间件
func Logger() gin.HandlerFunc {
return func(c *gin.Context) {
// 开始时间
startTime := time.Now()
// 处理请求
c.Next()
// 结束时间
endTime := time.Now()
// 执行时间
latency := endTime.Sub(startTime)
// 请求方法
method := c.Request.Method
// 请求路径
path := c.Request.URL.Path
// 状态码
statusCode := c.Writer.Status()
// 客户端IP
clientIP := c.ClientIP()
// 日志格式: [时间] 状态码 | 延迟 | IP | 方法 路径
log.Printf("[GIN] %v | %3d | %13v | %15s | %-7s %s",
endTime.Format("2006/01/02 - 15:04:05"),
statusCode,
latency,
clientIP,
method,
path,
)
// 如果有错误,额外记录错误信息
if len(c.Errors) > 0 {
log.Printf("[GIN] Errors: %v", c.Errors.String())
}
}
}
package middleware
import (
"log"
"time"
"github.com/gin-gonic/gin"
)
// Logger 请求日志中间件
func Logger() gin.HandlerFunc {
return func(c *gin.Context) {
// 开始时间
startTime := time.Now()
// 处理请求
c.Next()
// 结束时间
endTime := time.Now()
// 执行时间
latency := endTime.Sub(startTime)
// 请求方法
method := c.Request.Method
// 请求路径
path := c.Request.URL.Path
// 状态码
statusCode := c.Writer.Status()
// 客户端IP
clientIP := c.ClientIP()
// 日志格式: [时间] 状态码 | 延迟 | IP | 方法 路径
log.Printf("[GIN] %v | %3d | %13v | %15s | %-7s %s",
endTime.Format("2006/01/02 - 15:04:05"),
statusCode,
latency,
clientIP,
method,
path,
)
// 如果有错误,额外记录错误信息
if len(c.Errors) > 0 {
log.Printf("[GIN] Errors: %v", c.Errors.String())
}
}
}

View File

@@ -1,73 +1,73 @@
package middleware
import (
"context"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/gin-gonic/gin"
)
// ContextKey 定义上下文键类型
type ContextKey string
const (
// ContextKeyUser 用户上下文键
ContextKeyUser ContextKey = "user"
// ContextKeyUserRole 当前用户角色string
ContextKeyUserRole ContextKey = "user_role"
// ContextKeyApiKey API密钥上下文键
ContextKeyApiKey ContextKey = "api_key"
// ContextKeySubscription 订阅上下文键
ContextKeySubscription ContextKey = "subscription"
// ContextKeyForcePlatform 强制平台(用于 /antigravity 路由)
ContextKeyForcePlatform ContextKey = "force_platform"
)
// ForcePlatform 返回设置强制平台的中间件
// 同时设置 request.Context供 Service 使用)和 gin.Context供 Handler 快速检查)
func ForcePlatform(platform string) gin.HandlerFunc {
return func(c *gin.Context) {
// 设置到 request.Context使用 ctxkey.ForcePlatform 供 Service 层读取
ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, platform)
c.Request = c.Request.WithContext(ctx)
// 同时设置到 gin.Context供 Handler 快速检查
c.Set(string(ContextKeyForcePlatform), platform)
c.Next()
}
}
// HasForcePlatform 检查是否有强制平台(用于 Handler 跳过分组检查)
func HasForcePlatform(c *gin.Context) bool {
_, exists := c.Get(string(ContextKeyForcePlatform))
return exists
}
// GetForcePlatformFromContext 从 gin.Context 获取强制平台
func GetForcePlatformFromContext(c *gin.Context) (string, bool) {
value, exists := c.Get(string(ContextKeyForcePlatform))
if !exists {
return "", false
}
platform, ok := value.(string)
return platform, ok
}
// ErrorResponse 标准错误响应结构
type ErrorResponse struct {
Code string `json:"code"`
Message string `json:"message"`
}
// NewErrorResponse 创建错误响应
func NewErrorResponse(code, message string) ErrorResponse {
return ErrorResponse{
Code: code,
Message: message,
}
}
// AbortWithError 中断请求并返回JSON错误
func AbortWithError(c *gin.Context, statusCode int, code, message string) {
c.JSON(statusCode, NewErrorResponse(code, message))
c.Abort()
}
package middleware
import (
"context"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/gin-gonic/gin"
)
// ContextKey 定义上下文键类型
type ContextKey string
const (
// ContextKeyUser 用户上下文键
ContextKeyUser ContextKey = "user"
// ContextKeyUserRole 当前用户角色string
ContextKeyUserRole ContextKey = "user_role"
// ContextKeyApiKey API密钥上下文键
ContextKeyApiKey ContextKey = "api_key"
// ContextKeySubscription 订阅上下文键
ContextKeySubscription ContextKey = "subscription"
// ContextKeyForcePlatform 强制平台(用于 /antigravity 路由)
ContextKeyForcePlatform ContextKey = "force_platform"
)
// ForcePlatform 返回设置强制平台的中间件
// 同时设置 request.Context供 Service 使用)和 gin.Context供 Handler 快速检查)
func ForcePlatform(platform string) gin.HandlerFunc {
return func(c *gin.Context) {
// 设置到 request.Context使用 ctxkey.ForcePlatform 供 Service 层读取
ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, platform)
c.Request = c.Request.WithContext(ctx)
// 同时设置到 gin.Context供 Handler 快速检查
c.Set(string(ContextKeyForcePlatform), platform)
c.Next()
}
}
// HasForcePlatform 检查是否有强制平台(用于 Handler 跳过分组检查)
func HasForcePlatform(c *gin.Context) bool {
_, exists := c.Get(string(ContextKeyForcePlatform))
return exists
}
// GetForcePlatformFromContext 从 gin.Context 获取强制平台
func GetForcePlatformFromContext(c *gin.Context) (string, bool) {
value, exists := c.Get(string(ContextKeyForcePlatform))
if !exists {
return "", false
}
platform, ok := value.(string)
return platform, ok
}
// ErrorResponse 标准错误响应结构
type ErrorResponse struct {
Code string `json:"code"`
Message string `json:"message"`
}
// NewErrorResponse 创建错误响应
func NewErrorResponse(code, message string) ErrorResponse {
return ErrorResponse{
Code: code,
Message: message,
}
}
// AbortWithError 中断请求并返回JSON错误
func AbortWithError(c *gin.Context, statusCode int, code, message string) {
c.JSON(statusCode, NewErrorResponse(code, message))
c.Abort()
}

View File

@@ -1,64 +1,64 @@
package middleware
import (
"errors"
"net"
"net/http"
"os"
"strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/gin-gonic/gin"
)
// Recovery converts panics into the project's standard JSON error envelope.
//
// It preserves Gin's broken-pipe handling by not attempting to write a response
// when the client connection is already gone.
func Recovery() gin.HandlerFunc {
return gin.CustomRecoveryWithWriter(gin.DefaultErrorWriter, func(c *gin.Context, recovered any) {
recoveredErr, _ := recovered.(error)
if isBrokenPipe(recoveredErr) {
if recoveredErr != nil {
_ = c.Error(recoveredErr)
}
c.Abort()
return
}
if c.Writer.Written() {
c.Abort()
return
}
response.ErrorWithDetails(
c,
http.StatusInternalServerError,
infraerrors.UnknownMessage,
infraerrors.UnknownReason,
nil,
)
c.Abort()
})
}
func isBrokenPipe(err error) bool {
if err == nil {
return false
}
var opErr *net.OpError
if !errors.As(err, &opErr) {
return false
}
var syscallErr *os.SyscallError
if !errors.As(opErr.Err, &syscallErr) {
return false
}
msg := strings.ToLower(syscallErr.Error())
return strings.Contains(msg, "broken pipe") || strings.Contains(msg, "connection reset by peer")
}
package middleware
import (
"errors"
"net"
"net/http"
"os"
"strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/gin-gonic/gin"
)
// Recovery converts panics into the project's standard JSON error envelope.
//
// It preserves Gin's broken-pipe handling by not attempting to write a response
// when the client connection is already gone.
func Recovery() gin.HandlerFunc {
return gin.CustomRecoveryWithWriter(gin.DefaultErrorWriter, func(c *gin.Context, recovered any) {
recoveredErr, _ := recovered.(error)
if isBrokenPipe(recoveredErr) {
if recoveredErr != nil {
_ = c.Error(recoveredErr)
}
c.Abort()
return
}
if c.Writer.Written() {
c.Abort()
return
}
response.ErrorWithDetails(
c,
http.StatusInternalServerError,
infraerrors.UnknownMessage,
infraerrors.UnknownReason,
nil,
)
c.Abort()
})
}
func isBrokenPipe(err error) bool {
if err == nil {
return false
}
var opErr *net.OpError
if !errors.As(err, &opErr) {
return false
}
var syscallErr *os.SyscallError
if !errors.As(opErr.Err, &syscallErr) {
return false
}
msg := strings.ToLower(syscallErr.Error())
return strings.Contains(msg, "broken pipe") || strings.Contains(msg, "connection reset by peer")
}

View File

@@ -1,81 +1,81 @@
//go:build unit
package middleware
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestRecovery(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
handler gin.HandlerFunc
wantHTTPCode int
wantBody response.Response
}{
{
name: "panic_returns_standard_json_500",
handler: func(c *gin.Context) {
panic("boom")
},
wantHTTPCode: http.StatusInternalServerError,
wantBody: response.Response{
Code: http.StatusInternalServerError,
Message: infraerrors.UnknownMessage,
},
},
{
name: "no_panic_passthrough",
handler: func(c *gin.Context) {
response.Success(c, gin.H{"ok": true})
},
wantHTTPCode: http.StatusOK,
wantBody: response.Response{
Code: 0,
Message: "success",
Data: map[string]any{"ok": true},
},
},
{
name: "panic_after_write_does_not_override_body",
handler: func(c *gin.Context) {
response.Success(c, gin.H{"ok": true})
panic("boom")
},
wantHTTPCode: http.StatusOK,
wantBody: response.Response{
Code: 0,
Message: "success",
Data: map[string]any{"ok": true},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := gin.New()
r.Use(Recovery())
r.GET("/t", tt.handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
r.ServeHTTP(w, req)
require.Equal(t, tt.wantHTTPCode, w.Code)
var got response.Response
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
require.Equal(t, tt.wantBody, got)
})
}
}
//go:build unit
package middleware
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestRecovery(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
handler gin.HandlerFunc
wantHTTPCode int
wantBody response.Response
}{
{
name: "panic_returns_standard_json_500",
handler: func(c *gin.Context) {
panic("boom")
},
wantHTTPCode: http.StatusInternalServerError,
wantBody: response.Response{
Code: http.StatusInternalServerError,
Message: infraerrors.UnknownMessage,
},
},
{
name: "no_panic_passthrough",
handler: func(c *gin.Context) {
response.Success(c, gin.H{"ok": true})
},
wantHTTPCode: http.StatusOK,
wantBody: response.Response{
Code: 0,
Message: "success",
Data: map[string]any{"ok": true},
},
},
{
name: "panic_after_write_does_not_override_body",
handler: func(c *gin.Context) {
response.Success(c, gin.H{"ok": true})
panic("boom")
},
wantHTTPCode: http.StatusOK,
wantBody: response.Response{
Code: 0,
Message: "success",
Data: map[string]any{"ok": true},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := gin.New()
r.Use(Recovery())
r.GET("/t", tt.handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
r.ServeHTTP(w, req)
require.Equal(t, tt.wantHTTPCode, w.Code)
var got response.Response
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
require.Equal(t, tt.wantBody, got)
})
}
}

View File

@@ -1,15 +1,15 @@
package middleware
import (
"net/http"
"github.com/gin-gonic/gin"
)
// RequestBodyLimit 使用 MaxBytesReader 限制请求体大小。
func RequestBodyLimit(maxBytes int64) gin.HandlerFunc {
return func(c *gin.Context) {
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxBytes)
c.Next()
}
}
package middleware
import (
"net/http"
"github.com/gin-gonic/gin"
)
// RequestBodyLimit 使用 MaxBytesReader 限制请求体大小。
func RequestBodyLimit(maxBytes int64) gin.HandlerFunc {
return func(c *gin.Context) {
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxBytes)
c.Next()
}
}

View File

@@ -1,22 +1,22 @@
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/google/wire"
)
// JWTAuthMiddleware JWT 认证中间件类型
type JWTAuthMiddleware gin.HandlerFunc
// AdminAuthMiddleware 管理员认证中间件类型
type AdminAuthMiddleware gin.HandlerFunc
// ApiKeyAuthMiddleware API Key 认证中间件类型
type ApiKeyAuthMiddleware gin.HandlerFunc
// ProviderSet 中间件层的依赖注入
var ProviderSet = wire.NewSet(
NewJWTAuthMiddleware,
NewAdminAuthMiddleware,
NewApiKeyAuthMiddleware,
)
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/google/wire"
)
// JWTAuthMiddleware JWT 认证中间件类型
type JWTAuthMiddleware gin.HandlerFunc
// AdminAuthMiddleware 管理员认证中间件类型
type AdminAuthMiddleware gin.HandlerFunc
// ApiKeyAuthMiddleware API Key 认证中间件类型
type ApiKeyAuthMiddleware gin.HandlerFunc
// ProviderSet 中间件层的依赖注入
var ProviderSet = wire.NewSet(
NewJWTAuthMiddleware,
NewAdminAuthMiddleware,
NewApiKeyAuthMiddleware,
)

View File

@@ -1,62 +1,62 @@
package server
import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/server/routes"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/web"
"github.com/gin-gonic/gin"
)
// SetupRouter 配置路由器中间件和路由
func SetupRouter(
r *gin.Engine,
handlers *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService,
subscriptionService *service.SubscriptionService,
cfg *config.Config,
) *gin.Engine {
// 应用中间件
r.Use(middleware2.Logger())
r.Use(middleware2.CORS())
// Serve embedded frontend if available
if web.HasEmbeddedFrontend() {
r.Use(web.ServeEmbeddedFrontend())
}
// 注册路由
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
return r
}
// registerRoutes 注册所有 HTTP 路由
func registerRoutes(
r *gin.Engine,
h *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService,
subscriptionService *service.SubscriptionService,
cfg *config.Config,
) {
// 通用路由(健康检查、状态等)
routes.RegisterCommonRoutes(r)
// API v1
v1 := r.Group("/api/v1")
// 注册各模块路由
routes.RegisterAuthRoutes(v1, h, jwtAuth)
routes.RegisterUserRoutes(v1, h, jwtAuth)
routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, cfg)
}
package server
import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/server/routes"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/web"
"github.com/gin-gonic/gin"
)
// SetupRouter 配置路由器中间件和路由
func SetupRouter(
r *gin.Engine,
handlers *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService,
subscriptionService *service.SubscriptionService,
cfg *config.Config,
) *gin.Engine {
// 应用中间件
r.Use(middleware2.Logger())
r.Use(middleware2.CORS())
// Serve embedded frontend if available
if web.HasEmbeddedFrontend() {
r.Use(web.ServeEmbeddedFrontend())
}
// 注册路由
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
return r
}
// registerRoutes 注册所有 HTTP 路由
func registerRoutes(
r *gin.Engine,
h *handler.Handlers,
jwtAuth middleware2.JWTAuthMiddleware,
adminAuth middleware2.AdminAuthMiddleware,
apiKeyAuth middleware2.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService,
subscriptionService *service.SubscriptionService,
cfg *config.Config,
) {
// 通用路由(健康检查、状态等)
routes.RegisterCommonRoutes(r)
// API v1
v1 := r.Group("/api/v1")
// 注册各模块路由
routes.RegisterAuthRoutes(v1, h, jwtAuth)
routes.RegisterUserRoutes(v1, h, jwtAuth)
routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, cfg)
}

View File

@@ -1,265 +1,265 @@
package routes
import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
)
// RegisterAdminRoutes 注册管理员路由
func RegisterAdminRoutes(
v1 *gin.RouterGroup,
h *handler.Handlers,
adminAuth middleware.AdminAuthMiddleware,
) {
admin := v1.Group("/admin")
admin.Use(gin.HandlerFunc(adminAuth))
{
// 仪表盘
registerDashboardRoutes(admin, h)
// 用户管理
registerUserManagementRoutes(admin, h)
// 分组管理
registerGroupRoutes(admin, h)
// 账号管理
registerAccountRoutes(admin, h)
// OpenAI OAuth
registerOpenAIOAuthRoutes(admin, h)
// Gemini OAuth
registerGeminiOAuthRoutes(admin, h)
// Antigravity OAuth
registerAntigravityOAuthRoutes(admin, h)
// 代理管理
registerProxyRoutes(admin, h)
// 卡密管理
registerRedeemCodeRoutes(admin, h)
// 系统设置
registerSettingsRoutes(admin, h)
// 系统管理
registerSystemRoutes(admin, h)
// 订阅管理
registerSubscriptionRoutes(admin, h)
// 使用记录管理
registerUsageRoutes(admin, h)
// 用户属性管理
registerUserAttributeRoutes(admin, h)
}
}
func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard := admin.Group("/dashboard")
{
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
}
}
func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
users := admin.Group("/users")
{
users.GET("", h.Admin.User.List)
users.GET("/:id", h.Admin.User.GetByID)
users.POST("", h.Admin.User.Create)
users.PUT("/:id", h.Admin.User.Update)
users.DELETE("/:id", h.Admin.User.Delete)
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
// User attribute values
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
users.PUT("/:id/attributes", h.Admin.UserAttribute.UpdateUserAttributes)
}
}
func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
groups := admin.Group("/groups")
{
groups.GET("", h.Admin.Group.List)
groups.GET("/all", h.Admin.Group.GetAll)
groups.GET("/:id", h.Admin.Group.GetByID)
groups.POST("", h.Admin.Group.Create)
groups.PUT("/:id", h.Admin.Group.Update)
groups.DELETE("/:id", h.Admin.Group.Delete)
groups.GET("/:id/stats", h.Admin.Group.GetStats)
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
}
}
func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts := admin.Group("/accounts")
{
accounts.GET("", h.Admin.Account.List)
accounts.GET("/:id", h.Admin.Account.GetByID)
accounts.POST("", h.Admin.Account.Create)
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
accounts.PUT("/:id", h.Admin.Account.Update)
accounts.DELETE("/:id", h.Admin.Account.Delete)
accounts.POST("/:id/test", h.Admin.Account.Test)
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
accounts.POST("/:id/refresh-tier", h.Admin.Account.RefreshTier)
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
accounts.POST("/batch", h.Admin.Account.BatchCreate)
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
accounts.POST("/batch-refresh-tier", h.Admin.Account.BatchRefreshTier)
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
// Claude OAuth routes
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL)
accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode)
accounts.POST("/exchange-setup-token-code", h.Admin.OAuth.ExchangeSetupTokenCode)
accounts.POST("/cookie-auth", h.Admin.OAuth.CookieAuth)
accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth)
}
}
func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
openai := admin.Group("/openai")
{
openai.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
openai.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
openai.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
openai.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
openai.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
}
}
func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
gemini := admin.Group("/gemini")
{
gemini.POST("/oauth/auth-url", h.Admin.GeminiOAuth.GenerateAuthURL)
gemini.POST("/oauth/exchange-code", h.Admin.GeminiOAuth.ExchangeCode)
gemini.GET("/oauth/capabilities", h.Admin.GeminiOAuth.GetCapabilities)
}
}
func registerAntigravityOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
antigravity := admin.Group("/antigravity")
{
antigravity.POST("/oauth/auth-url", h.Admin.AntigravityOAuth.GenerateAuthURL)
antigravity.POST("/oauth/exchange-code", h.Admin.AntigravityOAuth.ExchangeCode)
}
}
func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
proxies := admin.Group("/proxies")
{
proxies.GET("", h.Admin.Proxy.List)
proxies.GET("/all", h.Admin.Proxy.GetAll)
proxies.GET("/:id", h.Admin.Proxy.GetByID)
proxies.POST("", h.Admin.Proxy.Create)
proxies.PUT("/:id", h.Admin.Proxy.Update)
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
proxies.POST("/:id/test", h.Admin.Proxy.Test)
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
}
}
func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
codes := admin.Group("/redeem-codes")
{
codes.GET("", h.Admin.Redeem.List)
codes.GET("/stats", h.Admin.Redeem.GetStats)
codes.GET("/export", h.Admin.Redeem.Export)
codes.GET("/:id", h.Admin.Redeem.GetByID)
codes.POST("/generate", h.Admin.Redeem.Generate)
codes.DELETE("/:id", h.Admin.Redeem.Delete)
codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete)
codes.POST("/:id/expire", h.Admin.Redeem.Expire)
}
}
func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
adminSettings := admin.Group("/settings")
{
adminSettings.GET("", h.Admin.Setting.GetSettings)
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
// Admin API Key 管理
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminApiKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminApiKey)
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminApiKey)
}
}
func registerSystemRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
system := admin.Group("/system")
{
system.GET("/version", h.Admin.System.GetVersion)
system.GET("/check-updates", h.Admin.System.CheckUpdates)
system.POST("/update", h.Admin.System.PerformUpdate)
system.POST("/rollback", h.Admin.System.Rollback)
system.POST("/restart", h.Admin.System.RestartService)
}
}
func registerSubscriptionRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
subscriptions := admin.Group("/subscriptions")
{
subscriptions.GET("", h.Admin.Subscription.List)
subscriptions.GET("/:id", h.Admin.Subscription.GetByID)
subscriptions.GET("/:id/progress", h.Admin.Subscription.GetProgress)
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
}
// 分组下的订阅列表
admin.GET("/groups/:id/subscriptions", h.Admin.Subscription.ListByGroup)
// 用户下的订阅列表
admin.GET("/users/:id/subscriptions", h.Admin.Subscription.ListByUser)
}
func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
usage := admin.Group("/usage")
{
usage.GET("", h.Admin.Usage.List)
usage.GET("/stats", h.Admin.Usage.Stats)
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
}
}
func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
attrs := admin.Group("/user-attributes")
{
attrs.GET("", h.Admin.UserAttribute.ListDefinitions)
attrs.POST("", h.Admin.UserAttribute.CreateDefinition)
attrs.POST("/batch", h.Admin.UserAttribute.GetBatchUserAttributes)
attrs.PUT("/reorder", h.Admin.UserAttribute.ReorderDefinitions)
attrs.PUT("/:id", h.Admin.UserAttribute.UpdateDefinition)
attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition)
}
}
package routes
import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
)
// RegisterAdminRoutes 注册管理员路由
func RegisterAdminRoutes(
v1 *gin.RouterGroup,
h *handler.Handlers,
adminAuth middleware.AdminAuthMiddleware,
) {
admin := v1.Group("/admin")
admin.Use(gin.HandlerFunc(adminAuth))
{
// 仪表盘
registerDashboardRoutes(admin, h)
// 用户管理
registerUserManagementRoutes(admin, h)
// 分组管理
registerGroupRoutes(admin, h)
// 账号管理
registerAccountRoutes(admin, h)
// OpenAI OAuth
registerOpenAIOAuthRoutes(admin, h)
// Gemini OAuth
registerGeminiOAuthRoutes(admin, h)
// Antigravity OAuth
registerAntigravityOAuthRoutes(admin, h)
// 代理管理
registerProxyRoutes(admin, h)
// 卡密管理
registerRedeemCodeRoutes(admin, h)
// 系统设置
registerSettingsRoutes(admin, h)
// 系统管理
registerSystemRoutes(admin, h)
// 订阅管理
registerSubscriptionRoutes(admin, h)
// 使用记录管理
registerUsageRoutes(admin, h)
// 用户属性管理
registerUserAttributeRoutes(admin, h)
}
}
func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard := admin.Group("/dashboard")
{
dashboard.GET("/stats", h.Admin.Dashboard.GetStats)
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage)
}
}
func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
users := admin.Group("/users")
{
users.GET("", h.Admin.User.List)
users.GET("/:id", h.Admin.User.GetByID)
users.POST("", h.Admin.User.Create)
users.PUT("/:id", h.Admin.User.Update)
users.DELETE("/:id", h.Admin.User.Delete)
users.POST("/:id/balance", h.Admin.User.UpdateBalance)
users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys)
users.GET("/:id/usage", h.Admin.User.GetUserUsage)
// User attribute values
users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes)
users.PUT("/:id/attributes", h.Admin.UserAttribute.UpdateUserAttributes)
}
}
func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
groups := admin.Group("/groups")
{
groups.GET("", h.Admin.Group.List)
groups.GET("/all", h.Admin.Group.GetAll)
groups.GET("/:id", h.Admin.Group.GetByID)
groups.POST("", h.Admin.Group.Create)
groups.PUT("/:id", h.Admin.Group.Update)
groups.DELETE("/:id", h.Admin.Group.Delete)
groups.GET("/:id/stats", h.Admin.Group.GetStats)
groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys)
}
}
func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts := admin.Group("/accounts")
{
accounts.GET("", h.Admin.Account.List)
accounts.GET("/:id", h.Admin.Account.GetByID)
accounts.POST("", h.Admin.Account.Create)
accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS)
accounts.PUT("/:id", h.Admin.Account.Update)
accounts.DELETE("/:id", h.Admin.Account.Delete)
accounts.POST("/:id/test", h.Admin.Account.Test)
accounts.POST("/:id/refresh", h.Admin.Account.Refresh)
accounts.POST("/:id/refresh-tier", h.Admin.Account.RefreshTier)
accounts.GET("/:id/stats", h.Admin.Account.GetStats)
accounts.POST("/:id/clear-error", h.Admin.Account.ClearError)
accounts.GET("/:id/usage", h.Admin.Account.GetUsage)
accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats)
accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit)
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
accounts.POST("/batch", h.Admin.Account.BatchCreate)
accounts.POST("/batch-update-credentials", h.Admin.Account.BatchUpdateCredentials)
accounts.POST("/batch-refresh-tier", h.Admin.Account.BatchRefreshTier)
accounts.POST("/bulk-update", h.Admin.Account.BulkUpdate)
// Claude OAuth routes
accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL)
accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL)
accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode)
accounts.POST("/exchange-setup-token-code", h.Admin.OAuth.ExchangeSetupTokenCode)
accounts.POST("/cookie-auth", h.Admin.OAuth.CookieAuth)
accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth)
}
}
func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
openai := admin.Group("/openai")
{
openai.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL)
openai.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode)
openai.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken)
openai.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken)
openai.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth)
}
}
func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
gemini := admin.Group("/gemini")
{
gemini.POST("/oauth/auth-url", h.Admin.GeminiOAuth.GenerateAuthURL)
gemini.POST("/oauth/exchange-code", h.Admin.GeminiOAuth.ExchangeCode)
gemini.GET("/oauth/capabilities", h.Admin.GeminiOAuth.GetCapabilities)
}
}
func registerAntigravityOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
antigravity := admin.Group("/antigravity")
{
antigravity.POST("/oauth/auth-url", h.Admin.AntigravityOAuth.GenerateAuthURL)
antigravity.POST("/oauth/exchange-code", h.Admin.AntigravityOAuth.ExchangeCode)
}
}
func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
proxies := admin.Group("/proxies")
{
proxies.GET("", h.Admin.Proxy.List)
proxies.GET("/all", h.Admin.Proxy.GetAll)
proxies.GET("/:id", h.Admin.Proxy.GetByID)
proxies.POST("", h.Admin.Proxy.Create)
proxies.PUT("/:id", h.Admin.Proxy.Update)
proxies.DELETE("/:id", h.Admin.Proxy.Delete)
proxies.POST("/:id/test", h.Admin.Proxy.Test)
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
}
}
func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
codes := admin.Group("/redeem-codes")
{
codes.GET("", h.Admin.Redeem.List)
codes.GET("/stats", h.Admin.Redeem.GetStats)
codes.GET("/export", h.Admin.Redeem.Export)
codes.GET("/:id", h.Admin.Redeem.GetByID)
codes.POST("/generate", h.Admin.Redeem.Generate)
codes.DELETE("/:id", h.Admin.Redeem.Delete)
codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete)
codes.POST("/:id/expire", h.Admin.Redeem.Expire)
}
}
func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
adminSettings := admin.Group("/settings")
{
adminSettings.GET("", h.Admin.Setting.GetSettings)
adminSettings.PUT("", h.Admin.Setting.UpdateSettings)
adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection)
adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail)
// Admin API Key 管理
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminApiKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminApiKey)
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminApiKey)
}
}
func registerSystemRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
system := admin.Group("/system")
{
system.GET("/version", h.Admin.System.GetVersion)
system.GET("/check-updates", h.Admin.System.CheckUpdates)
system.POST("/update", h.Admin.System.PerformUpdate)
system.POST("/rollback", h.Admin.System.Rollback)
system.POST("/restart", h.Admin.System.RestartService)
}
}
func registerSubscriptionRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
subscriptions := admin.Group("/subscriptions")
{
subscriptions.GET("", h.Admin.Subscription.List)
subscriptions.GET("/:id", h.Admin.Subscription.GetByID)
subscriptions.GET("/:id/progress", h.Admin.Subscription.GetProgress)
subscriptions.POST("/assign", h.Admin.Subscription.Assign)
subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign)
subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend)
subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke)
}
// 分组下的订阅列表
admin.GET("/groups/:id/subscriptions", h.Admin.Subscription.ListByGroup)
// 用户下的订阅列表
admin.GET("/users/:id/subscriptions", h.Admin.Subscription.ListByUser)
}
func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
usage := admin.Group("/usage")
{
usage.GET("", h.Admin.Usage.List)
usage.GET("/stats", h.Admin.Usage.Stats)
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys)
}
}
func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
attrs := admin.Group("/user-attributes")
{
attrs.GET("", h.Admin.UserAttribute.ListDefinitions)
attrs.POST("", h.Admin.UserAttribute.CreateDefinition)
attrs.POST("/batch", h.Admin.UserAttribute.GetBatchUserAttributes)
attrs.PUT("/reorder", h.Admin.UserAttribute.ReorderDefinitions)
attrs.PUT("/:id", h.Admin.UserAttribute.UpdateDefinition)
attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition)
}
}

View File

@@ -1,36 +1,36 @@
package routes
import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
)
// RegisterAuthRoutes 注册认证相关路由
func RegisterAuthRoutes(
v1 *gin.RouterGroup,
h *handler.Handlers,
jwtAuth middleware.JWTAuthMiddleware,
) {
// 公开接口
auth := v1.Group("/auth")
{
auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
}
// 公开设置(无需认证)
settings := v1.Group("/settings")
{
settings.GET("/public", h.Setting.GetPublicSettings)
}
// 需要认证的当前用户信息
authenticated := v1.Group("")
authenticated.Use(gin.HandlerFunc(jwtAuth))
{
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
}
}
package routes
import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
)
// RegisterAuthRoutes 注册认证相关路由
func RegisterAuthRoutes(
v1 *gin.RouterGroup,
h *handler.Handlers,
jwtAuth middleware.JWTAuthMiddleware,
) {
// 公开接口
auth := v1.Group("/auth")
{
auth.POST("/register", h.Auth.Register)
auth.POST("/login", h.Auth.Login)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
}
// 公开设置(无需认证)
settings := v1.Group("/settings")
{
settings.GET("/public", h.Setting.GetPublicSettings)
}
// 需要认证的当前用户信息
authenticated := v1.Group("")
authenticated.Use(gin.HandlerFunc(jwtAuth))
{
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
}
}

View File

@@ -1,32 +1,32 @@
package routes
import (
"net/http"
"github.com/gin-gonic/gin"
)
// RegisterCommonRoutes 注册通用路由(健康检查、状态等)
func RegisterCommonRoutes(r *gin.Engine) {
// 健康检查
r.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
// Claude Code 遥测日志忽略直接返回200
r.POST("/api/event_logging/batch", func(c *gin.Context) {
c.Status(http.StatusOK)
})
// Setup status endpoint (always returns needs_setup: false in normal mode)
// This is used by the frontend to detect when the service has restarted after setup
r.GET("/setup/status", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": gin.H{
"needs_setup": false,
"step": "completed",
},
})
})
}
package routes
import (
"net/http"
"github.com/gin-gonic/gin"
)
// RegisterCommonRoutes 注册通用路由(健康检查、状态等)
func RegisterCommonRoutes(r *gin.Engine) {
// 健康检查
r.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
// Claude Code 遥测日志忽略直接返回200
r.POST("/api/event_logging/batch", func(c *gin.Context) {
c.Status(http.StatusOK)
})
// Setup status endpoint (always returns needs_setup: false in normal mode)
// This is used by the frontend to detect when the service has restarted after setup
r.GET("/setup/status", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": gin.H{
"needs_setup": false,
"step": "completed",
},
})
})
}

View File

@@ -1,74 +1,74 @@
package routes
import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// RegisterGatewayRoutes 注册 API 网关路由Claude/OpenAI/Gemini 兼容)
func RegisterGatewayRoutes(
r *gin.Engine,
h *handler.Handlers,
apiKeyAuth middleware.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService,
subscriptionService *service.SubscriptionService,
cfg *config.Config,
) {
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
// API网关Claude API兼容
gateway := r.Group("/v1")
gateway.Use(bodyLimit)
gateway.Use(gin.HandlerFunc(apiKeyAuth))
{
gateway.POST("/messages", h.Gateway.Messages)
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
gateway.GET("/models", h.Gateway.Models)
gateway.GET("/usage", h.Gateway.Usage)
// OpenAI Responses API
gateway.POST("/responses", h.OpenAIGateway.Responses)
}
// Gemini 原生 API 兼容层Gemini SDK/CLI 直连)
gemini := r.Group("/v1beta")
gemini.Use(bodyLimit)
gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
// Gin treats ":" as a param marker, but Gemini uses "{model}:{action}" in the same segment.
gemini.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels)
}
// OpenAI Responses API不带v1前缀的别名
r.POST("/responses", bodyLimit, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
// Antigravity 模型列表
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels)
// Antigravity 专用路由(仅使用 antigravity 账户,不混合调度)
antigravityV1 := r.Group("/antigravity/v1")
antigravityV1.Use(bodyLimit)
antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1.Use(gin.HandlerFunc(apiKeyAuth))
{
antigravityV1.POST("/messages", h.Gateway.Messages)
antigravityV1.POST("/messages/count_tokens", h.Gateway.CountTokens)
antigravityV1.GET("/models", h.Gateway.AntigravityModels)
antigravityV1.GET("/usage", h.Gateway.Usage)
}
antigravityV1Beta := r.Group("/antigravity/v1beta")
antigravityV1Beta.Use(bodyLimit)
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{
antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels)
antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels)
}
}
package routes
import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// RegisterGatewayRoutes 注册 API 网关路由Claude/OpenAI/Gemini 兼容)
func RegisterGatewayRoutes(
r *gin.Engine,
h *handler.Handlers,
apiKeyAuth middleware.ApiKeyAuthMiddleware,
apiKeyService *service.ApiKeyService,
subscriptionService *service.SubscriptionService,
cfg *config.Config,
) {
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
// API网关Claude API兼容
gateway := r.Group("/v1")
gateway.Use(bodyLimit)
gateway.Use(gin.HandlerFunc(apiKeyAuth))
{
gateway.POST("/messages", h.Gateway.Messages)
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
gateway.GET("/models", h.Gateway.Models)
gateway.GET("/usage", h.Gateway.Usage)
// OpenAI Responses API
gateway.POST("/responses", h.OpenAIGateway.Responses)
}
// Gemini 原生 API 兼容层Gemini SDK/CLI 直连)
gemini := r.Group("/v1beta")
gemini.Use(bodyLimit)
gemini.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
// Gin treats ":" as a param marker, but Gemini uses "{model}:{action}" in the same segment.
gemini.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels)
}
// OpenAI Responses API不带v1前缀的别名
r.POST("/responses", bodyLimit, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
// Antigravity 模型列表
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels)
// Antigravity 专用路由(仅使用 antigravity 账户,不混合调度)
antigravityV1 := r.Group("/antigravity/v1")
antigravityV1.Use(bodyLimit)
antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1.Use(gin.HandlerFunc(apiKeyAuth))
{
antigravityV1.POST("/messages", h.Gateway.Messages)
antigravityV1.POST("/messages/count_tokens", h.Gateway.CountTokens)
antigravityV1.GET("/models", h.Gateway.AntigravityModels)
antigravityV1.GET("/usage", h.Gateway.Usage)
}
antigravityV1Beta := r.Group("/antigravity/v1beta")
antigravityV1Beta.Use(bodyLimit)
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1Beta.Use(middleware.ApiKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
{
antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels)
antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels)
}
}

View File

@@ -1,72 +1,72 @@
package routes
import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
)
// RegisterUserRoutes 注册用户相关路由(需要认证)
func RegisterUserRoutes(
v1 *gin.RouterGroup,
h *handler.Handlers,
jwtAuth middleware.JWTAuthMiddleware,
) {
authenticated := v1.Group("")
authenticated.Use(gin.HandlerFunc(jwtAuth))
{
// 用户接口
user := authenticated.Group("/user")
{
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile)
}
// API Key管理
keys := authenticated.Group("/keys")
{
keys.GET("", h.APIKey.List)
keys.GET("/:id", h.APIKey.GetByID)
keys.POST("", h.APIKey.Create)
keys.PUT("/:id", h.APIKey.Update)
keys.DELETE("/:id", h.APIKey.Delete)
}
// 用户可用分组(非管理员接口)
groups := authenticated.Group("/groups")
{
groups.GET("/available", h.APIKey.GetAvailableGroups)
}
// 使用记录
usage := authenticated.Group("/usage")
{
usage.GET("", h.Usage.List)
usage.GET("/:id", h.Usage.GetByID)
usage.GET("/stats", h.Usage.Stats)
// User dashboard endpoints
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
usage.GET("/dashboard/models", h.Usage.DashboardModels)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
}
// 卡密兑换
redeem := authenticated.Group("/redeem")
{
redeem.POST("", h.Redeem.Redeem)
redeem.GET("/history", h.Redeem.GetHistory)
}
// 用户订阅
subscriptions := authenticated.Group("/subscriptions")
{
subscriptions.GET("", h.Subscription.List)
subscriptions.GET("/active", h.Subscription.GetActive)
subscriptions.GET("/progress", h.Subscription.GetProgress)
subscriptions.GET("/summary", h.Subscription.GetSummary)
}
}
}
package routes
import (
"github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
)
// RegisterUserRoutes 注册用户相关路由(需要认证)
func RegisterUserRoutes(
v1 *gin.RouterGroup,
h *handler.Handlers,
jwtAuth middleware.JWTAuthMiddleware,
) {
authenticated := v1.Group("")
authenticated.Use(gin.HandlerFunc(jwtAuth))
{
// 用户接口
user := authenticated.Group("/user")
{
user.GET("/profile", h.User.GetProfile)
user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile)
}
// API Key管理
keys := authenticated.Group("/keys")
{
keys.GET("", h.APIKey.List)
keys.GET("/:id", h.APIKey.GetByID)
keys.POST("", h.APIKey.Create)
keys.PUT("/:id", h.APIKey.Update)
keys.DELETE("/:id", h.APIKey.Delete)
}
// 用户可用分组(非管理员接口)
groups := authenticated.Group("/groups")
{
groups.GET("/available", h.APIKey.GetAvailableGroups)
}
// 使用记录
usage := authenticated.Group("/usage")
{
usage.GET("", h.Usage.List)
usage.GET("/:id", h.Usage.GetByID)
usage.GET("/stats", h.Usage.Stats)
// User dashboard endpoints
usage.GET("/dashboard/stats", h.Usage.DashboardStats)
usage.GET("/dashboard/trend", h.Usage.DashboardTrend)
usage.GET("/dashboard/models", h.Usage.DashboardModels)
usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage)
}
// 卡密兑换
redeem := authenticated.Group("/redeem")
{
redeem.POST("", h.Redeem.Redeem)
redeem.GET("/history", h.Redeem.GetHistory)
}
// 用户订阅
subscriptions := authenticated.Group("/subscriptions")
{
subscriptions.GET("", h.Subscription.List)
subscriptions.GET("/active", h.Subscription.GetActive)
subscriptions.GET("/progress", h.Subscription.GetProgress)
subscriptions.GET("/summary", h.Subscription.GetSummary)
}
}
}