Package cascading.pattern.pmml

Source Code of cascading.pattern.pmml.PMMLPlanner

/*
* Copyright (c) 2007-2013 Concurrent, Inc. All Rights Reserved.
*
* Project and contact information: http://www.cascading.org/
*
* This file is part of the Cascading project.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cascading.pattern.pmml;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.InputStream;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;

import javax.xml.transform.sax.SAXSource;

import cascading.flow.AssemblyPlanner;
import cascading.flow.FlowDescriptors;
import cascading.flow.planner.PlannerException;
import cascading.pattern.PatternException;
import cascading.pattern.datafield.ContinuousDataField;
import cascading.pattern.ensemble.EnsembleSpec;
import cascading.pattern.ensemble.ParallelEnsembleAssembly;
import cascading.pattern.ensemble.selection.SelectionStrategy;
import cascading.pattern.model.ModelSchema;
import cascading.pattern.model.ModelScoringFunction;
import cascading.pattern.model.clustering.ClusteringFunction;
import cascading.pattern.model.clustering.ClusteringSpec;
import cascading.pattern.model.clustering.measure.EuclideanMeasure;
import cascading.pattern.model.clustering.measure.SquaredEuclideanMeasure;
import cascading.pattern.model.generalregression.CategoricalRegressionFunction;
import cascading.pattern.model.generalregression.GeneralRegressionSpec;
import cascading.pattern.model.generalregression.LinkFunction;
import cascading.pattern.model.generalregression.PredictionRegressionFunction;
import cascading.pattern.model.generalregression.RegressionTable;
import cascading.pattern.model.tree.Tree;
import cascading.pattern.model.tree.TreeFunction;
import cascading.pattern.model.tree.TreeSpec;
import cascading.pipe.Each;
import cascading.pipe.Pipe;
import cascading.pipe.assembly.Coerce;
import cascading.pipe.assembly.Retain;
import cascading.scheme.util.FieldTypeResolver;
import cascading.tap.Tap;
import cascading.tuple.Fields;
import cascading.util.Util;
import org.dmg.pmml.Cluster;
import org.dmg.pmml.ClusteringModel;
import org.dmg.pmml.ComparisonMeasure;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Euclidean;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldUsageType;
import org.dmg.pmml.GeneralRegressionModel;
import org.dmg.pmml.Measure;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.MiningModel;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.RegressionModel;
import org.dmg.pmml.Segment;
import org.dmg.pmml.SquaredEuclidean;
import org.dmg.pmml.TreeModel;
import org.dmg.pmml.True;
import org.jpmml.model.ImportFilter;
import org.jpmml.model.JAXBUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.xml.sax.InputSource;

import static cascading.pattern.pmml.DataFields.createDataFields;
import static cascading.pattern.pmml.TreeUtil.createTree;

/**
* Class PMMLPlanner is an implementation of a Cascading {@link AssemblyPlanner} that will convert a PMML document into
* a Cascading {@link cascading.flow.Flow} instance.
* <p/>
* Predictive Model Markup Language (<a href="http://en.wikipedia.org/wiki/Predictive_Model_Markup_Language">PMML</a>)
* is an XML format use to share common machine learning model parameters between applications. PMML documents are
* typically exported from tools such as R after a model has been created..
* <p/>
* PMML is very flexible, but sometimes it doesn't declare defaults for missing elements. For example, the
* "predicted" field is optional (the output field), but Cascading requires it for obvious reasons.
* <p/>
* So methods on this class help set sensible defaults if values are missing in the PMML document.
* <p/>
* PMMLPlanner in essence maps the PMML elements to Pattern model classes and populates those models with the
* given parameters. There may not always be a 1 to 1 correlation of elements to Pattern models, for example PMML
* regression is mapped to Pattern's {@link GeneralRegressionSpec}.
* <p/>
* To use, hand PMMLPlanner an PMML XML file, and pass the PMMLPlanner instance to
* {@link cascading.flow.FlowDef#addAssemblyPlanner(cascading.flow.AssemblyPlanner)}. Use the FlowDef as usual from
* there. The {@link cascading.flow.FlowConnector#connect(cascading.flow.FlowDef)} will then return a
* {@link cascading.flow.Flow} that will score your data.
* <p/>
* Optionally {@link #resolveAssembly(cascading.pipe.Pipe)} can be used directly, but requires more Cascading
* experience.
*/
public class PMMLPlanner implements AssemblyPlanner
  {
  private static final Logger LOG = LoggerFactory.getLogger( PMMLPlanner.class );

  File pmmlFile;
  InputStream pmmlStream;
  String branchName = "pmml";
  String headName;
  String tailName;

  private PMML pmml;
  private PMMLModel pmmlModel;
  private Fields retainIncomingFields;
  private Fields defaultPredictedField;

  /**
   * Instantiates a new PMML planner.
   *
   * @param pmmlFile the pmml file
   */
  public PMMLPlanner( File pmmlFile )
    {
    this.pmmlFile = pmmlFile;
    }

  /**
   * Instantiates a new PMML planner.
   *
   * @param pmmlStream the pmml stream
   */
  public PMMLPlanner( InputStream pmmlStream )
    {
    this.pmmlStream = pmmlStream;
    }

  /** Instantiates a new PMML planner. */
  public PMMLPlanner()
    {
    }

  public File getPMMLFile()
    {
    return pmmlFile;
    }

  public PMMLPlanner setPMMLInput( File pmmlFile )
    {
    this.pmmlFile = pmmlFile;

    return this;
    }

  public InputStream getPMMLStream()
    {
    return pmmlStream;
    }

  public PMMLPlanner setPMMLInput( InputStream pmmlStream )
    {
    this.pmmlStream = pmmlStream;

    return this;
    }

  /**
   * If a predicted data field is not declared, use this field name.
   *
   * @param defaultPredictedField
   * @return
   */
  public PMMLPlanner setDefaultPredictedField( Fields defaultPredictedField )
    {
    this.defaultPredictedField = defaultPredictedField;

    return this;
    }

  public Fields getDefaultPredictedField()
    {
    return defaultPredictedField;
    }

  /**
   * Gets head name.
   *
   * @return the head name
   */
  public String getHeadName()
    {
    return headName;
    }

  /**
   * Sets incoming source head name to use.
   * <p/>
   * If not set, will use the only available input source Tap name. If more than one
   * source was declared, the planner will throw an exception.
   * <p/>
   * Cascading binds heads of the pipe assembly to the corresponding named source Tap.
   *
   * @param headName the head name
   * @return the head name
   */
  public PMMLPlanner setHeadName( String headName )
    {
    this.headName = headName;

    return this;
    }

  public String getBranchName()
    {
    return branchName;
    }

  /**
   * Sets branch name to use for the resulting pipe assembly.
   *
   * @param branchName the branch name
   * @return the branch name
   */
  public PMMLPlanner setBranchName( String branchName )
    {
    this.branchName = branchName;

    return this;
    }

  public String getTailName()
    {
    return tailName;
    }

  /**
   * Sets the outgoing tail name to use.
   * <p/>
   * If not set, will use the only available input sink Tap name. If more than one
   * sink was declared, the planner will throw an exception.
   * <p/>
   * Cascading binds tails of the pipe assembly to the corresponding named sink Tap.
   *
   * @param tailName the tail name
   * @return the tail name
   */
  public PMMLPlanner setTailName( String tailName )
    {
    this.tailName = tailName;

    return this;
    }

  public Fields getRetainIncomingFields()
    {
    return retainIncomingFields;
    }

  /**
   * Sets the incoming fields that should be retained, by default all incoming fields
   * are passed through the assembly from the source Tap.
   * <p/>
   * This must include the fields required by the underlying models.
   * <p/>
   * Use {@link #retainOnlyActiveIncomingFields()} to dynamically limit the result to
   * those required.
   *
   * @param retainIncomingFields the retain incoming fields
   * @return the retain incoming fields
   */
  public PMMLPlanner setRetainIncomingFields( Fields retainIncomingFields )
    {
    if( retainIncomingFields.isNone() )
      throw new IllegalArgumentException( "may not retain Fields.NONE" );

    this.retainIncomingFields = retainIncomingFields;

    return this;
    }

  /**
   * Retain only active incoming fields as declared by the PMML. All other fields
   * in the incoming source Tap will be discarded.
   *
   * @return the pMML planner
   */
  public PMMLPlanner retainOnlyActiveIncomingFields()
    {
    return setRetainIncomingFields( getActiveFields() );
    }

  protected PMMLModel getPMMLModel()
    {
    if( pmmlModel == null )
      this.pmmlModel = new PMMLModel( getPMML() );

    return pmmlModel;
    }

  private PMML getPMML()
    {
    if( pmml == null )
      {
      InputStream stream = getPMMLStream();

      if( stream == null )
        stream = getFileInputStream();

      pmml = parse( stream );
      }

    return pmml;
    }

  private FileInputStream getFileInputStream()
    {
    try
      {
      return new FileInputStream( getPMMLFile() );
      }
    catch( FileNotFoundException exception )
      {
      throw new PatternException( "could not find pmml file: " + getPMMLFile(), exception );
      }
    }

  private void setPMML( PMML pmml )
    {
    this.pmml = pmml;
    }

  private static PMML parse( InputStream inputStream )
    {
    try
      {
      InputSource source = new InputSource( inputStream );
      SAXSource transformedSource = ImportFilter.apply( source );
      return JAXBUtil.unmarshalPMML( transformedSource );
      }
    catch( Exception exception )
      {
      throw new PatternException( "could not read inputStream: " + inputStream, exception );
      }
    }

  /**
   * Returns the "active" fields declared in the PMML document.
   *
   * @return the active fields
   */
  public Fields getActiveFields()
    {
    return getPMMLModel().getActiveFields();
    }

  /**
   * Returns predicted fields declared in the PMML document.
   *
   * @return the predicted fields
   */
  public Fields getPredictedFields()
    {
    return getPMMLModel().getPredictedFields();
    }

  public PMMLPlanner addDataTypes( Fields fields )
    {
    for( Comparable field : fields )
      {
      if( field instanceof Number )
        continue;

      addDataType( (String) field, fields.getType( fields.getPos( field ) ) );
      }

    return this;
    }

  public PMMLPlanner addDataType( String name, Type type )
    {
    FieldName fieldName = new FieldName( name );

    if( getPMMLModel().getDataField( fieldName ) != null )
      throw new IllegalStateException( "field name already in dictionary" );

    DataType dataType = DataTypes.getTypeToPmml( type );

    if( dataType == null )
      throw new IllegalStateException( "no data type mapping from: " + Util.getTypeName( type ) );

    getPMMLModel().addDataField( fieldName, OpType.CONTINUOUS, dataType );

    return this;
    }

  public FieldTypeResolver getFieldTypeResolver()
    {
    return new PMMLTypeResolver( this );
    }

  @Override
  public List<Pipe> resolveTails( Context context )
    {
    if( context.getTails().size() > 1 )
      throw new PlannerException( "assembly can only accept one inbound branch, got: " + context.getTails().size() );

    if( context.getTails().size() == 1 && getHeadName() != null )
      throw new PlannerException( "cannot specify a head name if there are incoming branches" );

    Pipe tail = null;

    String headName = findHeadName( context );

    if( context.getTails().size() == 0 && headName != null )
      tail = new Pipe( headName );
    else if( context.getTails().size() == 1 )
      tail = context.getTails().get( 0 );

    tail = applyCoercion( tail, context.getFlow().getSource( headName ) );

    tail = resolveAssembly( tail ); // branch name is applied

    tail = new Pipe( findTailName( context ), tail ); // bind the tail to the sink tailName

    return Arrays.asList( tail );
    }

  @Override
  public Map<String, String> getFlowDescriptor()
    {
    if( getPMMLFile() == null )
      return Collections.emptyMap();

    Map<String, String> map = new HashMap<String, String>();

    map.put( "pmml-resource", getPMMLFile().toString() );

    return map;
    }

  private Pipe applyCoercion( Pipe tail, Tap source )
    {
    Fields sourceFields = source.getSourceFields();
    FieldTypeResolver fieldTypeResolver = getFieldTypeResolver();

    Fields coercedFields = Fields.NONE;

    int count = 0;
    for( Comparable sourceField : sourceFields )
      {
      Type incoming = sourceFields.getType( sourceField );
      Type outgoing = fieldTypeResolver.inferTypeFrom( count++, sourceField.toString() );

      if( incoming != outgoing )
        coercedFields = coercedFields.append( new Fields( sourceField, outgoing ) );
      }

    if( coercedFields.isNone() )
      return tail;

    return new Coerce( tail, coercedFields );
    }

  private String findHeadName( Context context )
    {
    String headName;

    if( getHeadName() != null )
      headName = getHeadName();
    else if( context.getFlow().getSources().size() == 1 )
      headName = (String) context.getFlow().getSourceNames().get( 0 );
    else
      throw new IllegalStateException( "too man sources to choose from, found: " + context.getFlow().getSources().size() + ", use setHeadName" );

    return headName;
    }

  private String findTailName( Context context )
    {
    String tailName;

    if( getTailName() != null )
      tailName = getTailName();
    else if( context.getFlow().getSinks().size() == 1 )
      tailName = (String) context.getFlow().getSinkNames().get( 0 );
    else
      throw new IllegalStateException( "too man sinks to choose from, found: " + context.getFlow().getSinks().size() + ", use setTailName" );

    return tailName;
    }

  /**
   * Resolve assembly can be used if the give PMML needs to be extended beyond the capability of a Flow created
   * by a FlowConnector and FlowDef.
   * <p/>
   * The models will be appended to the given head Pipe instance, the final tail Pipe will be returned.
   *
   * @param pipe the pipe
   * @return the pipe
   */
  public Pipe resolveAssembly( Pipe pipe )
    {
    Pipe tail;

    if( pipe == null )
      tail = new Pipe( getBranchName() );
    else
      tail = new Pipe( getBranchName(), pipe );

    if( getRetainIncomingFields() != null )
      tail = new Retain( tail, getRetainIncomingFields() );

    for( Model model : getPMMLModel().getModels() )
      {
      if( !model.isScorable() )
        {
        LOG.info( "skipping non-scoreable model: " + model.getModelName() );
        continue;
        }

      if( model instanceof MiningModel )
        tail = handleMiningModel( tail, (MiningModel) model );
      else if( model instanceof ClusteringModel )
        tail = handleClusteringModel( tail, (ClusteringModel) model );
      else if( model instanceof RegressionModel )
        tail = handleRegressionModel( tail, (RegressionModel) model );
      else if( model instanceof GeneralRegressionModel )
        tail = handleGeneralRegressionModel( tail, (GeneralRegressionModel) model );
      else if( model instanceof TreeModel )
        tail = handleTreeModel( tail, (TreeModel) model );
      else
        throw new PatternException( "unsupported model type: " + model.getClass().getName() );
      }

    return tail;
    }

  private Pipe handleTreeModel( Pipe tail, TreeModel model )
    {
    ModelSchema modelSchema = createModelSchema( model );

    Tree tree = createTree( model, modelSchema );

    TreeSpec treeSpec = new TreeSpec( modelSchema, tree );

    return create( tail, modelSchema, new TreeFunction( treeSpec ) );
    }

  private Pipe handleGeneralRegressionModel( Pipe tail, GeneralRegressionModel model )
    {
    ModelSchema modelSchema = createModelSchema( model );

    Set<String> parameterList = GeneralRegressionUtil.createParameters( model );
    Set<String> covariateList = GeneralRegressionUtil.createCovariates( model );
    Set<String> factorsList = GeneralRegressionUtil.createFactors( model );

    RegressionTable regressionTable = GeneralRegressionUtil.createPPMatrix( model, parameterList, factorsList, covariateList );

    LinkFunction linkFunction = LinkFunction.getFunction( model.getLinkFunction().value() );

    GeneralRegressionSpec modelParam = new GeneralRegressionSpec( modelSchema, regressionTable, linkFunction );

    return create( tail, modelSchema, new PredictionRegressionFunction( modelParam ) );
    }

  private Pipe handleRegressionModel( Pipe tail, RegressionModel model )
    {
    if( model.getFunctionName() == MiningFunctionType.REGRESSION )
      return handlePredictionRegressionModel( tail, model );

    if( model.getFunctionName() == MiningFunctionType.CLASSIFICATION )
      return handleCategoricalRegressionModel( tail, model );

    throw new UnsupportedOperationException( "unsupported mining type, got: " + model.getFunctionName() );
    }

  private Pipe handleCategoricalRegressionModel( Pipe tail, RegressionModel model )
    {
    ModelSchema modelSchema = createModelSchema( model );

    List<String> predictedCategories = new ArrayList<String>( modelSchema.getPredictedCategories( modelSchema.getPredictedFieldNames().get( 0 ) ) );

    if( predictedCategories.isEmpty() )
      throw new PatternException( "no categories specified" );

    GeneralRegressionSpec regressionSpec = new GeneralRegressionSpec( modelSchema );

    regressionSpec.setNormalization( RegressionUtil.getNormalizationMethod( model ) );

    for( org.dmg.pmml.RegressionTable regressionTable : model.getRegressionTables() )
      regressionSpec.addRegressionTable( RegressionUtil.createTable( regressionTable ) );

    return create( tail, modelSchema, new CategoricalRegressionFunction( regressionSpec ) );
    }

  private Pipe handlePredictionRegressionModel( Pipe tail, RegressionModel model )
    {
    if( model.getRegressionTables().size() != 1 )
      throw new UnsupportedOperationException( "regression model only supports a single regression table, got: " + model.getRegressionTables().size() );

    ModelSchema modelSchema = createModelSchema( model );

    org.dmg.pmml.RegressionTable regressionTable = model.getRegressionTables().get( 0 );

    GeneralRegressionSpec regressionSpec = new GeneralRegressionSpec( modelSchema );

    regressionSpec.setLinkFunction( LinkFunction.NONE );

    regressionSpec.addRegressionTable( RegressionUtil.createTable( regressionTable ) );

    return create( tail, modelSchema, new PredictionRegressionFunction( regressionSpec ) );
    }

  private Pipe handleClusteringModel( Pipe tail, ClusteringModel model )
    {
    ModelSchema modelSchema = createModelSchema( model );

    if( model.getModelClass() != ClusteringModel.ModelClass.CENTER_BASED )
      throw new UnsupportedOperationException( "unsupported model class, got: " + model.getModelClass() );

    ComparisonMeasure comparisonMeasure = model.getComparisonMeasure();

    if( comparisonMeasure.getKind() != ComparisonMeasure.Kind.DISTANCE )
      throw new UnsupportedOperationException( "unsupported comparison kind, got: " + comparisonMeasure.getKind() );

    ClusteringSpec clusteringSpec = new ClusteringSpec( modelSchema );

    Measure measure = comparisonMeasure.getMeasure();
    if( measure instanceof Euclidean )
      clusteringSpec.setComparisonMeasure( new EuclideanMeasure() );
    else if( measure instanceof SquaredEuclidean )
      clusteringSpec.setComparisonMeasure( new SquaredEuclideanMeasure() );
    else
      throw new UnsupportedOperationException( "unsupported comparison measure: " + comparisonMeasure );

    clusteringSpec.setDefaultCompareFunction( ClusteringUtil.setComparisonFunction( model ) );

    for( Cluster cluster : model.getClusters() )
      {
      List<Double> exemplar = (List<Double>) PMMLUtil.parseArray( cluster.getArray() );

      LOG.debug( "exemplar: {}", exemplar );

      clusteringSpec.addCluster( new cascading.pattern.model.clustering.Cluster( cluster.getName(), exemplar ) );
      }

    return create( tail, modelSchema, new ClusteringFunction( clusteringSpec ) );
    }

  private Pipe handleMiningModel( Pipe tail, MiningModel model )
    {
    ModelSchema modelSchema = createModelSchema( model );
    List<TreeSpec> models = new LinkedList<TreeSpec>();

    for( Segment segment : model.getSegmentation().getSegments() )
      {
      if( !segment.getPredicate().getClass().equals( True.class ) )
        throw new PatternException( "segment predicates currently not supported, got: " + segment.getPredicate() );

      Model segmentModel = segment.getModel();

      if( segmentModel instanceof TreeModel )
        models.add( new TreeSpec( modelSchema, createTree( (TreeModel) segmentModel, modelSchema ) ) );
      else
        throw new PatternException( "ensemble model currently not supported, got: " + segmentModel );
      }

    EnsembleSpec<TreeSpec> miningSpec = new EnsembleSpec<TreeSpec>( modelSchema, models );

    LOG.debug( "creating: {}, input: {}, output: {}", new Object[]{miningSpec, modelSchema.getInputFields(),
                                                                   modelSchema.getDeclaredFields()} );

    SelectionStrategy strategy = PMMLUtil.getSelectionStrategy( model );

    miningSpec.setSelectionStrategy( strategy );

    tail = new ParallelEnsembleAssembly( tail, miningSpec );

    return tail;
    }

  private Pipe create( Pipe tail, ModelSchema schemaParam, ModelScoringFunction function )
    {
    Fields inputFields = schemaParam.getInputFields();
    Fields declaredFields = schemaParam.getDeclaredFields();

    LOG.debug( "creating: {}, input: {}, output: {}", new Object[]{schemaParam, inputFields, declaredFields} );

    return new Each( tail, inputFields, function, Fields.ALL );
    }

  private ModelSchema createModelSchema( Model model )
    {
    ModelSchema modelSchema = new ModelSchema();

    for( MiningField miningField : model.getMiningSchema().getMiningFields() )
      {
      DataField dataField = getPMMLModel().getDataField( miningField.getName() );

      if( miningField.getUsageType() == FieldUsageType.ACTIVE )
        modelSchema.addExpectedField( createDataFields( dataField ) );
      else if( miningField.getUsageType() == FieldUsageType.PREDICTED )
        modelSchema.setPredictedFields( createDataFields( dataField ) );
      }

    if( modelSchema.getPredictedFieldNames().isEmpty() && defaultPredictedField == null )
      throw new PatternException( "no predicted field name provided in PMML model, use setDefaultPredictedField() method" );

    if( modelSchema.getPredictedFieldNames().isEmpty() )
      modelSchema.setPredictedFields( new ContinuousDataField( defaultPredictedField ) );

    return modelSchema;
    }

  }
TOP

Related Classes of cascading.pattern.pmml.PMMLPlanner

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.