Merge pull request #738 from DaydreamCoding/feat/ungrouped-key-setting

feat(gateway): 系统设置控制未分组 Key 调度 — Handler 层中间件拦截
This commit is contained in:
Wesley Liddick
2026-03-03 21:03:31 +08:00
committed by GitHub
13 changed files with 150 additions and 7 deletions

View File

@@ -123,6 +123,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
OpsQueryModeDefault: settings.OpsQueryModeDefault,
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: settings.AllowUngroupedKeyScheduling,
})
}
@@ -193,6 +194,9 @@ type UpdateSettingsRequest struct {
OpsMetricsIntervalSeconds *int `json:"ops_metrics_interval_seconds"`
MinClaudeCodeVersion string `json:"min_claude_code_version"`
// 分组隔离
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
}
// UpdateSettings 更新系统设置
@@ -465,6 +469,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
OpsMonitoringEnabled: func() bool {
if req.OpsMonitoringEnabled != nil {
return *req.OpsMonitoringEnabled
@@ -561,6 +566,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
AllowUngroupedKeyScheduling: updatedSettings.AllowUngroupedKeyScheduling,
})
}
@@ -709,6 +715,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.MinClaudeCodeVersion != after.MinClaudeCodeVersion {
changed = append(changed, "min_claude_code_version")
}
if before.AllowUngroupedKeyScheduling != after.AllowUngroupedKeyScheduling {
changed = append(changed, "allow_ungrouped_key_scheduling")
}
if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
changed = append(changed, "purchase_subscription_enabled")
}

View File

@@ -77,6 +77,9 @@ type SystemSettings struct {
OpsMetricsIntervalSeconds int `json:"ops_metrics_interval_seconds"`
MinClaudeCodeVersion string `json:"min_claude_code_version"`
// 分组隔离
AllowUngroupedKeyScheduling bool `json:"allow_ungrouped_key_scheduling"`
}
type DefaultSubscriptionSetting struct {

View File

@@ -532,6 +532,7 @@ func TestAPIContracts(t *testing.T) {
"purchase_subscription_enabled": false,
"purchase_subscription_url": "",
"min_claude_code_version": "",
"allow_ungrouped_key_scheduling": false,
"custom_menu_items": []
}
}`,

View File

@@ -2,8 +2,11 @@ package middleware
import (
"context"
"net/http"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
)
@@ -71,3 +74,48 @@ func AbortWithError(c *gin.Context, statusCode int, code, message string) {
c.JSON(statusCode, NewErrorResponse(code, message))
c.Abort()
}
// ──────────────────────────────────────────────────────────
// RequireGroupAssignment — 未分组 Key 拦截中间件
// ──────────────────────────────────────────────────────────
// GatewayErrorWriter 定义网关错误响应格式(不同协议使用不同格式)
type GatewayErrorWriter func(c *gin.Context, status int, message string)
// AnthropicErrorWriter 按 Anthropic API 规范输出错误
func AnthropicErrorWriter(c *gin.Context, status int, message string) {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{"type": "permission_error", "message": message},
})
}
// GoogleErrorWriter 按 Google API 规范输出错误
func GoogleErrorWriter(c *gin.Context, status int, message string) {
c.JSON(status, gin.H{
"error": gin.H{
"code": status,
"message": message,
"status": googleapi.HTTPStatusToGoogleStatus(status),
},
})
}
// RequireGroupAssignment 检查 API Key 是否已分配到分组,
// 如果未分组且系统设置不允许未分组 Key 调度则返回 403。
func RequireGroupAssignment(settingService *service.SettingService, writeError GatewayErrorWriter) gin.HandlerFunc {
return func(c *gin.Context) {
apiKey, ok := GetAPIKeyFromContext(c)
if !ok || apiKey.GroupID != nil {
c.Next()
return
}
// 未分组 Key — 检查系统设置
if settingService.IsUngroupedKeySchedulingAllowed(c.Request.Context()) {
c.Next()
return
}
writeError(c, http.StatusForbidden, "API Key is not assigned to any group and cannot be used. Please contact the administrator to assign it to a group.")
c.Abort()
}
}

View File

@@ -81,7 +81,7 @@ func SetupRouter(
}
// 注册路由
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg, redisClient)
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
return r
}
@@ -96,6 +96,7 @@ func registerRoutes(
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
opsService *service.OpsService,
settingService *service.SettingService,
cfg *config.Config,
redisClient *redis.Client,
) {
@@ -110,5 +111,5 @@ func registerRoutes(
routes.RegisterUserRoutes(v1, h, jwtAuth)
routes.RegisterSoraClientRoutes(v1, h, jwtAuth)
routes.RegisterAdminRoutes(v1, h, adminAuth)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg)
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg)
}

View File

@@ -19,6 +19,7 @@ func RegisterGatewayRoutes(
apiKeyService *service.APIKeyService,
subscriptionService *service.SubscriptionService,
opsService *service.OpsService,
settingService *service.SettingService,
cfg *config.Config,
) {
bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize)
@@ -30,12 +31,17 @@ func RegisterGatewayRoutes(
clientRequestID := middleware.ClientRequestID()
opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService)
// 未分组 Key 拦截中间件(按协议格式区分错误响应)
requireGroupAnthropic := middleware.RequireGroupAssignment(settingService, middleware.AnthropicErrorWriter)
requireGroupGoogle := middleware.RequireGroupAssignment(settingService, middleware.GoogleErrorWriter)
// API网关Claude API兼容
gateway := r.Group("/v1")
gateway.Use(bodyLimit)
gateway.Use(clientRequestID)
gateway.Use(opsErrorLogger)
gateway.Use(gin.HandlerFunc(apiKeyAuth))
gateway.Use(requireGroupAnthropic)
{
gateway.POST("/messages", h.Gateway.Messages)
gateway.POST("/messages/count_tokens", h.Gateway.CountTokens)
@@ -61,6 +67,7 @@ func RegisterGatewayRoutes(
gemini.Use(clientRequestID)
gemini.Use(opsErrorLogger)
gemini.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
gemini.Use(requireGroupGoogle)
{
gemini.GET("/models", h.Gateway.GeminiV1BetaListModels)
gemini.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
@@ -69,11 +76,11 @@ func RegisterGatewayRoutes(
}
// OpenAI Responses API不带v1前缀的别名
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses)
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.ResponsesWebSocket)
r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.Responses)
r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.OpenAIGateway.ResponsesWebSocket)
// Antigravity 模型列表
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels)
r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), requireGroupAnthropic, h.Gateway.AntigravityModels)
// Antigravity 专用路由(仅使用 antigravity 账户,不混合调度)
antigravityV1 := r.Group("/antigravity/v1")
@@ -82,6 +89,7 @@ func RegisterGatewayRoutes(
antigravityV1.Use(opsErrorLogger)
antigravityV1.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1.Use(gin.HandlerFunc(apiKeyAuth))
antigravityV1.Use(requireGroupAnthropic)
{
antigravityV1.POST("/messages", h.Gateway.Messages)
antigravityV1.POST("/messages/count_tokens", h.Gateway.CountTokens)
@@ -95,6 +103,7 @@ func RegisterGatewayRoutes(
antigravityV1Beta.Use(opsErrorLogger)
antigravityV1Beta.Use(middleware.ForcePlatform(service.PlatformAntigravity))
antigravityV1Beta.Use(middleware.APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, cfg))
antigravityV1Beta.Use(requireGroupGoogle)
{
antigravityV1Beta.GET("/models", h.Gateway.GeminiV1BetaListModels)
antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel)
@@ -108,6 +117,7 @@ func RegisterGatewayRoutes(
soraV1.Use(opsErrorLogger)
soraV1.Use(middleware.ForcePlatform(service.PlatformSora))
soraV1.Use(gin.HandlerFunc(apiKeyAuth))
soraV1.Use(requireGroupAnthropic)
{
soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions)
soraV1.GET("/models", h.Gateway.Models)

View File

@@ -201,6 +201,9 @@ const (
// SettingKeyMinClaudeCodeVersion 最低 Claude Code 版本号要求 (semver, 如 "2.1.0",空值=不检查)
SettingKeyMinClaudeCodeVersion = "min_claude_code_version"
// SettingKeyAllowUngroupedKeyScheduling 允许未分组 API Key 调度(默认 false未分组 Key 返回 403
SettingKeyAllowUngroupedKeyScheduling = "allow_ungrouped_key_scheduling"
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).

View File

@@ -438,6 +438,9 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
// Claude Code version check
updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion
// 分组隔离
updates[SettingKeyAllowUngroupedKeyScheduling] = strconv.FormatBool(settings.AllowUngroupedKeyScheduling)
err = s.settingRepo.SetMultiple(ctx, updates)
if err == nil {
// 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
@@ -646,6 +649,9 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// Claude Code version check (default: empty = disabled)
SettingKeyMinClaudeCodeVersion: "",
// 分组隔离(默认不允许未分组 Key 调度)
SettingKeyAllowUngroupedKeyScheduling: "false",
}
return s.settingRepo.SetMultiple(ctx, defaults)
@@ -776,6 +782,9 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
// Claude Code version check
result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion]
// 分组隔离
result.AllowUngroupedKeyScheduling = settings[SettingKeyAllowUngroupedKeyScheduling] == "true"
return result
}
@@ -1098,6 +1107,15 @@ func (s *SettingService) GetStreamTimeoutSettings(ctx context.Context) (*StreamT
return &settings, nil
}
// IsUngroupedKeySchedulingAllowed 查询是否允许未分组 Key 调度
func (s *SettingService) IsUngroupedKeySchedulingAllowed(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyAllowUngroupedKeyScheduling)
if err != nil {
return false // fail-closed: 查询失败时默认不允许
}
return value == "true"
}
// GetMinClaudeCodeVersion 获取最低 Claude Code 版本号要求
// 使用进程内 atomic.Value 缓存60 秒 TTL热路径零锁开销
// singleflight 防止缓存过期时 thundering herd

View File

@@ -65,6 +65,9 @@ type SystemSettings struct {
// Claude Code version check
MinClaudeCodeVersion string
// 分组隔离:允许未分组 Key 调度(默认 false → 403
AllowUngroupedKeyScheduling bool
}
type DefaultSubscriptionSetting struct {