Package org.surfnet.oaaas.conext

Source Code of org.surfnet.oaaas.conext.SAMLAuthenticator

/*
* Copyright 2012 SURFnet bv, The Netherlands
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*      http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.surfnet.oaaas.conext;

import nl.surfnet.coin.api.client.OpenConextOAuthClient;
import nl.surfnet.coin.api.client.domain.Group20;
import nl.surfnet.spring.security.opensaml.AuthnRequestGenerator;
import nl.surfnet.spring.security.opensaml.Provisioner;
import nl.surfnet.spring.security.opensaml.SAMLMessageHandler;
import nl.surfnet.spring.security.opensaml.ServiceProviderAuthenticationException;
import nl.surfnet.spring.security.opensaml.util.IDService;
import nl.surfnet.spring.security.opensaml.util.TimeService;
import nl.surfnet.spring.security.opensaml.xml.EndpointGenerator;
import org.apache.commons.beanutils.BeanUtils;
import org.apache.commons.lang.StringUtils;
import org.opensaml.common.binding.SAMLMessageContext;
import org.opensaml.saml2.core.AuthnRequest;
import org.opensaml.saml2.core.RequesterID;
import org.opensaml.saml2.core.Response;
import org.opensaml.saml2.core.Scoping;
import org.opensaml.saml2.core.impl.RequesterIDBuilder;
import org.opensaml.saml2.core.impl.ScopingBuilder;
import org.opensaml.saml2.metadata.Endpoint;
import org.opensaml.saml2.metadata.SingleSignOnService;
import org.opensaml.ws.message.decoder.MessageDecodingException;
import org.opensaml.ws.message.encoder.MessageEncodingException;
import org.opensaml.xml.security.CriteriaSet;
import org.opensaml.xml.security.credential.Credential;
import org.opensaml.xml.security.credential.UsageType;
import org.opensaml.xml.security.criteria.EntityIDCriteria;
import org.opensaml.xml.security.criteria.UsageCriteria;
import org.opensaml.xml.validation.ValidationException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.support.PropertiesLoaderUtils;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import org.surfnet.oaaas.auth.AbstractAuthenticator;
import org.surfnet.oaaas.auth.principal.AuthenticatedPrincipal;
import org.surfnet.oaaas.model.AuthorizationRequest;
import org.surfnet.oaaas.model.Client;
import org.surfnet.oaaas.repository.AuthorizationRequestRepository;

import javax.inject.Inject;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.util.List;
import java.util.Properties;

@Component
public class SAMLAuthenticator extends AbstractAuthenticator {

  private static final Logger LOG = LoggerFactory.getLogger(SAMLAuthenticator.class);
  private static final String RELAY_STATE_FROM_SAML = "RELAY_STATE_FROM_SAML";
  private static final String PRINCIPAL_FROM_SAML = "PRINCIPAL_FROM_SAML";
  private static final String CLIENT_SAML_ENTITY_NAME = "CLIENT_SAML_ENTITY_NAME";

  private TimeService timeService = new TimeService();
  private IDService idService = new IDService();
  private ScopingBuilder scopingBuilder = new ScopingBuilder();
  private RequesterIDBuilder requesterIDBuilder = new RequesterIDBuilder();

  private OpenSAMLContext openSAMLContext;
  private OpenConextOAuthClient apiClient;
  private String callbackFlagParameter = "apiOauthCallback";
  private boolean enrichPricipal;
  private String adminGroup;


  @Inject
  private AuthorizationRequestRepository authorizationRequestRepository;

  private final Properties properties;

  {
    try {
      properties = PropertiesLoaderUtils.loadAllProperties("surfconext.authn.properties");
    } catch (IOException e) {
      throw new RuntimeException(e);
    }
  }


  @Override
  public void init(FilterConfig filterConfig) throws ServletException {
    try {
      super.init(filterConfig);
      openSAMLContext = createOpenSAMLContext(properties);
      enrichPricipal = Boolean.valueOf(properties.getProperty("api-enrich-principal"));
      if (enrichPricipal) {
        apiClient = createOpenConextOAuthClient(properties);
        adminGroup = properties.getProperty("admin.client.apis.teamname");
      }
    } catch (Exception e) {
      throw new ServletException(e);
    }
  }


  protected OpenConextOAuthClient createOpenConextOAuthClient(Properties properties) throws ClassNotFoundException, IllegalAccessException, InstantiationException, InvocationTargetException {
    OpenConextOAuthClient apiClient = (OpenConextOAuthClient) getClass().getClassLoader().loadClass(properties.getProperty("openConextApiClient")).newInstance();
    BeanUtils.setProperty(apiClient, "callbackUrl", properties.getProperty("api-callbackuri"));
    BeanUtils.setProperty(apiClient, "consumerSecret", properties.getProperty("api-consumersecret"));
    BeanUtils.setProperty(apiClient, "consumerKey", properties.getProperty("api-consumerkey"));
    BeanUtils.setProperty(apiClient, "endpointBaseUrl", properties.getProperty("api-baseurl"));
    return apiClient;
  }

  /**
   * Default Context factory method.
   */
  protected OpenSAMLContext createOpenSAMLContext(Properties properties) {
    return new OpenSAMLContext(properties, createProvisioner());
  }

  /**
   * Default Provisioner factory method.
   */
  protected Provisioner createProvisioner() {
    SAMLProvisioner samlProvisioner = new SAMLProvisioner();
    samlProvisioner.setUuidAttribute((String) properties.get("samlUuidAttribute"));
    return samlProvisioner;
  }

  @Override
  public boolean canCommence(HttpServletRequest request) {
    return isSAMLResponse(request) || isOAuthCallback(request);
  }

  @Override
  public void authenticate(HttpServletRequest request, HttpServletResponse response, FilterChain chain,
                           String authStateValue, String returnUri) throws IOException, ServletException {
    LOG.debug("Hitting SAML Authenticator filter");
    if (isSAMLResponse(request)) {
      Response samlResponse = extractSamlResponse(request);
      SAMLAuthenticatedPrincipal principal = (SAMLAuthenticatedPrincipal) openSAMLContext.assertionConsumer().consume(samlResponse);
      if (enrichPricipal) {
        //need to save the Principal and the AuthState somewhere
        request.getSession().setAttribute(PRINCIPAL_FROM_SAML, principal);
        request.getSession().setAttribute(RELAY_STATE_FROM_SAML, getSAMLRelayState(request));
        response.sendRedirect(apiClient.getAuthorizationUrl());
      } else {
        proceedWithChain(request, response, chain, principal, getSAMLRelayState(request));
      }
    } else if (isOAuthCallback(request)) {
      SAMLAuthenticatedPrincipal principal = (SAMLAuthenticatedPrincipal) request.getSession().getAttribute(PRINCIPAL_FROM_SAML);
      String authState = (String) request.getSession().getAttribute(RELAY_STATE_FROM_SAML);
      if (principal == null) { //huh
        throw new ServiceProviderAuthenticationException("No principal anymore in the session");
      }
      String userId = principal.getName();
      if (StringUtils.isEmpty(userId)) {
        throw new ServiceProviderAuthenticationException("No userId in SAML assertion!");
      }
      apiClient.oauthCallback(request, userId);
      List<Group20> groups = apiClient.getGroups20(userId, userId);
      if (!CollectionUtils.isEmpty(groups)) {
        for (Group20 group : groups) {
          principal.addGroup(group.getId());
          if (StringUtils.isNotBlank(this.adminGroup) && adminGroup.equalsIgnoreCase(group.getId())) {
            principal.setAdminPrincipal(true);
          }
        }
      }
      proceedWithChain(request, response, chain, principal, authState);
    } else {
      sendAuthnRequest(response, authStateValue, getReturnUri(request));
    }
  }

  private void proceedWithChain(HttpServletRequest request, HttpServletResponse response, FilterChain chain, AuthenticatedPrincipal principal, String authStateValue) throws IOException, ServletException {
    super.setPrincipal(request, principal);
    super.setAuthStateValue(request, authStateValue);
    chain.doFilter(request, response);
  }

  private boolean isOAuthCallback(HttpServletRequest request) {
    return request.getParameter(callbackFlagParameter) != null;
  }

  protected String getSAMLRelayState(HttpServletRequest request) {
    return request.getParameter("RelayState");
  }

  protected boolean isSAMLResponse(HttpServletRequest request) {
    return request.getParameter("SAMLResponse") != null;
  }

  private Response extractSamlResponse(HttpServletRequest request) {

    SAMLMessageContext messageContext;

    final SAMLMessageHandler samlMessageHandler = openSAMLContext.samlMessageHandler();
    try {
      messageContext = samlMessageHandler.extractSAMLMessageContext(request);
    } catch (MessageDecodingException me) {
      throw new ServiceProviderAuthenticationException("Could not decode SAML Response", me);
    } catch (org.opensaml.xml.security.SecurityException se) {
      throw new ServiceProviderAuthenticationException("Could not decode SAML Response", se);
    }

    LOG.debug("Message received from issuer: " + messageContext.getInboundMessageIssuer());

    if (!(messageContext.getInboundSAMLMessage() instanceof Response)) {
      throw new ServiceProviderAuthenticationException("SAML Message was not a Response.");
    }

    final Response inboundSAMLMessage = (Response) messageContext.getInboundSAMLMessage();

    try {
      openSAMLContext.validatorSuite().validate(inboundSAMLMessage);
      return inboundSAMLMessage;
    } catch (ValidationException ve) {
      LOG.warn("Response Message failed Validation", ve);
      throw new RuntimeException("Invalid SAML Response Message", ve);
    }
  }

  private void sendAuthnRequest(HttpServletResponse response, String authState, String returnUri) throws IOException {
    AuthnRequestGenerator authnRequestGenerator = new AuthnRequestGenerator(openSAMLContext.entityId(), timeService,
            idService);
    EndpointGenerator endpointGenerator = new EndpointGenerator();

    final String target = openSAMLContext.getIdpUrl();

    Endpoint endpoint = endpointGenerator.generateEndpoint(
            SingleSignOnService.DEFAULT_ELEMENT_NAME, target, openSAMLContext.assertionConsumerUri());

    AuthnRequest authnRequest = authnRequestGenerator.generateAuthnRequest(target, openSAMLContext.assertionConsumerUri());

    Client client = getClientByRequest(authState);
    String spEntityIdBy = client.getAttributes().get(CLIENT_SAML_ENTITY_NAME);

    if (StringUtils.isNotEmpty(spEntityIdBy)) {
      Scoping scoping = scopingBuilder.buildObject();
      scoping.getRequesterIDs().add(createRequesterID(spEntityIdBy));
      authnRequest.setScoping(scoping);
    } else {
      LOG.warn("For Client {} there is no key CLIENT_SAML_ENTITY_NAME configured to identify the SP entity name. NO SCOPING IS APPLIED", client.getClientId());
    }

    CriteriaSet criteriaSet = new CriteriaSet();
    criteriaSet.add(new EntityIDCriteria(openSAMLContext.entityId()));
    criteriaSet.add(new UsageCriteria(UsageType.SIGNING));
    try {

      Credential signingCredential = openSAMLContext.keyStoreCredentialResolver().resolveSingle(criteriaSet);
      String relayState = authState;
      LOG.debug("Sending authnRequest to {}", target);
      openSAMLContext.samlMessageHandler().sendSAMLMessage(authnRequest, endpoint, response, relayState, signingCredential);
    } catch (MessageEncodingException mee) {
      LOG.error("Could not send authnRequest to Identity Provider.", mee);
      response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
    } catch (org.opensaml.xml.security.SecurityException e) {
      throw new RuntimeException(e);
    }
  }

  private RequesterID createRequesterID(String id) {
    RequesterID requesterID = requesterIDBuilder.buildObject();
    requesterID.setRequesterID(id);
    return requesterID;
  }

  /**
   * Get the Client
   */
  protected Client getClientByRequest(String authState) {
    AuthorizationRequest authorizationRequest = authorizationRequestRepository.findByAuthState(authState);
    return authorizationRequest.getClient();

  }


}
TOP

Related Classes of org.surfnet.oaaas.conext.SAMLAuthenticator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.