package middleware import ( "log" "net/http" "strings" "sync" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/gin-gonic/gin" ) var corsWarningOnce sync.Once // CORS 跨域中间件 func CORS(cfg config.CORSConfig) gin.HandlerFunc { allowedOrigins := normalizeOrigins(cfg.AllowedOrigins) allowAll := false for _, origin := range allowedOrigins { if origin == "*" { allowAll = true break } } wildcardWithSpecific := allowAll && len(allowedOrigins) > 1 if wildcardWithSpecific { allowedOrigins = []string{"*"} } allowCredentials := cfg.AllowCredentials corsWarningOnce.Do(func() { if len(allowedOrigins) == 0 { log.Println("Warning: CORS allowed_origins not configured; cross-origin requests will be rejected.") } if wildcardWithSpecific { log.Println("Warning: CORS allowed_origins includes '*'; wildcard will take precedence over explicit origins.") } if allowAll && allowCredentials { log.Println("Warning: CORS allowed_origins set to '*', disabling allow_credentials.") } }) if allowAll && allowCredentials { allowCredentials = false } allowedSet := make(map[string]struct{}, len(allowedOrigins)) for _, origin := range allowedOrigins { if origin == "" || origin == "*" { continue } allowedSet[origin] = struct{}{} } return func(c *gin.Context) { origin := strings.TrimSpace(c.GetHeader("Origin")) originAllowed := allowAll if origin != "" && !allowAll { _, originAllowed = allowedSet[origin] } if originAllowed { if allowAll { c.Writer.Header().Set("Access-Control-Allow-Origin", "*") } else if origin != "" { c.Writer.Header().Set("Access-Control-Allow-Origin", origin) c.Writer.Header().Add("Vary", "Origin") } if allowCredentials { c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") } } c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With, X-API-Key") c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH") // 处理预检请求 if c.Request.Method == http.MethodOptions { if originAllowed { c.AbortWithStatus(http.StatusNoContent) } else { c.AbortWithStatus(http.StatusForbidden) } return } c.Next() } } func normalizeOrigins(values []string) []string { if len(values) == 0 { return nil } normalized := make([]string, 0, len(values)) for _, value := range values { trimmed := strings.TrimSpace(value) if trimmed == "" { continue } normalized = append(normalized, trimmed) } return normalized }