diff --git hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/http/CrossOriginFilter.java hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/http/CrossOriginFilter.java index e69de29..f08f52e 100644 --- hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/http/CrossOriginFilter.java +++ hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/http/CrossOriginFilter.java @@ -0,0 +1,139 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.http; + +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.apache.commons.lang.StringUtils; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class CrossOriginFilter implements Filter { + + private static final Log LOG = LogFactory.getLog(CrossOriginFilter.class); + + // HTTP CORS Request Headers + private static final String ORIGIN = "Origin"; + private static final String ACCESS_CONTROL_REQUEST_METHOD = "Access-Control-Request-Method"; + private static final String ACCESS_CONTROL_REQUEST_HEADERS = "Access-Control-Request-Headers"; + + // HTTP CORS Response Headers + private static final String ACCESS_CONTROL_ALLOW_ORIGIN = "Access-Control-Allow-Origin"; + private static final String ACCESS_CONTROL_ALLOW_CREDENTIALS = "Access-Control-Allow-Credentials"; + private static final String ACCESS_CONTROL_ALLOW_METHODS = "Access-Control-Allow-Methods"; + private static final String ACCESS_CONTROL_ALLOW_HEADERS = "Access-Control-Allow-Headers"; + + // Filter configuration + public static final String ALLOWED_ORIGINS = "access.control.allowed.origins"; + + private List allowedMethods = new ArrayList(); + private List allowedHeaders = new ArrayList(); + private List allowedOrigins = new ArrayList(); + + @Override + public void init(FilterConfig filterConfig) throws ServletException { + initializeAllowedMethods(filterConfig); + initializeAllowedHeaders(filterConfig); + initializeAllowedOrigins(filterConfig); + } + + @Override + public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) + throws IOException, ServletException { + doCrossFilter((HttpServletRequest) req, (HttpServletResponse) res); + chain.doFilter(req, res); + } + + @Override + public void destroy() { + allowedMethods.clear(); + allowedHeaders.clear(); + allowedOrigins.clear(); + } + + private void doCrossFilter(HttpServletRequest req, HttpServletResponse res) { + + String origin = req.getHeader(ORIGIN); + if (!isCrossOrigin(origin)) { + return; + } + if (!isOriginAllowed(origin)) { + return; + } + if (!isMethodAllowed(req)) { + return; + } + if (!areHeadersAllowed(req)) { + return; + } + + res.setHeader(ACCESS_CONTROL_ALLOW_ORIGIN, origin); + res.setHeader(ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); + res.setHeader(ACCESS_CONTROL_ALLOW_METHODS, StringUtils.join(allowedMethods, ',')); + res.setHeader(ACCESS_CONTROL_ALLOW_HEADERS, StringUtils.join(allowedHeaders, ',')); + } + + private void initializeAllowedMethods(FilterConfig filterConfig) { + allowedMethods.add("GET"); + allowedMethods.add("POST"); + allowedMethods.add("HEAD"); + } + + private void initializeAllowedHeaders(FilterConfig filterConfig) { + allowedHeaders.add("X-Requested-With"); + allowedHeaders.add("Content-Type"); + allowedHeaders.add("Accept"); + allowedHeaders.add("Origin"); + } + + private void initializeAllowedOrigins(FilterConfig filterConfig) { + String allowedOriginsConfig = filterConfig.getInitParameter(ALLOWED_ORIGINS); + allowedOrigins = Arrays.asList(allowedOriginsConfig.trim().split("\\s*,\\s*")); + } + + private boolean isCrossOrigin(String origin) { + return origin != null; + } + + private boolean areHeadersAllowed(HttpServletRequest httpReq) { + String accessControlRequestHeaders = httpReq.getHeader(ACCESS_CONTROL_REQUEST_HEADERS); + String headers[] = accessControlRequestHeaders.trim().split("\\s*,\\s*"); + return allowedHeaders.containsAll(Arrays.asList(headers)); + } + + private boolean isOriginAllowed(String origin) { + return allowedOrigins.contains(origin); + } + + private boolean isMethodAllowed(HttpServletRequest httpReq) { + String accessControlRequestMethod = httpReq.getHeader(ACCESS_CONTROL_REQUEST_METHOD); + return allowedMethods.contains(accessControlRequestMethod); + } +}