Merge branch 'main' into feat/api-key-ip-restriction
This commit is contained in:
@@ -398,7 +398,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
settingRepo := newStubSettingRepo()
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil)
|
||||
@@ -575,6 +575,10 @@ func (stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, err
|
||||
return nil, service.ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (stubGroupRepo) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
|
||||
return nil, service.ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (stubGroupRepo) Update(ctx context.Context, group *service.Group) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// ProviderSet 提供服务器层的依赖
|
||||
@@ -31,6 +32,7 @@ func ProvideRouter(
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
settingService *service.SettingService,
|
||||
redisClient *redis.Client,
|
||||
) *gin.Engine {
|
||||
if cfg.Server.Mode == "release" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
@@ -48,7 +50,7 @@ func ProvideRouter(
|
||||
}
|
||||
}
|
||||
|
||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, settingService, cfg)
|
||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, settingService, cfg, redisClient)
|
||||
}
|
||||
|
||||
// ProvideHTTPServer 提供 HTTP 服务器
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -103,6 +105,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
setGroupContext(c, apiKey.Group)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
@@ -161,6 +164,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
setGroupContext(c, apiKey.Group)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
@@ -185,3 +189,14 @@ func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool
|
||||
subscription, ok := value.(*service.UserSubscription)
|
||||
return subscription, ok
|
||||
}
|
||||
|
||||
func setGroupContext(c *gin.Context, group *service.Group) {
|
||||
if !service.IsGroupContextValid(group) {
|
||||
return
|
||||
}
|
||||
if existing, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group); ok && existing != nil && existing.ID == group.ID && service.IsGroupContextValid(existing) {
|
||||
return
|
||||
}
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.Group, group)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
@@ -63,6 +63,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
setGroupContext(c, apiKey.Group)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
@@ -102,6 +103,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
setGroupContext(c, apiKey.Group)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -133,6 +134,70 @@ func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) {
|
||||
require.Equal(t, "INVALID_ARGUMENT", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
group := &service.Group{
|
||||
ID: 99,
|
||||
Name: "g1",
|
||||
Status: service.StatusActive,
|
||||
Platform: service.PlatformGemini,
|
||||
Hydrated: true,
|
||||
}
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
Group: group,
|
||||
}
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyService := service.NewAPIKeyService(
|
||||
fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
&config.Config{RunMode: config.RunModeSimple},
|
||||
)
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
r := gin.New()
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) {
|
||||
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
|
||||
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -25,6 +26,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
ID: 42,
|
||||
Name: "sub",
|
||||
Status: service.StatusActive,
|
||||
Hydrated: true,
|
||||
SubscriptionType: service.SubscriptionTypeSubscription,
|
||||
DailyLimitUSD: &limit,
|
||||
}
|
||||
@@ -110,6 +112,129 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthSetsGroupContext(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
group := &service.Group{
|
||||
ID: 101,
|
||||
Name: "g1",
|
||||
Status: service.StatusActive,
|
||||
Platform: service.PlatformAnthropic,
|
||||
Hydrated: true,
|
||||
}
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
Group: group,
|
||||
}
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
|
||||
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
group := &service.Group{
|
||||
ID: 101,
|
||||
Name: "g1",
|
||||
Status: service.StatusActive,
|
||||
Platform: service.PlatformAnthropic,
|
||||
Hydrated: true,
|
||||
}
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
Group: group,
|
||||
}
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
|
||||
|
||||
invalidGroup := &service.Group{
|
||||
ID: group.ID,
|
||||
Platform: group.Platform,
|
||||
Status: group.Status,
|
||||
}
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
|
||||
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID || !groupFromCtx.Hydrated || groupFromCtx == invalidGroup {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, invalidGroup))
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/web"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// SetupRouter 配置路由器中间件和路由
|
||||
@@ -24,6 +25,7 @@ func SetupRouter(
|
||||
subscriptionService *service.SubscriptionService,
|
||||
settingService *service.SettingService,
|
||||
cfg *config.Config,
|
||||
redisClient *redis.Client,
|
||||
) *gin.Engine {
|
||||
// 应用中间件
|
||||
r.Use(middleware2.Logger())
|
||||
@@ -44,7 +46,7 @@ func SetupRouter(
|
||||
}
|
||||
|
||||
// 注册路由
|
||||
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg)
|
||||
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, cfg, redisClient)
|
||||
|
||||
return r
|
||||
}
|
||||
@@ -59,6 +61,7 @@ func registerRoutes(
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
cfg *config.Config,
|
||||
redisClient *redis.Client,
|
||||
) {
|
||||
// 通用路由(健康检查、状态等)
|
||||
routes.RegisterCommonRoutes(r)
|
||||
@@ -67,7 +70,7 @@ func registerRoutes(
|
||||
v1 := r.Group("/api/v1")
|
||||
|
||||
// 注册各模块路由
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
|
||||
routes.RegisterUserRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterAdminRoutes(v1, h, adminAuth)
|
||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, cfg)
|
||||
|
||||
@@ -44,6 +44,9 @@ func RegisterAdminRoutes(
|
||||
// 卡密管理
|
||||
registerRedeemCodeRoutes(admin, h)
|
||||
|
||||
// 优惠码管理
|
||||
registerPromoCodeRoutes(admin, h)
|
||||
|
||||
// 系统设置
|
||||
registerSettingsRoutes(admin, h)
|
||||
|
||||
@@ -201,6 +204,18 @@ func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
}
|
||||
}
|
||||
|
||||
func registerPromoCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
promoCodes := admin.Group("/promo-codes")
|
||||
{
|
||||
promoCodes.GET("", h.Admin.Promo.List)
|
||||
promoCodes.GET("/:id", h.Admin.Promo.GetByID)
|
||||
promoCodes.POST("", h.Admin.Promo.Create)
|
||||
promoCodes.PUT("/:id", h.Admin.Promo.Update)
|
||||
promoCodes.DELETE("/:id", h.Admin.Promo.Delete)
|
||||
promoCodes.GET("/:id/usages", h.Admin.Promo.GetUsages)
|
||||
}
|
||||
}
|
||||
|
||||
func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
adminSettings := admin.Group("/settings")
|
||||
{
|
||||
|
||||
@@ -1,24 +1,34 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/middleware"
|
||||
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RegisterAuthRoutes 注册认证相关路由
|
||||
func RegisterAuthRoutes(
|
||||
v1 *gin.RouterGroup,
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware.JWTAuthMiddleware,
|
||||
jwtAuth servermiddleware.JWTAuthMiddleware,
|
||||
redisClient *redis.Client,
|
||||
) {
|
||||
// 创建速率限制器
|
||||
rateLimiter := middleware.NewRateLimiter(redisClient)
|
||||
|
||||
// 公开接口
|
||||
auth := v1.Group("/auth")
|
||||
{
|
||||
auth.POST("/register", h.Auth.Register)
|
||||
auth.POST("/login", h.Auth.Login)
|
||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
||||
// 优惠码验证接口添加速率限制:每分钟最多 10 次
|
||||
auth.POST("/validate-promo-code", rateLimiter.Limit("validate-promo", 10, time.Minute), h.Auth.ValidatePromoCode)
|
||||
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
||||
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user