/*
* Copyright 2005-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ws.server.endpoint.adapter.method.jaxb;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.Reader;
import java.io.Writer;
import java.net.URL;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import javax.xml.bind.JAXBContext;
import javax.xml.bind.JAXBElement;
import javax.xml.bind.JAXBException;
import javax.xml.bind.JAXBIntrospector;
import javax.xml.bind.Marshaller;
import javax.xml.bind.Unmarshaller;
import javax.xml.bind.UnmarshallerHandler;
import javax.xml.namespace.QName;
import javax.xml.stream.XMLEventReader;
import javax.xml.stream.XMLEventWriter;
import javax.xml.stream.XMLStreamException;
import javax.xml.stream.XMLStreamReader;
import javax.xml.stream.XMLStreamWriter;
import javax.xml.transform.Result;
import javax.xml.transform.Source;
import javax.xml.transform.sax.SAXSource;
import javax.xml.transform.stream.StreamResult;
import javax.xml.transform.stream.StreamSource;
import org.w3c.dom.Node;
import org.xml.sax.ContentHandler;
import org.xml.sax.InputSource;
import org.xml.sax.XMLReader;
import org.xml.sax.ext.LexicalHandler;
import org.springframework.core.MethodParameter;
import org.springframework.util.Assert;
import org.springframework.ws.WebServiceMessage;
import org.springframework.ws.context.MessageContext;
import org.springframework.ws.server.endpoint.adapter.method.AbstractPayloadMethodProcessor;
import org.springframework.ws.stream.StreamingPayload;
import org.springframework.ws.stream.StreamingWebServiceMessage;
import org.springframework.xml.transform.TraxUtils;
/**
* Abstract base class for {@link org.springframework.ws.server.endpoint.adapter.method.MethodArgumentResolver
* MethodArgumentResolver} and {@link org.springframework.ws.server.endpoint.adapter.method.MethodReturnValueHandler
* MethodReturnValueHandler} implementations that use JAXB2. Creates {@link JAXBContext} object lazily, and offers
* {@linkplain #marshalToResponsePayload(org.springframework.ws.context.MessageContext, Class, Object) marshalling} and
* {@linkplain #unmarshalFromRequestPayload(org.springframework.ws.context.MessageContext, Class) unmarshalling}
* methods.
*
* @author Arjen Poutsma
* @since 2.0
*/
public abstract class AbstractJaxb2PayloadMethodProcessor extends AbstractPayloadMethodProcessor {
private final ConcurrentMap<Class<?>, JAXBContext> jaxbContexts = new ConcurrentHashMap<Class<?>, JAXBContext>();
@Override
public final void handleReturnValue(MessageContext messageContext,
MethodParameter returnType, Object returnValue) throws Exception {
if (returnValue != null) {
handleReturnValueInternal(messageContext, returnType, returnValue);
}
}
protected abstract void handleReturnValueInternal(MessageContext messageContext,
MethodParameter returnType, Object returnValue) throws Exception;
/**
* Marshals the given {@code jaxbElement} to the response payload of the given message context.
*
* @param messageContext the message context to marshal to
* @param clazz the clazz to create a marshaller for
* @param jaxbElement the object to be marshalled
* @throws JAXBException in case of JAXB2 errors
*/
protected final void marshalToResponsePayload(MessageContext messageContext, Class<?> clazz, Object jaxbElement)
throws JAXBException {
Assert.notNull(messageContext, "'messageContext' must not be null");
Assert.notNull(clazz, "'clazz' must not be null");
Assert.notNull(jaxbElement, "'jaxbElement' must not be null");
if (logger.isDebugEnabled()) {
logger.debug("Marshalling [" + jaxbElement + "] to response payload");
}
WebServiceMessage response = messageContext.getResponse();
if (response instanceof StreamingWebServiceMessage) {
StreamingWebServiceMessage streamingResponse = (StreamingWebServiceMessage) response;
StreamingPayload payload = new JaxbStreamingPayload(clazz, jaxbElement);
streamingResponse.setStreamingPayload(payload);
}
else {
Result responsePayload = response.getPayloadResult();
try {
Jaxb2ResultCallback callback = new Jaxb2ResultCallback(clazz, jaxbElement);
TraxUtils.doWithResult(responsePayload, callback);
}
catch (Exception ex) {
throw convertToJaxbException(ex);
}
}
}
/**
* Unmarshals the request payload of the given message context.
*
* @param messageContext the message context to unmarshal from
* @param clazz the class to unmarshal
* @return the unmarshalled object, or {@code null} if the request has no payload
* @throws JAXBException in case of JAXB2 errors
*/
protected final Object unmarshalFromRequestPayload(MessageContext messageContext, Class<?> clazz)
throws JAXBException {
Source requestPayload = getRequestPayload(messageContext);
if (requestPayload == null) {
return null;
}
try {
Jaxb2SourceCallback callback = new Jaxb2SourceCallback(clazz);
TraxUtils.doWithSource(requestPayload, callback);
if (logger.isDebugEnabled()) {
logger.debug("Unmarshalled payload request to [" + callback.result + "]");
}
return callback.result;
}
catch (Exception ex) {
throw convertToJaxbException(ex);
}
}
/**
* Unmarshals the request payload of the given message context as {@link JAXBElement}.
*
* @param messageContext the message context to unmarshal from
* @param clazz the class to unmarshal
* @return the unmarshalled element, or {@code null} if the request has no payload
* @throws JAXBException in case of JAXB2 errors
*/
protected final <T> JAXBElement<T> unmarshalElementFromRequestPayload(MessageContext messageContext, Class<T> clazz)
throws JAXBException {
Source requestPayload = getRequestPayload(messageContext);
if (requestPayload == null) {
return null;
}
try {
JaxbElementSourceCallback<T> callback = new JaxbElementSourceCallback<T>(clazz);
TraxUtils.doWithSource(requestPayload, callback);
if (logger.isDebugEnabled()) {
logger.debug("Unmarshalled payload request to [" + callback.result + "]");
}
return callback.result;
}
catch (Exception ex) {
throw convertToJaxbException(ex);
}
}
private Source getRequestPayload(MessageContext messageContext) {
WebServiceMessage request = messageContext.getRequest();
return request != null ? request.getPayloadSource() : null;
}
private JAXBException convertToJaxbException(Exception ex) {
if (ex instanceof JAXBException) {
return (JAXBException) ex;
}
else {
return new JAXBException(ex);
}
}
/**
* Creates a new {@link Marshaller} to be used for marshalling objects to XML. Defaults to
* {@link javax.xml.bind.JAXBContext#createMarshaller()}, but can be overridden in subclasses for further
* customization.
*
* @param jaxbContext the JAXB context to create a marshaller for
* @return the marshaller
* @throws JAXBException in case of JAXB errors
*/
protected Marshaller createMarshaller(JAXBContext jaxbContext) throws JAXBException {
return jaxbContext.createMarshaller();
}
private Marshaller createMarshaller(Class<?> clazz) throws JAXBException {
return createMarshaller(getJaxbContext(clazz));
}
/**
* Creates a new {@link Unmarshaller} to be used for unmarshalling XML to objects. Defaults to
* {@link javax.xml.bind.JAXBContext#createUnmarshaller()}, but can be overridden in subclasses for further
* customization.
*
* @param jaxbContext the JAXB context to create a unmarshaller for
* @return the unmarshaller
* @throws JAXBException in case of JAXB errors
*/
protected Unmarshaller createUnmarshaller(JAXBContext jaxbContext) throws JAXBException {
return jaxbContext.createUnmarshaller();
}
private Unmarshaller createUnmarshaller(Class<?> clazz) throws JAXBException {
return createUnmarshaller(getJaxbContext(clazz));
}
private JAXBContext getJaxbContext(Class<?> clazz) throws JAXBException {
Assert.notNull(clazz, "'clazz' must not be null");
JAXBContext jaxbContext = jaxbContexts.get(clazz);
if (jaxbContext == null) {
jaxbContext = JAXBContext.newInstance(clazz);
jaxbContexts.putIfAbsent(clazz, jaxbContext);
}
return jaxbContext;
}
// Callbacks
private class Jaxb2SourceCallback implements TraxUtils.SourceCallback {
private final Unmarshaller unmarshaller;
private Object result;
public Jaxb2SourceCallback(Class<?> clazz) throws JAXBException {
this.unmarshaller = createUnmarshaller(clazz);
}
@Override
public void domSource(Node node) throws JAXBException {
result = unmarshaller.unmarshal(node);
}
@Override
public void saxSource(XMLReader reader, InputSource inputSource) throws Exception {
if (inputSource.getByteStream() == null && inputSource.getCharacterStream() == null
&& inputSource.getSystemId() == null) {
// The InputSource neither has a stream nor a system ID set; this means that
// we are dealing with a custom SAXSource that is not backed by a SAX parser
// but that generates a sequence of SAX events in some other way.
// In this case, we need to use a ContentHandler to feed the SAX events into
// the unmarshaller.
UnmarshallerHandler handler = unmarshaller.getUnmarshallerHandler();
reader.setContentHandler(handler);
reader.parse(inputSource);
result = handler.getResult();
} else {
// If a stream or system ID is set, we assume that the SAXSource is backed
// by a SAX parser and we only pass the InputSource to the unmarshaller.
// This effectively ignores the SAX parser and lets the unmarshaller take
// care of the parsing (in a potentially more efficient way).
result = unmarshaller.unmarshal(inputSource);
}
}
@Override
public void staxSource(XMLEventReader eventReader) throws JAXBException {
result = unmarshaller.unmarshal(eventReader);
}
@Override
public void staxSource(XMLStreamReader streamReader) throws JAXBException {
result = unmarshaller.unmarshal(streamReader);
}
@Override
public void streamSource(InputStream inputStream) throws IOException, JAXBException {
result = unmarshaller.unmarshal(inputStream);
}
@Override
public void streamSource(Reader reader) throws IOException, JAXBException {
result = unmarshaller.unmarshal(reader);
}
@Override
public void source(String systemId) throws Exception {
result = unmarshaller.unmarshal(new URL(systemId));
}
}
private class JaxbElementSourceCallback<T> implements TraxUtils.SourceCallback {
private final Unmarshaller unmarshaller;
private final Class<T> declaredType;
private JAXBElement<T> result;
public JaxbElementSourceCallback(Class<T> declaredType) throws JAXBException {
this.unmarshaller = createUnmarshaller(declaredType);
this.declaredType = declaredType;
}
@Override
public void domSource(Node node) throws JAXBException {
result = unmarshaller.unmarshal(node, declaredType);
}
@Override
public void saxSource(XMLReader reader, InputSource inputSource) throws JAXBException {
result = unmarshaller.unmarshal(new SAXSource(reader, inputSource), declaredType);
}
@Override
public void staxSource(XMLEventReader eventReader) throws JAXBException {
result = unmarshaller.unmarshal(eventReader, declaredType);
}
@Override
public void staxSource(XMLStreamReader streamReader) throws JAXBException {
result = unmarshaller.unmarshal(streamReader, declaredType);
}
@Override
public void streamSource(InputStream inputStream) throws IOException, JAXBException {
result = unmarshaller.unmarshal(new StreamSource(inputStream), declaredType);
}
@Override
public void streamSource(Reader reader) throws IOException, JAXBException {
result = unmarshaller.unmarshal(new StreamSource(reader), declaredType);
}
@Override
public void source(String systemId) throws Exception {
result = unmarshaller.unmarshal(new StreamSource(systemId), declaredType);
}
}
private class Jaxb2ResultCallback implements TraxUtils.ResultCallback {
private final Marshaller marshaller;
private final Object jaxbElement;
private Jaxb2ResultCallback(Class<?> clazz, Object jaxbElement) throws JAXBException {
this.marshaller = createMarshaller(clazz);
this.jaxbElement = jaxbElement;
}
@Override
public void domResult(Node node) throws JAXBException {
marshaller.marshal(jaxbElement, node);
}
@Override
public void saxResult(ContentHandler contentHandler, LexicalHandler lexicalHandler) throws JAXBException {
marshaller.marshal(jaxbElement, contentHandler);
}
@Override
public void staxResult(XMLEventWriter eventWriter) throws JAXBException {
marshaller.marshal(jaxbElement, eventWriter);
}
@Override
public void staxResult(XMLStreamWriter streamWriter) throws JAXBException {
marshaller.marshal(jaxbElement, streamWriter);
}
@Override
public void streamResult(OutputStream outputStream) throws JAXBException {
marshaller.marshal(jaxbElement, outputStream);
}
@Override
public void streamResult(Writer writer) throws JAXBException {
marshaller.marshal(jaxbElement, writer);
}
@Override
public void result(String systemId) throws Exception {
marshaller.marshal(jaxbElement, new StreamResult(systemId));
}
}
private class JaxbStreamingPayload implements StreamingPayload {
private final Object jaxbElement;
private final Marshaller marshaller;
private final QName name;
private JaxbStreamingPayload(Class<?> clazz, Object jaxbElement) throws JAXBException {
JAXBContext jaxbContext = getJaxbContext(clazz);
this.marshaller = jaxbContext.createMarshaller();
this.marshaller.setProperty(Marshaller.JAXB_FRAGMENT, Boolean.TRUE);
this.jaxbElement = jaxbElement;
JAXBIntrospector introspector = jaxbContext.createJAXBIntrospector();
this.name = introspector.getElementName(jaxbElement);
}
@Override
public QName getName() {
return name;
}
@Override
public void writeTo(XMLStreamWriter streamWriter) throws XMLStreamException {
try {
marshaller.marshal(jaxbElement, streamWriter);
}
catch (JAXBException ex) {
throw new XMLStreamException("Could not marshal [" + jaxbElement + "]: " + ex.getMessage(), ex);
}
}
}
}