This commit is contained in:
yangjianbo
2026-02-06 06:56:23 +08:00
94 changed files with 10861 additions and 858 deletions

View File

@@ -598,7 +598,7 @@ func newContractDeps(t *testing.T) *contractDeps {
}
userService := service.NewUserService(userRepo, nil)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo, nil, nil)
@@ -612,7 +612,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
@@ -1619,6 +1619,10 @@ func (r *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) {
return nil, errors.New("not implemented")
}
func (r *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) {
return nil, errors.New("not implemented")
}

View File

@@ -58,10 +58,39 @@ func ProvideRouter(
// ProvideHTTPServer 提供 HTTP 服务器
func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
handler := h2c.NewHandler(router, &http2.Server{})
httpHandler := http.Handler(router)
globalMaxSize := cfg.Server.MaxRequestBodySize
if globalMaxSize <= 0 {
globalMaxSize = cfg.Gateway.MaxBodySize
}
if globalMaxSize > 0 {
httpHandler = http.MaxBytesHandler(httpHandler, globalMaxSize)
log.Printf("Global max request body size: %d bytes (%.2f MB)", globalMaxSize, float64(globalMaxSize)/(1<<20))
}
// 根据配置决定是否启用 H2C
if cfg.Server.H2C.Enabled {
h2cConfig := cfg.Server.H2C
httpHandler = h2c.NewHandler(router, &http2.Server{
MaxConcurrentStreams: h2cConfig.MaxConcurrentStreams,
IdleTimeout: time.Duration(h2cConfig.IdleTimeout) * time.Second,
MaxReadFrameSize: uint32(h2cConfig.MaxReadFrameSize),
MaxUploadBufferPerConnection: int32(h2cConfig.MaxUploadBufferPerConnection),
MaxUploadBufferPerStream: int32(h2cConfig.MaxUploadBufferPerStream),
})
log.Printf("HTTP/2 Cleartext (h2c) enabled: max_concurrent_streams=%d, idle_timeout=%ds, max_read_frame_size=%d, max_upload_buffer_per_connection=%d, max_upload_buffer_per_stream=%d",
h2cConfig.MaxConcurrentStreams,
h2cConfig.IdleTimeout,
h2cConfig.MaxReadFrameSize,
h2cConfig.MaxUploadBufferPerConnection,
h2cConfig.MaxUploadBufferPerStream,
)
}
return &http.Server{
Addr: cfg.Server.Address(),
Handler: handler,
Handler: httpHandler,
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源

View File

@@ -93,6 +93,7 @@ func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService
nil, // userRepo (unused in GetByKey)
nil, // groupRepo
nil, // userSubRepo
nil, // userGroupRateRepo
nil, // cache
&config.Config{},
)
@@ -187,6 +188,7 @@ func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) {
nil,
nil,
nil,
nil,
&config.Config{RunMode: config.RunModeSimple},
)

View File

@@ -59,7 +59,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil)
router := newAuthTestRouter(apiKeyService, subscriptionService, cfg)
@@ -73,7 +73,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
t.Run("standard_mode_enforces_quota_check", func(t *testing.T) {
cfg := &config.Config{RunMode: config.RunModeStandard}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
now := time.Now()
sub := &service.UserSubscription{
@@ -150,7 +150,7 @@ func TestAPIKeyAuthSetsGroupContext(t *testing.T) {
}
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
router.GET("/t", func(c *gin.Context) {
@@ -208,7 +208,7 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
}
cfg := &config.Config{RunMode: config.RunModeSimple}
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
router := gin.New()
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))

View File

@@ -34,12 +34,16 @@ func Logger() gin.HandlerFunc {
// 客户端IP
clientIP := c.ClientIP()
// 日志格式: [时间] 状态码 | 延迟 | IP | 方法 路径
log.Printf("[GIN] %v | %3d | %13v | %15s | %-7s %s",
// 协议版本
protocol := c.Request.Proto
// 日志格式: [时间] 状态码 | 延迟 | IP | 协议 | 方法 路径
log.Printf("[GIN] %v | %3d | %13v | %15s | %-6s | %-7s %s",
endTime.Format("2006/01/02 - 15:04:05"),
statusCode,
latency,
clientIP,
protocol,
method,
path,
)

View File

@@ -67,6 +67,9 @@ func RegisterAdminRoutes(
// 用户属性管理
registerUserAttributeRoutes(admin, h)
// 错误透传规则管理
registerErrorPassthroughRoutes(admin, h)
}
}
@@ -387,3 +390,14 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition)
}
}
func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
rules := admin.Group("/error-passthrough-rules")
{
rules.GET("", h.Admin.ErrorPassthrough.List)
rules.GET("/:id", h.Admin.ErrorPassthrough.GetByID)
rules.POST("", h.Admin.ErrorPassthrough.Create)
rules.PUT("/:id", h.Admin.ErrorPassthrough.Update)
rules.DELETE("/:id", h.Admin.ErrorPassthrough.Delete)
}
}

View File

@@ -28,6 +28,12 @@ func RegisterAuthRoutes(
auth.POST("/login", h.Auth.Login)
auth.POST("/login/2fa", h.Auth.Login2FA)
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
// Token刷新接口添加速率限制每分钟最多 30 次Redis 故障时 fail-close
auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
}), h.Auth.RefreshToken)
// 登出接口公开允许未认证用户调用以撤销Refresh Token
auth.POST("/logout", h.Auth.Logout)
// 优惠码验证接口添加速率限制:每分钟最多 10 次Redis 故障时 fail-close
auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{
FailureMode: middleware.RateLimitFailClose,
@@ -59,5 +65,7 @@ func RegisterAuthRoutes(
authenticated.Use(gin.HandlerFunc(jwtAuth))
{
authenticated.GET("/auth/me", h.Auth.GetCurrentUser)
// 撤销所有会话(需要认证)
authenticated.POST("/auth/revoke-all-sessions", h.Auth.RevokeAllSessions)
}
}

View File

@@ -49,6 +49,7 @@ func RegisterUserRoutes(
groups := authenticated.Group("/groups")
{
groups.GET("/available", h.APIKey.GetAvailableGroups)
groups.GET("/rates", h.APIKey.GetUserGroupRates)
}
// 使用记录