/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.cxf.rs.security.xml;
import java.io.IOException;
import java.io.InputStream;
import java.security.Key;
import java.security.PublicKey;
import java.security.cert.X509Certificate;
import java.util.LinkedList;
import java.util.List;
import java.util.logging.Logger;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.ws.rs.core.Response;
import javax.xml.stream.XMLStreamException;
import javax.xml.stream.XMLStreamReader;
import org.apache.cxf.common.logging.LogUtils;
import org.apache.cxf.interceptor.Fault;
import org.apache.cxf.interceptor.StaxInInterceptor;
import org.apache.cxf.jaxrs.utils.ExceptionUtils;
import org.apache.cxf.jaxrs.utils.JAXRSUtils;
import org.apache.cxf.message.Message;
import org.apache.cxf.phase.AbstractPhaseInterceptor;
import org.apache.cxf.phase.Phase;
import org.apache.cxf.rs.security.common.CryptoLoader;
import org.apache.cxf.rs.security.common.SecurityUtils;
import org.apache.cxf.rs.security.common.TrustValidator;
import org.apache.cxf.staxutils.StaxUtils;
import org.apache.cxf.ws.security.SecurityConstants;
import org.apache.wss4j.common.crypto.Crypto;
import org.apache.wss4j.common.crypto.CryptoType;
import org.apache.wss4j.common.ext.WSPasswordCallback;
import org.apache.wss4j.common.ext.WSSecurityException;
import org.apache.xml.security.exceptions.XMLSecurityException;
import org.apache.xml.security.stax.ext.InboundXMLSec;
import org.apache.xml.security.stax.ext.XMLSec;
import org.apache.xml.security.stax.ext.XMLSecurityConstants;
import org.apache.xml.security.stax.ext.XMLSecurityProperties;
import org.apache.xml.security.stax.securityEvent.AlgorithmSuiteSecurityEvent;
import org.apache.xml.security.stax.securityEvent.SecurityEvent;
import org.apache.xml.security.stax.securityEvent.SecurityEventConstants;
import org.apache.xml.security.stax.securityEvent.SecurityEventConstants.Event;
import org.apache.xml.security.stax.securityEvent.SecurityEventListener;
import org.apache.xml.security.stax.securityEvent.TokenSecurityEvent;
import org.apache.xml.security.stax.securityToken.SecurityToken;
/**
* A new StAX-based interceptor for processing messages with XML Signature + Encryption content.
*/
public class XmlSecInInterceptor extends AbstractPhaseInterceptor<Message> {
private static final Logger LOG = LogUtils.getL7dLogger(XmlSecInInterceptor.class);
private EncryptionProperties encryptionProperties;
private SignatureProperties sigProps;
private String decryptionAlias;
private String signatureVerificationAlias;
private boolean persistSignature = true;
private boolean requireSignature;
private boolean requireEncryption;
public XmlSecInInterceptor() {
super(Phase.POST_STREAM);
getAfter().add(StaxInInterceptor.class.getName());
}
public void handleMessage(Message message) throws Fault {
String method = (String)message.get(Message.HTTP_REQUEST_METHOD);
if ("GET".equals(method)) {
return;
}
Message outMs = message.getExchange().getOutMessage();
Message inMsg = outMs == null ? message : outMs.getExchange().getInMessage();
XMLStreamReader originalXmlStreamReader = inMsg.getContent(XMLStreamReader.class);
if (originalXmlStreamReader == null) {
InputStream is = inMsg.getContent(InputStream.class);
if (is != null) {
originalXmlStreamReader = StaxUtils.createXMLStreamReader(is);
}
}
inMsg.getInterceptorChain().add(
new StaxActionInInterceptor(requireSignature, requireEncryption));
try {
XMLSecurityProperties properties = new XMLSecurityProperties();
configureDecryptionKeys(inMsg, properties);
Crypto signatureCrypto = getSignatureCrypto(inMsg);
configureSignatureKeys(signatureCrypto, inMsg, properties);
SecurityEventListener securityEventListener =
configureSecurityEventListener(signatureCrypto, inMsg, properties);
InboundXMLSec inboundXMLSec = XMLSec.getInboundWSSec(properties);
XMLStreamReader newXmlStreamReader =
inboundXMLSec.processInMessage(originalXmlStreamReader, null, securityEventListener);
inMsg.setContent(XMLStreamReader.class, newXmlStreamReader);
} catch (XMLStreamException e) {
throwFault(e.getMessage(), e);
} catch (XMLSecurityException e) {
throwFault(e.getMessage(), e);
} catch (IOException e) {
throwFault(e.getMessage(), e);
} catch (UnsupportedCallbackException e) {
throwFault(e.getMessage(), e);
}
}
private void configureDecryptionKeys(Message message, XMLSecurityProperties properties)
throws IOException,
UnsupportedCallbackException, WSSecurityException {
String cryptoKey = null;
String propKey = null;
if (SecurityUtils.isSignedAndEncryptedTwoWay(message)) {
cryptoKey = SecurityConstants.SIGNATURE_CRYPTO;
propKey = SecurityConstants.SIGNATURE_PROPERTIES;
} else {
cryptoKey = SecurityConstants.ENCRYPT_CRYPTO;
propKey = SecurityConstants.ENCRYPT_PROPERTIES;
}
Crypto crypto = null;
try {
crypto = new CryptoLoader().getCrypto(message, cryptoKey, propKey);
} catch (Exception ex) {
throwFault("Crypto can not be loaded", ex);
}
if (crypto != null) {
String alias = decryptionAlias;
if (alias == null) {
alias = crypto.getDefaultX509Identifier();
}
if (alias != null) {
CallbackHandler callback = SecurityUtils.getCallbackHandler(message, this.getClass());
WSPasswordCallback passwordCallback =
new WSPasswordCallback(alias, WSPasswordCallback.DECRYPT);
callback.handle(new Callback[] {passwordCallback});
Key privateKey = crypto.getPrivateKey(alias, passwordCallback.getPassword());
properties.setDecryptionKey(privateKey);
}
}
}
private Crypto getSignatureCrypto(Message message) {
String cryptoKey = null;
String propKey = null;
if (SecurityUtils.isSignedAndEncryptedTwoWay(message)) {
cryptoKey = SecurityConstants.ENCRYPT_CRYPTO;
propKey = SecurityConstants.ENCRYPT_PROPERTIES;
} else {
cryptoKey = SecurityConstants.SIGNATURE_CRYPTO;
propKey = SecurityConstants.SIGNATURE_PROPERTIES;
}
try {
return new CryptoLoader().getCrypto(message, cryptoKey, propKey);
} catch (Exception ex) {
throwFault("Crypto can not be loaded", ex);
return null;
}
}
private void configureSignatureKeys(
Crypto sigCrypto, Message message, XMLSecurityProperties properties
) throws IOException,
UnsupportedCallbackException, WSSecurityException {
if (sigCrypto != null && signatureVerificationAlias != null) {
CryptoType cryptoType = new CryptoType(CryptoType.TYPE.ALIAS);
cryptoType.setAlias(signatureVerificationAlias);
X509Certificate[] certs = sigCrypto.getX509Certificates(cryptoType);
if (certs != null && certs.length > 0) {
properties.setSignatureVerificationKey(certs[0].getPublicKey());
}
}
}
protected SecurityEventListener configureSecurityEventListener(
final Crypto sigCrypto, final Message msg, XMLSecurityProperties securityProperties
) {
final List<SecurityEvent> incomingSecurityEventList = new LinkedList<SecurityEvent>();
SecurityEventListener securityEventListener = new SecurityEventListener() {
@Override
public void registerSecurityEvent(SecurityEvent securityEvent) throws XMLSecurityException {
if (securityEvent.getSecurityEventType() == SecurityEventConstants.AlgorithmSuite) {
if (encryptionProperties != null) {
checkEncryptionAlgorithms((AlgorithmSuiteSecurityEvent)securityEvent);
}
if (sigProps != null) {
checkSignatureAlgorithms((AlgorithmSuiteSecurityEvent)securityEvent);
}
} else if (securityEvent.getSecurityEventType() != SecurityEventConstants.EncryptedKeyToken
&& securityEvent instanceof TokenSecurityEvent<?>) {
checkSignatureTrust(sigCrypto, msg, (TokenSecurityEvent<?>)securityEvent);
}
incomingSecurityEventList.add(securityEvent);
}
};
msg.getExchange().put(SecurityEvent.class.getName() + ".in", incomingSecurityEventList);
msg.put(SecurityEvent.class.getName() + ".in", incomingSecurityEventList);
return securityEventListener;
}
private void checkEncryptionAlgorithms(AlgorithmSuiteSecurityEvent event)
throws XMLSecurityException {
if (XMLSecurityConstants.Enc.equals(event.getAlgorithmUsage())
&& encryptionProperties.getEncryptionSymmetricKeyAlgo() != null
&& !encryptionProperties.getEncryptionSymmetricKeyAlgo().equals(event.getAlgorithmURI())) {
throw new XMLSecurityException("empty", "The symmetric encryption algorithm "
+ event.getAlgorithmURI() + " is not allowed");
} else if ((XMLSecurityConstants.Sym_Key_Wrap.equals(event.getAlgorithmUsage())
|| XMLSecurityConstants.Asym_Key_Wrap.equals(event.getAlgorithmUsage()))
&& encryptionProperties.getEncryptionKeyTransportAlgo() != null
&& !encryptionProperties.getEncryptionKeyTransportAlgo().equals(event.getAlgorithmURI())) {
throw new XMLSecurityException("empty", "The key transport algorithm "
+ event.getAlgorithmURI() + " is not allowed");
} else if (XMLSecurityConstants.EncDig.equals(event.getAlgorithmUsage())
&& encryptionProperties.getEncryptionDigestAlgo() != null
&& !encryptionProperties.getEncryptionDigestAlgo().equals(event.getAlgorithmURI())) {
throw new XMLSecurityException("empty", "The encryption digest algorithm "
+ event.getAlgorithmURI() + " is not allowed");
}
}
private void checkSignatureAlgorithms(AlgorithmSuiteSecurityEvent event)
throws XMLSecurityException {
if ((XMLSecurityConstants.Asym_Sig.equals(event.getAlgorithmUsage())
|| XMLSecurityConstants.Sym_Sig.equals(event.getAlgorithmUsage()))
&& sigProps.getSignatureAlgo() != null
&& !sigProps.getSignatureAlgo().equals(event.getAlgorithmURI())) {
throw new XMLSecurityException("empty", "The signature algorithm "
+ event.getAlgorithmURI() + " is not allowed");
} else if (XMLSecurityConstants.SigDig.equals(event.getAlgorithmUsage())
&& sigProps.getSignatureDigestAlgo() != null
&& !sigProps.getSignatureDigestAlgo().equals(event.getAlgorithmURI())) {
throw new XMLSecurityException("empty", "The signature digest algorithm "
+ event.getAlgorithmURI() + " is not allowed");
} else if (XMLSecurityConstants.SigC14n.equals(event.getAlgorithmUsage())
&& sigProps.getSignatureC14nMethod() != null
&& !sigProps.getSignatureC14nMethod().equals(event.getAlgorithmURI())) {
throw new XMLSecurityException("empty", "The signature c14n algorithm "
+ event.getAlgorithmURI() + " is not allowed");
} else if (XMLSecurityConstants.SigTransform.equals(event.getAlgorithmUsage())
&& !XMLSecurityConstants.NS_XMLDSIG_ENVELOPED_SIGNATURE.equals(event.getAlgorithmURI())
&& sigProps.getSignatureC14nTransform() != null
&& !sigProps.getSignatureC14nTransform().equals(event.getAlgorithmURI())) {
throw new XMLSecurityException("empty", "The signature transformation algorithm "
+ event.getAlgorithmURI() + " is not allowed");
}
}
private void checkSignatureTrust(
Crypto sigCrypto, Message msg, TokenSecurityEvent<?> event
) throws XMLSecurityException {
SecurityToken token = event.getSecurityToken();
if (token != null) {
X509Certificate[] certs = token.getX509Certificates();
PublicKey publicKey = token.getPublicKey();
X509Certificate cert = null;
if (certs != null && certs.length > 0) {
cert = certs[0];
}
// validate trust
try {
new TrustValidator().validateTrust(sigCrypto, cert, publicKey);
} catch (WSSecurityException e) {
throw new XMLSecurityException("empty", "Error during Signature Trust "
+ "validation: " + e.getMessage());
}
if (persistSignature) {
msg.setContent(X509Certificate.class, cert);
}
}
}
protected void throwFault(String error, Exception ex) {
LOG.warning(error);
Response response = JAXRSUtils.toResponseBuilder(400).entity(error).build();
throw ExceptionUtils.toBadRequestException(null, response);
}
public void setEncryptionProperties(EncryptionProperties properties) {
this.encryptionProperties = properties;
}
public void setSignatureProperties(SignatureProperties properties) {
this.sigProps = properties;
}
public String getDecryptionAlias() {
return decryptionAlias;
}
public void setDecryptionAlias(String decryptionAlias) {
this.decryptionAlias = decryptionAlias;
}
public String getSignatureVerificationAlias() {
return signatureVerificationAlias;
}
public void setSignatureVerificationAlias(String signatureVerificationAlias) {
this.signatureVerificationAlias = signatureVerificationAlias;
}
public void setPersistSignature(boolean persist) {
this.persistSignature = persist;
}
public boolean isRequireSignature() {
return requireSignature;
}
public void setRequireSignature(boolean requireSignature) {
this.requireSignature = requireSignature;
}
public boolean isRequireEncryption() {
return requireEncryption;
}
public void setRequireEncryption(boolean requireEncryption) {
this.requireEncryption = requireEncryption;
}
/**
* This interceptor handles parsing the StaX results (events) + checks to see whether the
* required (if any) Actions (signature or encryption) were fulfilled.
*/
private static class StaxActionInInterceptor extends AbstractPhaseInterceptor<Message> {
private static final Logger LOG =
LogUtils.getL7dLogger(StaxActionInInterceptor.class);
private final boolean signatureRequired;
private final boolean encryptionRequired;
public StaxActionInInterceptor(boolean signatureRequired, boolean encryptionRequired) {
super(Phase.PRE_LOGICAL);
this.signatureRequired = signatureRequired;
this.encryptionRequired = encryptionRequired;
}
@Override
public void handleMessage(Message message) throws Fault {
if (!(signatureRequired || encryptionRequired)) {
return;
}
@SuppressWarnings("unchecked")
final List<SecurityEvent> incomingSecurityEventList =
(List<SecurityEvent>)message.get(SecurityEvent.class.getName() + ".in");
if (incomingSecurityEventList == null) {
LOG.warning("Security processing failed (actions mismatch)");
XMLSecurityException ex =
new XMLSecurityException("empty", "The request was not signed or encrypted");
throwFault(ex.getMessage(), ex);
}
if (signatureRequired) {
Event requiredEvent = SecurityEventConstants.SignatureValue;
if (!isEventInResults(requiredEvent, incomingSecurityEventList)) {
LOG.warning("The request was not signed");
XMLSecurityException ex =
new XMLSecurityException("empty", "The request was not signed");
throwFault(ex.getMessage(), ex);
}
}
if (encryptionRequired) {
boolean foundEncryptionPart =
isEventInResults(SecurityEventConstants.EncryptedElement, incomingSecurityEventList);
if (!foundEncryptionPart) {
LOG.warning("The request was not encrypted");
XMLSecurityException ex =
new XMLSecurityException("empty", "The request was not encrypted");
throwFault(ex.getMessage(), ex);
}
}
}
private boolean isEventInResults(Event event, List<SecurityEvent> incomingSecurityEventList) {
for (SecurityEvent incomingEvent : incomingSecurityEventList) {
if (event == incomingEvent.getSecurityEventType()) {
return true;
}
}
return false;
}
protected void throwFault(String error, Exception ex) {
LOG.warning(error);
Response response = JAXRSUtils.toResponseBuilder(400).entity(error).build();
throw ExceptionUtils.toBadRequestException(null, response);
}
}
}