Package pattern.model.tree

Source Code of pattern.model.tree.TreeModel

/*
* Copyright (c) 2007-2013 Concurrent, Inc. All Rights Reserved.
*
* Project and contact information: http://www.concurrentinc.com/
*/

package pattern.model.tree;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;

import javax.xml.xpath.XPathConstants;

import org.jgrapht.DirectedGraph;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

import pattern.PMML;
import pattern.PatternException;
import pattern.Schema;
import pattern.model.Model;
import storm.trident.tuple.TridentTuple;

public class TreeModel extends Model implements Serializable {
  /** Field LOG */
  private static final Logger LOG = LoggerFactory.getLogger(TreeModel.class);

  public Context context = null;
  public Tree tree;

  /**
   * Constructor for a TreeModel as a standalone classifier (PMML versions
   * 1-3).
   *
   * @param pmml
   *            PMML model
   * @throws PatternException
   */
  public TreeModel(PMML pmml) throws PatternException {
    schema = pmml.getSchema();
    context = new Context();

    schema.parseMiningSchema(pmml
        .getNodeList("/PMML/TreeModel/MiningSchema/MiningField"));
    tree = new Tree("default");

    String node_expr = "./TreeModel/Node[1]";
    NodeList root_node = pmml.getNodeList(node_expr);

    buildTree(pmml, context, (Element) root_node.item(0), tree);
  }

  /**
   * Constructor for a TreeModel as part of an ensemble (PMML verion 4+), such
   * as in Random Forest.
   *
   * @param pmml
   *            PMML model
   * @param context
   *            tree context
   * @param parent
   *            parent node in the XML
   * @throws PatternException
   */
  public TreeModel(PMML pmml, Context context, Node parent)
      throws PatternException {
    String id = ((Element) parent).getAttribute("id");
    tree = new Tree(id);

    String node_expr = "./TreeModel/Node[1]";
    NodeList root_node = (NodeList) pmml.getReader().read(parent,
        node_expr, XPathConstants.NODESET);

    buildTree(pmml, context, (Element) root_node.item(0), tree);
  }

  /**
   * Prepare to classify with this model. Called immediately before the
   * enclosing Operation instance is put into play processing Tuples.
   */
  @Override
  public void prepare() {
    context.prepare(schema);
  }

  /**
   * Classify an input tuple, returning the predicted label.
   *
   *
   * @param values
   *            tuple values
   * @param fields
   * @return String
   * @throws PatternException
   */
  @Override
  public String classifyTuple(TridentTuple values) throws PatternException {
    // TODO
    return "null";
  }

  /**
   * Generate a serializable graph representation for a tree.
   *
   * @param pmml
   *            PMML model
   * @param shared_context
   *            tree context
   * @param root
   *            root node in the XML
   * @param tree
   *            serializable tree structure
   * @throws PatternException
   */
  public void buildTree(PMML pmml, Context shared_context, Element root,
      Tree tree) throws PatternException {
    Vertex vertex = makeVertex(root, tree.getGraph());
    tree.setRoot(vertex);

    buildNode(pmml, shared_context, root, vertex, tree.getGraph());
  }

  /**
   * @param pmml
   *            PMML model
   * @param shared_context
   *            tree context
   * @param node
   *            predicate node in the XML
   * @param vertex
   *            tree vertex
   * @param graph
   *            serializable graph
   * @throws PatternException
   */
  protected void buildNode(PMML pmml, Context shared_context, Element node,
      Vertex vertex, DirectedGraph<Vertex, Edge> graph)
      throws PatternException {
    // build a list of parameters from which the predicate will be evaluated

    Schema schema = pmml.getSchema();
    String[] param_names = schema.getParamNames();
    List<String> params = new ArrayList<String>();

    for (int i = 0; i < param_names.length; i++)
      params.add(param_names[i]);

    // walk the node list to construct serializable predicates

    NodeList child_nodes = node.getChildNodes();

    for (int i = 0; i < child_nodes.getLength(); i++) {
      Node child = child_nodes.item(i);

      if (child.getNodeType() == Node.ELEMENT_NODE) {
        if ("SimplePredicate".equals(child.getNodeName())
            || "SimpleSetPredicate".equals(child.getNodeName())) {
          Integer predicate_id = shared_context.makePredicate(schema,
              pmml.getReader(), (Element) child, params);

          if (node.hasAttribute("score")) {
            String score = (node).getAttribute("score");
            vertex.setScore(score);
          }

          for (Edge e : graph.edgesOf(vertex))
            e.setPredicateId(predicate_id);
        } else if ("Node".equals(child.getNodeName())) {
          Vertex child_vertex = makeVertex((Element) child, graph);
          Edge edge = graph.addEdge(vertex, child_vertex);

          buildNode(pmml, shared_context, (Element) child,
              child_vertex, graph);
        }
      }
    }
  }

  /**
   * @param node
   *            predicate node in the XML
   * @param graph
   *            serializable graph
   * @return Vertex
   */
  protected Vertex makeVertex(Element node, DirectedGraph<Vertex, Edge> graph) {
    String id = (node).getAttribute("id");
    Vertex vertex = new Vertex(id);
    graph.addVertex(vertex);

    return vertex;
  }

  /** @return String */
  @Override
  public String toString() {
    StringBuilder buf = new StringBuilder();

    if (schema != null) {
      buf.append(schema);
      buf.append("\n");
      buf.append("---------");
      buf.append("\n");
    }

    if (context != null) {
      buf.append(context);
      buf.append("\n");
      buf.append("---------");
      buf.append("\n");
    }

    buf.append(tree);
    buf.append(tree.getRoot());

    for (Edge edge : tree.getGraph().edgeSet())
      buf.append(edge);

    buf.append("\n");

    return buf.toString();
  }
}
TOP

Related Classes of pattern.model.tree.TreeModel

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.