/*
* Copyright 2003-2006 Rick Knowles <winstone-devel at lists sourceforge net>
* Distributed under the terms of either:
* - the common development and distribution license (CDDL), v1.0; or
* - the GNU Lesser General Public License, v2.1 or later
*/
package winstone;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletRequestWrapper;
import javax.servlet.ServletResponse;
import javax.servlet.ServletResponseWrapper;
/**
* This class implements both the RequestDispatcher and FilterChain components. On
* the first call to include() or forward(), it starts the filter chain execution
* if one exists. On the final doFilter() or if there is no chain, we call the include()
* or forward() again, and the servlet is executed.
*
* @author <a href="mailto:rick_knowles@hotmail.com">Rick Knowles</a>
* @version $Id: RequestDispatcher.java,v 1.18 2007/04/23 02:55:35 rickknowles Exp $
*/
public class RequestDispatcher implements javax.servlet.RequestDispatcher,
javax.servlet.FilterChain {
static final String INCLUDE_REQUEST_URI = "javax.servlet.include.request_uri";
static final String INCLUDE_CONTEXT_PATH = "javax.servlet.include.context_path";
static final String INCLUDE_SERVLET_PATH = "javax.servlet.include.servlet_path";
static final String INCLUDE_PATH_INFO = "javax.servlet.include.path_info";
static final String INCLUDE_QUERY_STRING = "javax.servlet.include.query_string";
static final String FORWARD_REQUEST_URI = "javax.servlet.forward.request_uri";
static final String FORWARD_CONTEXT_PATH = "javax.servlet.forward.context_path";
static final String FORWARD_SERVLET_PATH = "javax.servlet.forward.servlet_path";
static final String FORWARD_PATH_INFO = "javax.servlet.forward.path_info";
static final String FORWARD_QUERY_STRING = "javax.servlet.forward.query_string";
static final String ERROR_STATUS_CODE = "javax.servlet.error.status_code";
static final String ERROR_EXCEPTION_TYPE = "javax.servlet.error.exception_type";
static final String ERROR_MESSAGE = "javax.servlet.error.message";
static final String ERROR_EXCEPTION = "javax.servlet.error.exception";
static final String ERROR_REQUEST_URI = "javax.servlet.error.request_uri";
static final String ERROR_SERVLET_NAME = "javax.servlet.error.servlet_name";
private WebAppConfiguration webAppConfig;
private ServletConfiguration servletConfig;
private String servletPath;
private String pathInfo;
private String queryString;
private String requestURI;
private Integer errorStatusCode;
private Throwable errorException;
private String errorSummaryMessage;
private AuthenticationHandler authHandler;
private Mapping forwardFilterPatterns[];
private Mapping includeFilterPatterns[];
private FilterConfiguration matchingFilters[];
private int matchingFiltersEvaluated;
private Boolean doInclude;
private boolean isErrorDispatch;
private boolean useRequestAttributes;
private WebAppConfiguration includedWebAppConfig;
private ServletConfiguration includedServletConfig;
/**
* Constructor. This initializes the filter chain and sets up the details
* needed to handle a servlet excecution, such as security constraints,
* filters, etc.
*/
public RequestDispatcher(WebAppConfiguration webAppConfig, ServletConfiguration servletConfig) {
this.servletConfig = servletConfig;
this.webAppConfig = webAppConfig;
this.matchingFiltersEvaluated = 0;
}
public void setForNamedDispatcher(Mapping forwardFilterPatterns[],
Mapping includeFilterPatterns[]) {
this.forwardFilterPatterns = forwardFilterPatterns;
this.includeFilterPatterns = includeFilterPatterns;
this.matchingFilters = null; // set after the call to forward or include
this.useRequestAttributes = false;
this.isErrorDispatch = false;
}
public void setForURLDispatcher(String servletPath, String pathInfo,
String queryString, String requestURIInsideWebapp,
Mapping forwardFilterPatterns[], Mapping includeFilterPatterns[]) {
this.servletPath = servletPath;
this.pathInfo = pathInfo;
this.queryString = queryString;
this.requestURI = requestURIInsideWebapp;
this.forwardFilterPatterns = forwardFilterPatterns;
this.includeFilterPatterns = includeFilterPatterns;
this.matchingFilters = null; // set after the call to forward or include
this.useRequestAttributes = true;
this.isErrorDispatch = false;
}
public void setForErrorDispatcher(String servletPath, String pathInfo,
String queryString, int statusCode, String summaryMessage,
Throwable exception, String errorHandlerURI,
Mapping errorFilterPatterns[]) {
this.servletPath = servletPath;
this.pathInfo = pathInfo;
this.queryString = queryString;
this.requestURI = errorHandlerURI;
this.errorStatusCode = new Integer(statusCode);
this.errorException = exception;
this.errorSummaryMessage = summaryMessage;
this.matchingFilters = getMatchingFilters(errorFilterPatterns, this.webAppConfig,
servletPath + (pathInfo == null ? "" : pathInfo),
getName(), "ERROR", (servletPath != null));
this.useRequestAttributes = true;
this.isErrorDispatch = true;
}
public void setForInitialDispatcher(String servletPath, String pathInfo,
String queryString, String requestURIInsideWebapp, Mapping requestFilterPatterns[],
AuthenticationHandler authHandler) {
this.servletPath = servletPath;
this.pathInfo = pathInfo;
this.queryString = queryString;
this.requestURI = requestURIInsideWebapp;
this.authHandler = authHandler;
this.matchingFilters = getMatchingFilters(requestFilterPatterns, this.webAppConfig,
servletPath + (pathInfo == null ? "" : pathInfo),
getName(), "REQUEST", (servletPath != null));
this.useRequestAttributes = false;
this.isErrorDispatch = false;
}
public String getName() {
return this.servletConfig.getServletName();
}
/**
* Includes the execution of a servlet into the current request
*
* Note this method enters itself twice: once with the initial call, and once again
* when all the filters have completed.
*/
public void include(ServletRequest request, ServletResponse response)
throws ServletException, IOException {
// On the first call, log and initialise the filter chain
if (this.doInclude == null) {
Logger.log(Logger.DEBUG, Launcher.RESOURCES,
"RequestDispatcher.IncludeMessage", new String[] {
getName(), this.requestURI });
WinstoneRequest wr = getUnwrappedRequest(request);
// Add the query string to the included query string stack
wr.addIncludeQueryParameters(this.queryString);
// Set request attributes
if (useRequestAttributes) {
wr.addIncludeAttributes(this.webAppConfig.getContextPath() + this.requestURI,
this.webAppConfig.getContextPath(), this.servletPath, this.pathInfo, this.queryString);
}
// Add another include buffer to the response stack
WinstoneResponse wresp = getUnwrappedResponse(response);
wresp.startIncludeBuffer();
this.includedServletConfig = wr.getServletConfig();
this.includedWebAppConfig = wr.getWebAppConfig();
wr.setServletConfig(this.servletConfig);
wr.setWebAppConfig(this.webAppConfig);
wresp.setWebAppConfig(this.webAppConfig);
this.doInclude = Boolean.TRUE;
}
if (this.matchingFilters == null) {
this.matchingFilters = getMatchingFilters(this.includeFilterPatterns, this.webAppConfig,
this.servletPath + (this.pathInfo == null ? "" : this.pathInfo),
getName(), "INCLUDE", (this.servletPath != null));
}
try {
// Make sure the filter chain is exhausted first
if (this.matchingFiltersEvaluated < this.matchingFilters.length) {
doFilter(request, response);
finishInclude(request, response);
} else {
try {
this.servletConfig.execute(request, response,
this.webAppConfig.getContextPath() + this.requestURI);
} finally {
if (this.matchingFilters.length == 0) {
finishInclude(request, response);
}
}
}
} catch (Throwable err) {
finishInclude(request, response);
if (err instanceof ServletException) {
throw (ServletException) err;
} else if (err instanceof IOException) {
throw (IOException) err;
} else if (err instanceof Error) {
throw (Error) err;
} else {
throw (RuntimeException) err;
}
}
}
private void finishInclude(ServletRequest request, ServletResponse response)
throws IOException {
WinstoneRequest wr = getUnwrappedRequest(request);
wr.removeIncludeQueryString();
// Set request attributes
if (useRequestAttributes) {
wr.removeIncludeAttributes();
}
// Remove the include buffer from the response stack
WinstoneResponse wresp = getUnwrappedResponse(response);
wresp.finishIncludeBuffer();
if (this.includedServletConfig != null) {
wr.setServletConfig(this.includedServletConfig);
this.includedServletConfig = null;
}
if (this.includedWebAppConfig != null) {
wr.setWebAppConfig(this.includedWebAppConfig);
wresp.setWebAppConfig(this.includedWebAppConfig);
this.includedWebAppConfig = null;
}
}
/**
* Forwards to another servlet, and when it's finished executing that other
* servlet, cut off execution.
*
* Note this method enters itself twice: once with the initial call, and once again
* when all the filters have completed.
*/
public void forward(ServletRequest request, ServletResponse response)
throws ServletException, IOException {
// Only on the first call to forward, we should set any forwarding attributes
if (this.doInclude == null) {
Logger.log(Logger.DEBUG, Launcher.RESOURCES,
"RequestDispatcher.ForwardMessage", new String[] {
getName(), this.requestURI });
if (response.isCommitted()) {
throw new IllegalStateException(Launcher.RESOURCES.getString(
"RequestDispatcher.ForwardCommitted"));
}
WinstoneRequest req = getUnwrappedRequest(request);
WinstoneResponse rsp = getUnwrappedResponse(response);
// Clear the include stack if one has been accumulated
rsp.resetBuffer();
req.clearIncludeStackForForward();
rsp.clearIncludeStackForForward();
// Set request attributes (because it's the first step in the filter chain of a forward or error)
if (useRequestAttributes) {
req.setAttribute(FORWARD_REQUEST_URI, req.getRequestURI());
req.setAttribute(FORWARD_CONTEXT_PATH, req.getContextPath());
req.setAttribute(FORWARD_SERVLET_PATH, req.getServletPath());
req.setAttribute(FORWARD_PATH_INFO, req.getPathInfo());
req.setAttribute(FORWARD_QUERY_STRING, req.getQueryString());
if (this.isErrorDispatch) {
req.setAttribute(ERROR_REQUEST_URI, req.getRequestURI());
req.setAttribute(ERROR_STATUS_CODE, this.errorStatusCode);
req.setAttribute(ERROR_MESSAGE,
errorSummaryMessage != null ? errorSummaryMessage : "");
if (req.getServletConfig() != null) {
req.setAttribute(ERROR_SERVLET_NAME, req.getServletConfig().getServletName());
}
if (this.errorException != null) {
req.setAttribute(ERROR_EXCEPTION_TYPE, this.errorException.getClass());
req.setAttribute(ERROR_EXCEPTION, this.errorException);
}
// Revert back to the original request and response
rsp.setErrorStatusCode(this.errorStatusCode.intValue());
request = req;
response = rsp;
}
}
req.setServletPath(this.servletPath);
req.setPathInfo(this.pathInfo);
req.setRequestURI(this.webAppConfig.getContextPath() + this.requestURI);
req.setForwardQueryString(this.queryString);
req.setWebAppConfig(this.webAppConfig);
req.setServletConfig(this.servletConfig);
req.setRequestAttributeListeners(this.webAppConfig.getRequestAttributeListeners());
rsp.setWebAppConfig(this.webAppConfig);
// Forwards haven't set up the filter pattern set yet
if (this.matchingFilters == null) {
this.matchingFilters = getMatchingFilters(this.forwardFilterPatterns, this.webAppConfig,
this.servletPath + (this.pathInfo == null ? "" : this.pathInfo),
getName(), "FORWARD", (this.servletPath != null));
}
// Otherwise we are an initial or error dispatcher, so check security if initial -
// if we should not continue, return
else if (!this.isErrorDispatch && !continueAfterSecurityCheck(request, response)) {
return;
}
this.doInclude = Boolean.FALSE;
}
// Make sure the filter chain is exhausted first
boolean outsideFilter = (this.matchingFiltersEvaluated == 0);
if (this.matchingFiltersEvaluated < this.matchingFilters.length) {
doFilter(request, response);
} else {
this.servletConfig.execute(request, response, this.webAppConfig.getContextPath() + this.requestURI);
}
// Stop any output after the final filter has been executed (e.g. from forwarding servlet)
if (outsideFilter) {
WinstoneResponse rsp = getUnwrappedResponse(response);
rsp.flushBuffer();
rsp.getWinstoneOutputStream().setClosed(true);
}
}
private boolean continueAfterSecurityCheck(ServletRequest request,
ServletResponse response) throws IOException, ServletException {
// Evaluate security constraints
if (this.authHandler != null) {
return this.authHandler.processAuthentication(request, response,
this.servletPath + (this.pathInfo == null ? "" : this.pathInfo));
} else {
return true;
}
}
/**
* Handles the processing of the chain of filters, so that we process them
* all, then pass on to the main servlet
*/
public void doFilter(ServletRequest request, ServletResponse response)
throws ServletException, IOException {
// Loop through the filter mappings until we hit the end
while (this.matchingFiltersEvaluated < this.matchingFilters.length) {
FilterConfiguration filter = this.matchingFilters[this.matchingFiltersEvaluated++];
Logger.log(Logger.DEBUG, Launcher.RESOURCES,
"RequestDispatcher.ExecutingFilter", filter.getFilterName());
filter.execute(request, response, this);
return;
}
// Forward / include as requested in the beginning
if (this.doInclude == null)
return; // will never happen, because we can't call doFilter before forward/include
else if (this.doInclude.booleanValue())
include(request, response);
else
forward(request, response);
}
/**
* Caches the filter matching, so that if the same URL is requested twice, we don't recalculate the
* filter matching every time.
*/
private static FilterConfiguration[] getMatchingFilters(Mapping filterPatterns[],
WebAppConfiguration webAppConfig, String fullPath, String servletName,
String filterChainType, boolean isURLBasedMatch) {
String cacheKey = null;
if (isURLBasedMatch) {
cacheKey = filterChainType + ":URI:" + fullPath;
} else {
cacheKey = filterChainType + ":Servlet:" + servletName;
}
FilterConfiguration matchingFilters[] = null;
Map cache = webAppConfig.getFilterMatchCache();
synchronized (cache) {
matchingFilters = (FilterConfiguration []) cache.get(cacheKey);
if (matchingFilters == null) {
Logger.log(Logger.FULL_DEBUG, Launcher.RESOURCES,
"RequestDispatcher.CalcFilterChain", cacheKey);
List outFilters = new ArrayList();
for (int n = 0; n < filterPatterns.length; n++) {
// Get the pattern and eval it, bumping up the eval'd count
Mapping filterPattern = filterPatterns[n];
// If the servlet name matches this name, execute it
if ((filterPattern.getLinkName() != null)
&& (filterPattern.getLinkName().equals(servletName) ||
filterPattern.getLinkName().equals("*"))) {
outFilters.add(webAppConfig.getFilters().get(filterPattern.getMappedTo()));
}
// If the url path matches this filters mappings
else if ((filterPattern.getLinkName() == null) && isURLBasedMatch
&& filterPattern.match(fullPath, null, null)) {
outFilters.add(webAppConfig.getFilters().get(filterPattern.getMappedTo()));
}
}
matchingFilters = (FilterConfiguration []) outFilters.toArray(new FilterConfiguration[0]);
cache.put(cacheKey, matchingFilters);
} else {
Logger.log(Logger.FULL_DEBUG, Launcher.RESOURCES,
"RequestDispatcher.UseCachedFilterChain", cacheKey);
}
}
return matchingFilters;
}
/**
* Unwrap back to the original container allocated request object
*/
protected WinstoneRequest getUnwrappedRequest(ServletRequest request) {
ServletRequest workingRequest = request;
while (workingRequest instanceof ServletRequestWrapper) {
workingRequest = ((ServletRequestWrapper) workingRequest).getRequest();
}
return (WinstoneRequest) workingRequest;
}
/**
* Unwrap back to the original container allocated response object
*/
protected WinstoneResponse getUnwrappedResponse(ServletResponse response) {
ServletResponse workingResponse = response;
while (workingResponse instanceof ServletResponseWrapper) {
workingResponse = ((ServletResponseWrapper) workingResponse).getResponse();
}
return (WinstoneResponse) workingResponse;
}
}