Files
sub2api/backend/internal/server/middleware/cors.go
yangjianbo 9634494ba9 fix: 修复代码审核发现的10个问题(P0安全+P1数据一致性+P2性能优化)
P0: OpenAI SSE 错误消息 JSON 注入 — 使用 json.Marshal 替代 fmt.Sprintf
P1: subscription 续期包裹 Ent 事务确保原子性
P1: CSP nonce 生成处理 crypto/rand 错误,失败降级为 unsafe-inline
P1: singleflight 透传数据库真实错误,不再吞没为 not found
P1: GetUserSubscriptionsWithProgress 提取 calculateProgress 消除 N+1
P2: billing_cache/gateway_helper 迁移到 math/rand/v2 消除全局锁争用
P2: generateRandomID 降级分支增加原子计数器防碰撞
P2: CORS 非白名单 origin 不再设置 Allow-Headers/Methods/Max-Age
P2: Turnstile 验证移除 VerifyCode 空值跳过条件防绕过
P2: Redis Cluster Lua 脚本空 KEYS 添加兼容性警告注释

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-07 22:13:45 +08:00

104 lines
2.7 KiB
Go

package middleware
import (
"log"
"net/http"
"strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
var corsWarningOnce sync.Once
// CORS 跨域中间件
func CORS(cfg config.CORSConfig) gin.HandlerFunc {
allowedOrigins := normalizeOrigins(cfg.AllowedOrigins)
allowAll := false
for _, origin := range allowedOrigins {
if origin == "*" {
allowAll = true
break
}
}
wildcardWithSpecific := allowAll && len(allowedOrigins) > 1
if wildcardWithSpecific {
allowedOrigins = []string{"*"}
}
allowCredentials := cfg.AllowCredentials
corsWarningOnce.Do(func() {
if len(allowedOrigins) == 0 {
log.Println("Warning: CORS allowed_origins not configured; cross-origin requests will be rejected.")
}
if wildcardWithSpecific {
log.Println("Warning: CORS allowed_origins includes '*'; wildcard will take precedence over explicit origins.")
}
if allowAll && allowCredentials {
log.Println("Warning: CORS allowed_origins set to '*', disabling allow_credentials.")
}
})
if allowAll && allowCredentials {
allowCredentials = false
}
allowedSet := make(map[string]struct{}, len(allowedOrigins))
for _, origin := range allowedOrigins {
if origin == "" || origin == "*" {
continue
}
allowedSet[origin] = struct{}{}
}
return func(c *gin.Context) {
origin := strings.TrimSpace(c.GetHeader("Origin"))
originAllowed := allowAll
if origin != "" && !allowAll {
_, originAllowed = allowedSet[origin]
}
if originAllowed {
if allowAll {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
} else if origin != "" {
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
c.Writer.Header().Add("Vary", "Origin")
}
if allowCredentials {
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
}
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH")
c.Writer.Header().Set("Access-Control-Max-Age", "86400")
}
// 处理预检请求
if c.Request.Method == http.MethodOptions {
if originAllowed {
c.AbortWithStatus(http.StatusNoContent)
} else {
c.AbortWithStatus(http.StatusForbidden)
}
return
}
c.Next()
}
}
func normalizeOrigins(values []string) []string {
if len(values) == 0 {
return nil
}
normalized := make([]string, 0, len(values))
for _, value := range values {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
continue
}
normalized = append(normalized, trimmed)
}
return normalized
}