keel/net/http/middleware/cors.go
2022-06-16 17:24:22 +02:00

248 lines
6.0 KiB
Go

package middleware
import (
"net/http"
"regexp"
"strconv"
"strings"
"go.uber.org/zap"
keelhttp "github.com/foomo/keel/net/http"
)
type (
CORSOptions struct {
AllowOrigins []string
AllowMethods []string
AllowHeaders []string
AllowCredentials bool
ExposeHeaders []string
MaxAge int
}
CORSOption func(*CORSOptions)
)
// GetDefaultCORSOptions returns the default options
func GetDefaultCORSOptions() CORSOptions {
return CORSOptions{
AllowOrigins: []string{"*"},
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
}
}
// CORSWithAllowOrigins middleware option
func CORSWithAllowOrigins(v ...string) CORSOption {
return func(o *CORSOptions) {
o.AllowOrigins = v
}
}
// CORSWithAllowMethods middleware option
func CORSWithAllowMethods(v ...string) CORSOption {
return func(o *CORSOptions) {
o.AllowMethods = v
}
}
// CORSWithAllowHeaders middleware option
func CORSWithAllowHeaders(v ...string) CORSOption {
return func(o *CORSOptions) {
o.AllowHeaders = v
}
}
// CORSWithAllowCredentials middleware option
func CORSWithAllowCredentials(v bool) CORSOption {
return func(o *CORSOptions) {
o.AllowCredentials = v
}
}
// CORSWithExposeHeaders middleware option
func CORSWithExposeHeaders(v ...string) CORSOption {
return func(o *CORSOptions) {
o.ExposeHeaders = v
}
}
// CORSWithMaxAge middleware option
func CORSWithMaxAge(v int) CORSOption {
return func(o *CORSOptions) {
o.MaxAge = v
}
}
// CORS middleware
func CORS(opts ...CORSOption) Middleware {
options := GetDefaultCORSOptions()
for _, opt := range opts {
if opt != nil {
opt(&options)
}
}
return CORSWithOptions(options)
}
// CORSWithOptions middleware
func CORSWithOptions(opts CORSOptions) Middleware {
allowOriginPatterns := make([]string, len(opts.AllowOrigins))
for i, origin := range opts.AllowOrigins {
pattern := regexp.QuoteMeta(origin)
pattern = strings.ReplaceAll(pattern, "\\*", ".*")
pattern = strings.ReplaceAll(pattern, "\\?", ".")
pattern = "^" + pattern + "$"
allowOriginPatterns[i] = pattern
}
allowMethods := strings.Join(opts.AllowMethods, ",")
allowHeaders := strings.Join(opts.AllowHeaders, ",")
exposeHeaders := strings.Join(opts.ExposeHeaders, ",")
maxAge := strconv.Itoa(opts.MaxAge)
return func(l *zap.Logger, name string, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get(keelhttp.HeaderOrigin)
allowOrigin := ""
w.Header().Add(keelhttp.HeaderVary, keelhttp.HeaderOrigin)
preflight := r.Method == http.MethodOptions
// No Origin provided
if origin == "" {
if !preflight {
next.ServeHTTP(w, r)
return
}
w.WriteHeader(http.StatusNoContent)
return
}
// Check allowed origins
for _, value := range opts.AllowOrigins {
if value == "*" && opts.AllowCredentials {
allowOrigin = origin
break
}
if value == "*" || value == origin {
allowOrigin = value
break
}
if matchSubdomain(origin, value) {
allowOrigin = origin
break
}
}
// Check allowed origin patterns
for _, re := range allowOriginPatterns {
if allowOrigin == "" {
index := strings.Index(origin, "://")
if index == -1 {
continue
}
if len(origin[index+3:]) > 253 {
break
}
if match, _ := regexp.MatchString(re, origin); match {
allowOrigin = origin
break
}
}
}
// Origin not allowed
if allowOrigin == "" && !preflight {
next.ServeHTTP(w, r)
return
} else if allowOrigin == "" {
w.WriteHeader(http.StatusNoContent)
return
}
// Simple request
if !preflight {
w.Header().Set(keelhttp.HeaderAccessControlAllowOrigin, allowOrigin)
if opts.AllowCredentials {
w.Header().Set(keelhttp.HeaderAccessControlAllowCredentials, "true")
}
if exposeHeaders != "" {
w.Header().Set(keelhttp.HeaderAccessControlExposeHeaders, exposeHeaders)
}
next.ServeHTTP(w, r)
return
}
// Preflight request
w.Header().Add(keelhttp.HeaderVary, keelhttp.HeaderAccessControlRequestMethod)
w.Header().Add(keelhttp.HeaderVary, keelhttp.HeaderAccessControlRequestHeaders)
w.Header().Set(keelhttp.HeaderAccessControlAllowOrigin, allowOrigin)
w.Header().Set(keelhttp.HeaderAccessControlAllowMethods, allowMethods)
if opts.AllowCredentials {
w.Header().Set(keelhttp.HeaderAccessControlAllowCredentials, "true")
}
if allowHeaders != "" {
w.Header().Set(keelhttp.HeaderAccessControlAllowHeaders, allowHeaders)
} else if h := r.Header.Get(keelhttp.HeaderAccessControlRequestHeaders); h != "" {
w.Header().Set(keelhttp.HeaderAccessControlAllowHeaders, h)
}
if opts.MaxAge > 0 {
w.Header().Set(keelhttp.HeaderAccessControlMaxAge, maxAge)
}
w.WriteHeader(http.StatusNoContent)
})
}
}
func matchScheme(domain, pattern string) bool {
didx := strings.Index(domain, ":")
pidx := strings.Index(pattern, ":")
return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx]
}
// matchSubdomain compares authority with wildcard
func matchSubdomain(domain, pattern string) bool {
if !matchScheme(domain, pattern) {
return false
}
didx := strings.Index(domain, "://")
pidx := strings.Index(pattern, "://")
if didx == -1 || pidx == -1 {
return false
}
domAuth := domain[didx+3:]
// to avoid long loop by invalid long domain
if len(domAuth) > 253 {
return false
}
patAuth := pattern[pidx+3:]
domComp := strings.Split(domAuth, ".")
patComp := strings.Split(patAuth, ".")
for i := len(domComp)/2 - 1; i >= 0; i-- {
opp := len(domComp) - 1 - i
domComp[i], domComp[opp] = domComp[opp], domComp[i]
}
for i := len(patComp)/2 - 1; i >= 0; i-- {
opp := len(patComp) - 1 - i
patComp[i], patComp[opp] = patComp[opp], patComp[i]
}
for i, v := range domComp {
if len(patComp) <= i {
return false
}
p := patComp[i]
if p == "*" {
return true
}
if p != v {
return false
}
}
return false
}