Package org.picketlink.identity.federation.bindings.jboss.auth

Source Code of org.picketlink.identity.federation.bindings.jboss.auth.SAML2STSCommonLoginModule

/*
* JBoss, Home of Professional Open Source.
* Copyright 2008, Red Hat Middleware LLC, and individual contributors
* as indicated by the @author tags. See the copyright.txt file in the
* distribution for a full listing of individual contributors.
*
* This is free software; you can redistribute it and/or modify it
* under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of
* the License, or (at your option) any later version.
*
* This software is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this software; if not, write to the Free
* Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
* 02110-1301 USA, or see the FSF site: http://www.fsf.org.
*/
package org.picketlink.identity.federation.bindings.jboss.auth;

import java.security.Principal;
import java.security.acl.Group;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import javax.security.auth.Subject;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.login.LoginException;
import javax.xml.datatype.XMLGregorianCalendar;
import javax.xml.transform.Source;
import javax.xml.ws.Dispatch;

import org.jboss.security.SecurityConstants;
import org.jboss.security.SimplePrincipal;
import org.jboss.security.auth.callback.ObjectCallback;
import org.picketlink.identity.federation.bindings.jboss.subject.PicketLinkGroup;
import org.picketlink.identity.federation.bindings.jboss.subject.PicketLinkPrincipal;
import org.picketlink.identity.federation.core.ErrorCodes;
import org.picketlink.identity.federation.core.constants.AttributeConstants;
import org.picketlink.identity.federation.core.constants.PicketLinkFederationConstants;
import org.picketlink.identity.federation.core.exceptions.ProcessingException;
import org.picketlink.identity.federation.core.factories.JBossAuthCacheInvalidationFactory.TimeCacheExpiry;
import org.picketlink.identity.federation.core.saml.v2.util.AssertionUtil;
import org.picketlink.identity.federation.core.saml.v2.util.DocumentUtil;
import org.picketlink.identity.federation.core.util.StringUtil;
import org.picketlink.identity.federation.core.wstrust.STSClient;
import org.picketlink.identity.federation.core.wstrust.STSClientConfig.Builder;
import org.picketlink.identity.federation.core.wstrust.STSClientFactory;
import org.picketlink.identity.federation.core.wstrust.SamlCredential;
import org.picketlink.identity.federation.core.wstrust.WSTrustException;
import org.picketlink.identity.federation.core.wstrust.auth.AbstractSTSLoginModule;
import org.picketlink.identity.federation.core.wstrust.plugins.saml.SAMLUtil;
import org.picketlink.identity.federation.saml.v2.assertion.AssertionType;
import org.picketlink.identity.federation.saml.v2.assertion.BaseIDAbstractType;
import org.picketlink.identity.federation.saml.v2.assertion.NameIDType;
import org.picketlink.identity.federation.saml.v2.assertion.SubjectType;
import org.w3c.dom.Element;

/**
* <p>
* This {@code LoginModule} authenticates clients by validating their SAML assertions with an external security token service
* (such as PicketLinkSTS). If the supplied assertion contains roles, these roles are extracted and included in the
* {@code Group} returned by the {@code getRoleSets} method.
* </p>
* <p>
* This module defines the following module options:
* <ul>
* <li>
* configFile - this property identifies the properties file that will be used to establish communication with the external
* security token service.
* </li>
* <li>
* cache.invalidation: set it to true if you require invalidation of JBoss Auth Cache at SAML Principal expiration.
* </li>
* <li>
* jboss.security.security_domain: name of the security domain where this login module is configured. This is only required if
* the cache.invalidation option is configured.
* </li>
* <li>
* roleKey: a comma separated list of strings that define the attributes in SAML assertion for user roles
* </li>
* <li>
* localValidation: if you want to validate the assertion locally for signature and expiry
* </li>
* <li>
* localValidationSecurityDomain:  the security domain for the trust store information (via the JaasSecurityDomain)
* </li>
* <li>
* tokenEncodingType: encoding type of SAML token delivered via http request's header.
* Possible values are:
*    base64 - content encoded as base64. In case of encoding will vary between base64 and gzip use base64 and LoginModule will detect gzipped data.
*    gzip - gzipped content encoded as base64
*    none - content not encoded in any way
* </li>
* <li>
* samlTokenHttpHeader - name of http request header to fetch SAML token from. For example: "Authorize"
* </li>
* <li>
* samlTokenHttpHeaderRegEx - Java regular expression to be used to get SAML token from "samlTokenHttpHeader". Example: use: ."(.)".* to parse SAML token from header content like this: SAML_assertion="HHDHS=", at the same time set samlTokenHttpHeaderRegExGroup to 1.
* </li>
* <li>
* samlTokenHttpHeaderRegExGroup - Group value to be used when parsing out value of http request header specified by "samlTokenHttpHeader" using "samlTokenHttpHeaderRegEx".
* </li>
* </ul>
* </p>
* <p>
* Any properties specified besides the above properties are assumed to be used to configure how the {@code STSClient} will
* connect to the STS. For example, the JBossWS {@code StubExt.PROPERTY_SOCKET_FACTORY} can be specified in order to inform the
* socket factory that must be used to connect to the STS. All properties will be set in the request context of the
* {@code Dispatch} instance used by the {@code STSClient} to send requests to the STS.
* </p>
* <p>
* An example of a {@code configFile} can be seen bellow:
*
* <pre>
* serviceName=PicketLinkSTS
* portName=PicketLinkSTSPort
* endpointAddress=http://localhost:8080/picketlink-sts/PicketLinkSTS
* username=JBoss
* password=JBoss
* </pre>
*
* The first three properties specify the STS endpoint URL, service name, and port name. The last two properties specify the
* username and password that are to be used by the application server to authenticate to the STS and have the SAML assertions
* validated.
* </p>
* <p>
* <b>NOTE:</b> Sub-classes can use {@link #getSTSClient()} method to customize the {@link STSClient} class to make calls to
* STS/
* </p>
*
* @author <a href="mailto:sguilhen@redhat.com">Stefan Guilhen</a>
* @author Anil.Saldhana@redhat.com
*/
@SuppressWarnings("unchecked")
public abstract class SAML2STSCommonLoginModule extends SAMLTokenFromHttpRequestAbstractLoginModule {
   
    protected String stsConfigurationFile;

    protected Principal principal;

    protected SamlCredential credential;

    protected AssertionType assertion;

    protected boolean enableCacheInvalidation = false;

    protected String securityDomain = null;

    protected boolean localValidation = false;

    protected String localValidationSecurityDomain;

    protected String roleKey = AttributeConstants.ROLE_IDENTIFIER_ASSERTION;

    /**
     * Maximal number of clients in the STS Client Pool.
     */
    protected int maxClientsInPool = 0;
   
    /**
     * Number of clients initialized for in case pool is out of free clients.
     */
    protected int initialNumberOfClients = 0;

   
    /**
     * Options that are computed by this login module. Few options are removed and the rest are set in the dispatch sts call
     */
    protected Map<String, Object> options = new HashMap<String, Object>();

    /**
     * Original Options that are sent by the JDK JAAS Framework
     */
    protected Map<String, Object> rawOptions = new HashMap<String, Object>();

    /**
     * This is an option that should identify the configuration file for WSTrustClient.
     */
    public static final String STS_CONFIG_FILE = "configFile";

    /**
     * Key to specify the end point address
     */
    public static final String ENDPOINT_ADDRESS = "endpointAddress";

    /**
     * Key to specify the port name
     */
    public static final String PORT_NAME = "portName";

    /**
     * Key to specify the service name
     */
    public static final String SERVICE_NAME = "serviceName";

    /**
     * Key to specify the username
     */
    public static final String USERNAME_KEY = "username";

    /**
     * Key to specify the password
     */
    public static final String PASSWORD_KEY = "password";

    // A variable used by the unit test to pass local validation
    protected boolean localTestingOnly = false;

    /**
     * Paramater name.
     */
    public static final String MAX_CLIENTS_IN_POOL = "maxClientsInPool";

    /**
     * Paramater name.
     */
    public static final String INITIAL_NUMBER_OF_CLIENTS = "initialNumberOfClients";

    /*
     * (non-Javadoc)
     *
     * @see org.jboss.security.auth.spi.AbstractServerLoginModule#initialize(javax.security.auth.Subject,
     * javax.security.auth.callback.CallbackHandler, java.util.Map, java.util.Map)
     */
    @Override
    public void initialize(Subject subject, CallbackHandler callbackHandler, Map<String, ?> sharedState, Map<String, ?> options) {
        super.initialize(subject, callbackHandler, sharedState, options);
        this.options.putAll(options);
        this.rawOptions.putAll(options);

        if (logger.isTraceEnabled()) {
            logger.trace(options.toString());
        }
        // save the config file and cache validation options, removing them from the map - all remaining properties will
        // be set in the request context of the Dispatch instance used to send requests to the STS.
        this.stsConfigurationFile = (String) this.options.remove(STS_CONFIG_FILE);
        String cacheInvalidation = (String) this.options.remove("cache.invalidation");
        if (cacheInvalidation != null && !cacheInvalidation.isEmpty()) {
            this.enableCacheInvalidation = Boolean.parseBoolean(cacheInvalidation);

            this.securityDomain = (String) this.options.remove(SecurityConstants.SECURITY_DOMAIN_OPTION);
            if (this.securityDomain == null || this.securityDomain.isEmpty())
                throw logger.optionNotSet(SecurityConstants.SECURITY_DOMAIN_OPTION);
        }

        String roleKeyStr = (String) options.get("roleKey");
        if (StringUtil.isNotNull(roleKeyStr)) {
            roleKey = roleKeyStr.trim();
        }

        String localValidationStr = (String) options.get("localValidation");
        if (StringUtil.isNotNull(localValidationStr)) {
            localValidation = Boolean.parseBoolean(localValidationStr);
            localValidationSecurityDomain = (String) options.get("localValidationSecurityDomain");

            if (localValidationSecurityDomain == null) {
               logger.error(ErrorCodes.LOCAL_VALIDATION_SEC_DOMAIN_MUST_BE_SPECIFIED);
               throw logger.optionNotSet("localValidationSecurityDomain");
            }
           
            if (localValidationSecurityDomain.startsWith("java:") == false)
                localValidationSecurityDomain = SecurityConstants.JAAS_CONTEXT_ROOT + "/" + localValidationSecurityDomain;

            String localTestingOnlyStr = (String) options.get("localTestingOnly");
            if (StringUtil.isNotNull(localTestingOnlyStr)) {
                localTestingOnly = Boolean.valueOf(localTestingOnlyStr);
            }
        }

        String maxClientsString = (String) options.get(MAX_CLIENTS_IN_POOL);
        if (StringUtil.isNotNull(maxClientsString)) {
            try {
                this.maxClientsInPool = Integer.parseInt(maxClientsString);
            } catch (Exception e) {
                logger.cannotParseParameterValue(MAX_CLIENTS_IN_POOL, e);
            }
        }
       
        String initialNumberOfClientsString = (String) options.get(INITIAL_NUMBER_OF_CLIENTS);
        if (StringUtil.isNotNull(initialNumberOfClientsString)) {
            try {
                this.initialNumberOfClients = Integer.parseInt(initialNumberOfClientsString);
            } catch (Exception e) {
                logger.cannotParseParameterValue(INITIAL_NUMBER_OF_CLIENTS, e);
            }
        }
       
       
    }

    /*
     * (non-Javadoc)
     *
     * @see org.jboss.security.auth.spi.AbstractServerLoginModule#login()
     */
    @Override
    public boolean login() throws LoginException {
        // if shared data exists, set our principal and assertion variables.
        if (super.login()) {
            Object sharedPrincipal = super.sharedState.get("javax.security.auth.login.name");
            if (sharedPrincipal instanceof Principal)
                this.principal = (Principal) sharedPrincipal;
            else {
                try {
                    this.principal = createIdentity(sharedPrincipal.toString());
                } catch (Exception e) {
                    throw logger.authFailedToCreatePrincipal(e);
                }
            }

            Object credential = super.sharedState.get("javax.security.auth.login.password");
            if (credential instanceof SamlCredential)
                this.credential = (SamlCredential) credential;
            else
                throw logger.authSharedCredentialIsNotSAMLCredential(credential.getClass().getName());
            return true;
        }

        // obtain the assertion from the callback handler.
        ObjectCallback callback = new ObjectCallback(null);
        Element assertionElement = null;
        try {
            if (getSamlTokenHttpHeader() != null) {
                this.credential = getCredentialFromHttpRequest();
            }
            else {
                super.callbackHandler.handle(new Callback[] { callback });
               
                if (callback.getCredential() instanceof String) {
                    callback.setCredential(new SamlCredential(DocumentUtil.getDocument(callback.getCredential().toString()).getDocumentElement()));
                }
               
                if (callback.getCredential() instanceof SamlCredential == false)
                    throw logger.authSharedCredentialIsNotSAMLCredential(callback.getCredential().getClass().getName());
                this.credential = (SamlCredential) callback.getCredential();
            }
            assertionElement = this.credential.getAssertionAsElement();
        } catch (Exception e) {
            throw logger.authErrorHandlingCallback(e);
        }
   
       
        // if there is no shared data, validate the assertion using the STS.
        if (localValidation) {
            logger.trace("Local Validation is being Performed");
            try {
                boolean isValid = localValidation(assertionElement);
                if (isValid) {
                    logger.trace("Local Validation passed.");
                }
            } catch (Exception e) {
                LoginException le = new LoginException();
                le.initCause(e);
                throw le;
            }
        } else {
            logger.trace("Local Validation is disabled. Verifying with STS");

            // sts config file has to be present to call STS (using sts client)
            if (this.stsConfigurationFile == null)
                throw logger.authSTSConfigFileNotFound();

            // send the assertion to the STS for validation.
            STSClient client = this.getSTSClient();
            try {
                boolean isValid = client.validateToken(assertionElement);
                // if the STS says the assertion is invalid, throw an exception to signal that authentication has failed.
                if (isValid == false)
                    throw logger.authInvalidSAMLAssertionBySTS();
            } catch (WSTrustException we) {
                throw logger.authAssertionValidationError(we);
            }
        }

        // if the assertion is valid, create a principal containing the assertion subject.
        try {
            this.assertion = SAMLUtil.fromElement(assertionElement);
            SubjectType subject = assertion.getSubject();
            if (subject != null) {
                BaseIDAbstractType baseID = subject.getSubType().getBaseID();
                if (baseID instanceof NameIDType) {
                    NameIDType nameID = (NameIDType) baseID;
                    this.principal = new PicketLinkPrincipal(nameID.getValue());

                    // If the user has configured cache invalidation of subject based on saml token expiry
                    if (enableCacheInvalidation) {
                        TimeCacheExpiry cacheExpiry = this.getCacheExpiry();
                        XMLGregorianCalendar expiry = AssertionUtil.getExpiration(assertion);
                        if (expiry != null) {
                            Date expiryDate = expiry.toGregorianCalendar().getTime();

                            logger.trace("Creating Cache Entry for JBoss at [" + new Date() + "] , with expiration set to SAML expiry = " + expiryDate);

                            cacheExpiry.register(securityDomain, expiryDate, principal);
                        } else {
                            logger.samlAssertionWithoutExpiration(assertion.getID());
                        }
                    }
                }
            }
        } catch (Exception e) {
            throw logger.authFailedToParseSAMLAssertion(e);
        }

        // if password-stacking has been configured, set the principal and the assertion in the shared map.
        if (getUseFirstPass()) {
            super.sharedState.put("javax.security.auth.login.name", this.principal);
            super.sharedState.put("javax.security.auth.login.password", this.credential);
        }
        return (super.loginOk = true);
    }
   
   

    /* (non-Javadoc)
     * @see org.jboss.security.auth.spi.AbstractServerLoginModule#commit()
     */
    @Override
    public boolean commit() throws LoginException {
        if (super.commit()) {
            final boolean added = subject.getPublicCredentials().add(this.credential);
            if (added && logger.isTraceEnabled())
                logger.trace("Added Credential " + this.credential);
            return true;
        } else {
            return false;
        }
    }

    /**
     * Called if the overall authentication failed (phase 2).
     */
    @Override
    public boolean abort() throws LoginException {
        clearState();
        super.abort();
        return true;
    }

    @Override
    public boolean logout() throws LoginException {
        clearState();
        super.logout();
        return true;
    }

    private void clearState() {
        AbstractSTSLoginModule.removeAllSamlCredentials(subject);
        credential = null;
    }

    /*
     * (non-Javadoc)
     *
     * @see org.jboss.security.auth.spi.AbstractServerLoginModule#getIdentity()
     */
    @Override
    protected Principal getIdentity() {
        return this.principal;
    }

    /*
     * (non-Javadoc)
     *
     * @see org.jboss.security.auth.spi.AbstractServerLoginModule#getRoleSets()
     */
    @Override
    protected Group[] getRoleSets() throws LoginException {
        if (this.assertion == null) {
            try {
                this.assertion = SAMLUtil.fromElement(this.credential.getAssertionAsElement());
            } catch (Exception e) {
                throw logger.authFailedToParseSAMLAssertion(e);
            }
        }
        if (logger.isTraceEnabled()) {
            try {
                logger.trace("Assertion from where roles will be sought = " + AssertionUtil.asString(assertion));
            } catch (ProcessingException ignore) {
            }
        }

        List<String> roleKeys = new ArrayList<String>();
        if (StringUtil.isNotNull(roleKey)) {
            roleKeys.addAll(StringUtil.tokenize(roleKey));
        }

        String groupName = SecurityConstants.ROLES_IDENTIFIER;
        Group rolesGroup = new PicketLinkGroup(groupName);
        List<String> roles = AssertionUtil.getRoles(assertion, roleKeys);
        for (String role : roles) {
            rolesGroup.addMember(new SimplePrincipal(role));
        }

        return new Group[] { rolesGroup };
    }

    /**
     * Get the {@link STSClient} object with which we can make calls to the STS
     *
     * @return
     */
    protected STSClient getSTSClient() {
        /*
         * Builder builder = new Builder(this.stsConfigurationFile); STSClient client = new STSClient(builder.build());
         */

        Builder builder = null;
        STSClient client = null;
        if (rawOptions.containsKey(STS_CONFIG_FILE)) {
            builder = new Builder(this.stsConfigurationFile);
            client = STSClientFactory.getInstance(maxClientsInPool).create(initialNumberOfClients, builder.build());
        } else {
            builder = new Builder();
            builder.endpointAddress((String) rawOptions.get(ENDPOINT_ADDRESS));
            builder.portName((String) rawOptions.get(PORT_NAME)).serviceName((String) rawOptions.get(SERVICE_NAME));
            builder.username((String) rawOptions.get(USERNAME_KEY)).password((String) rawOptions.get(PASSWORD_KEY));

            String passwordString = (String) rawOptions.get(PASSWORD_KEY);
            if (passwordString != null && passwordString.startsWith(PicketLinkFederationConstants.PASS_MASK_PREFIX)) {
                // password is masked
                String salt = (String) rawOptions.get(PicketLinkFederationConstants.SALT);
                if (StringUtil.isNullOrEmpty(salt))
                    throw logger.optionNotSet("Salt");

                String iCount = (String) rawOptions.get(PicketLinkFederationConstants.ITERATION_COUNT);
                if (StringUtil.isNullOrEmpty(iCount))
                    throw logger.optionNotSet("Iteration Count");

                int iterationCount = Integer.parseInt(iCount);
                try {
                    builder.password(StringUtil.decode(passwordString, salt, iterationCount));
                } catch (Exception e) {
                    throw logger.unableToDecodePasswordError(passwordString);
                }
            }
            client = STSClientFactory.getInstance(maxClientsInPool).create(initialNumberOfClients, builder.build());
        }

        // if the login module options map still contains any properties, assume they are for configuring the connection
        // to the STS and set them in the Dispatch request context.
        if (!this.options.isEmpty()) {
            Dispatch<Source> dispatch = client.getDispatch();
            for (Map.Entry<String, ?> entry : this.options.entrySet())
                dispatch.getRequestContext().put(entry.getKey(), entry.getValue());
        }
        return client;
    }
   
    /**
     * Locally validate the SAML Assertion element
     *
     * @param assertionElement
     * @return
     * @throws Exception
     */
    protected abstract boolean localValidation(Element assertionElement) throws Exception;

    protected abstract TimeCacheExpiry getCacheExpiry() throws Exception;

}
TOP

Related Classes of org.picketlink.identity.federation.bindings.jboss.auth.SAML2STSCommonLoginModule

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.