feat(安全): 强化安全策略与配置校验

- 增加 CORS/CSP/安全响应头与代理信任配置

- 引入 URL 白名单与私网开关,校验上游与价格源

- 改善 API Key 处理与网关错误返回

- 管理端设置隐藏敏感字段并优化前端提示

- 增加计费熔断与相关配置示例

测试: go test ./...
This commit is contained in:
yangjianbo
2026-01-02 17:40:57 +08:00
parent 3fd9bd4a80
commit bd4bf00856
46 changed files with 1572 additions and 220 deletions

View File

@@ -19,6 +19,13 @@ func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, subscriptionS
// apiKeyAuthWithSubscription API Key认证中间件支持订阅验证
func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
queryKey := strings.TrimSpace(c.Query("key"))
queryApiKey := strings.TrimSpace(c.Query("api_key"))
if queryKey != "" || queryApiKey != "" {
AbortWithError(c, 400, "api_key_in_query_deprecated", "API key in query parameter is deprecated. Please use Authorization header instead.")
return
}
// 尝试从Authorization header中提取API key (Bearer scheme)
authHeader := c.GetHeader("Authorization")
var apiKeyString string
@@ -41,19 +48,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.ApiKeyService, subscripti
apiKeyString = c.GetHeader("x-goog-api-key")
}
// 如果header中没有尝试从query参数中提取Google API key风格
if apiKeyString == "" {
apiKeyString = c.Query("key")
}
// 兼容常见别名
if apiKeyString == "" {
apiKeyString = c.Query("api_key")
}
// 如果所有header都没有API key
if apiKeyString == "" {
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, x-goog-api-key header, or key/api_key query parameter")
AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, or x-goog-api-key header")
return
}

View File

@@ -22,6 +22,10 @@ func ApiKeyAuthGoogle(apiKeyService *service.ApiKeyService, cfg *config.Config)
// It is intended for Gemini native endpoints (/v1beta) to match Gemini SDK expectations.
func ApiKeyAuthWithSubscriptionGoogle(apiKeyService *service.ApiKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
if v := strings.TrimSpace(c.Query("api_key")); v != "" {
abortWithGoogleError(c, 400, "Query parameter api_key is deprecated. Use Authorization header or key instead.")
return
}
apiKeyString := extractAPIKeyFromRequest(c)
if apiKeyString == "" {
abortWithGoogleError(c, 401, "API key is required")
@@ -116,15 +120,18 @@ func extractAPIKeyFromRequest(c *gin.Context) string {
if v := strings.TrimSpace(c.GetHeader("x-goog-api-key")); v != "" {
return v
}
if v := strings.TrimSpace(c.Query("key")); v != "" {
return v
}
if v := strings.TrimSpace(c.Query("api_key")); v != "" {
return v
if allowGoogleQueryKey(c.Request.URL.Path) {
if v := strings.TrimSpace(c.Query("key")); v != "" {
return v
}
}
return ""
}
func allowGoogleQueryKey(path string) bool {
return strings.HasPrefix(path, "/v1beta") || strings.HasPrefix(path, "/antigravity/v1beta")
}
func abortWithGoogleError(c *gin.Context, status int, message string) {
c.JSON(status, gin.H{
"error": gin.H{

View File

@@ -109,6 +109,58 @@ func TestApiKeyAuthWithSubscriptionGoogle_MissingKey(t *testing.T) {
require.Equal(t, "UNAUTHENTICATED", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return nil, errors.New("should not be called")
},
})
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, &config.Config{}))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test?api_key=legacy", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
var resp googleErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
require.Equal(t, http.StatusBadRequest, resp.Error.Code)
require.Equal(t, "Query parameter api_key is deprecated. Use Authorization header or key instead.", resp.Error.Message)
require.Equal(t, "INVALID_ARGUMENT", resp.Error.Status)
}
func TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
apiKeyService := newTestApiKeyService(fakeApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.ApiKey, error) {
return &service.ApiKey{
ID: 1,
Key: key,
Status: service.StatusActive,
User: &service.User{
ID: 123,
Status: service.StatusActive,
},
}, nil
},
})
cfg := &config.Config{RunMode: config.RunModeSimple}
r.Use(ApiKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) })
req := httptest.NewRequest(http.MethodGet, "/v1beta/test?key=valid", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
require.Equal(t, http.StatusOK, rec.Code)
}
func TestApiKeyAuthWithSubscriptionGoogle_InvalidKey(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@@ -1,24 +1,103 @@
package middleware
import (
"log"
"net/http"
"strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
var corsWarningOnce sync.Once
// CORS 跨域中间件
func CORS() gin.HandlerFunc {
func CORS(cfg config.CORSConfig) gin.HandlerFunc {
allowedOrigins := normalizeOrigins(cfg.AllowedOrigins)
allowAll := false
for _, origin := range allowedOrigins {
if origin == "*" {
allowAll = true
break
}
}
wildcardWithSpecific := allowAll && len(allowedOrigins) > 1
if wildcardWithSpecific {
allowedOrigins = []string{"*"}
}
allowCredentials := cfg.AllowCredentials
corsWarningOnce.Do(func() {
if len(allowedOrigins) == 0 {
log.Println("Warning: CORS allowed_origins not configured; cross-origin requests will be rejected.")
}
if wildcardWithSpecific {
log.Println("Warning: CORS allowed_origins includes '*'; wildcard will take precedence over explicit origins.")
}
if allowAll && allowCredentials {
log.Println("Warning: CORS allowed_origins set to '*', disabling allow_credentials.")
}
})
if allowAll && allowCredentials {
allowCredentials = false
}
allowedSet := make(map[string]struct{}, len(allowedOrigins))
for _, origin := range allowedOrigins {
if origin == "" || origin == "*" {
continue
}
allowedSet[origin] = struct{}{}
}
return func(c *gin.Context) {
// 设置允许跨域的响应头
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
origin := strings.TrimSpace(c.GetHeader("Origin"))
originAllowed := allowAll
if origin != "" && !allowAll {
_, originAllowed = allowedSet[origin]
}
if originAllowed {
if allowAll {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
} else if origin != "" {
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
c.Writer.Header().Add("Vary", "Origin")
}
if allowCredentials {
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
}
}
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
// 处理预检请求
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
if c.Request.Method == http.MethodOptions {
if originAllowed {
c.AbortWithStatus(http.StatusNoContent)
} else {
c.AbortWithStatus(http.StatusForbidden)
}
return
}
c.Next()
}
}
func normalizeOrigins(values []string) []string {
if len(values) == 0 {
return nil
}
normalized := make([]string, 0, len(values))
for _, value := range values {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
continue
}
normalized = append(normalized, trimmed)
}
return normalized
}

View File

@@ -0,0 +1,26 @@
package middleware
import (
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
// SecurityHeaders sets baseline security headers for all responses.
func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
policy := strings.TrimSpace(cfg.Policy)
if policy == "" {
policy = config.DefaultCSPPolicy
}
return func(c *gin.Context) {
c.Header("X-Content-Type-Options", "nosniff")
c.Header("X-Frame-Options", "DENY")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
if cfg.Enabled {
c.Header("Content-Security-Policy", policy)
}
c.Next()
}
}