Package edu.uci.jforestsx.learning.trees

Source Code of edu.uci.jforestsx.learning.trees.Tree

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package edu.uci.jforestsx.learning.trees;

import java.util.ArrayList;
import java.util.List;

import edu.uci.jforestsx.dataset.Dataset;

import edu.uci.jforestsx.dataset.Feature;
import edu.uci.jforestsx.sample.Sample;
import edu.uci.jforestsx.util.ArraysUtil;

/**
* @author Yasser Ganjisaffar <ganjisaffar at gmail dot com>
*/

public abstract class Tree implements Cloneable {

  public int numLeaves;

  protected int[] leftChild;
  protected int[] rightChild;
  protected int[] splitFeatures;
  protected int[] thresholds;
  protected double[] originalThresholds;

  public Tree() {
  }

  public void init(int maxLeaves) {
    leftChild = new int[maxLeaves - 1];
    rightChild = new int[maxLeaves - 1];
    splitFeatures = new int[maxLeaves - 1];
    thresholds = new int[maxLeaves - 1];
    originalThresholds = new double[maxLeaves - 1];
    numLeaves = 1;
  }

  protected void copyTo(Tree copy) {
    copy.init(leftChild.length + 1);
    copy.numLeaves = numLeaves;
    System.arraycopy(leftChild, 0, copy.leftChild, 0, numLeaves - 1);
    System.arraycopy(rightChild, 0, copy.rightChild, 0, numLeaves - 1);
    System.arraycopy(splitFeatures, 0, copy.splitFeatures, 0, numLeaves - 1);
    System.arraycopy(thresholds, 0, copy.thresholds, 0, numLeaves - 1);
    System.arraycopy(originalThresholds, 0, copy.originalThresholds, 0, numLeaves - 1);
  }

  public int getLeftChild(int node) {
    return leftChild[node];
  }

  public void setLeftChild(int node, int newChild) {
    leftChild[node] = newChild;
  }

  public int getRightChild(int node) {
    return rightChild[node];
  }

  public void setRightChild(int node, int newChild) {
    rightChild[node] = newChild;
  }

  public int getSplitFeature(int node) {
    return splitFeatures[node];
  }

  public int getThreshold(int node) {
    return thresholds[node];
  }

  public double getOriginalThreshold(int node) {
    return originalThresholds[node];
  }

  /** SISTA added code: similar to getLeaf for one Dataset instance */
  public int getLeaf(Feature[] features) {
    if (numLeaves == 1) {
      return 0;
    }

    int node = 0;
    while (node >= 0) {
      int split = splitFeatures[node];
      assert(split < features.length);
      Feature feat = features[split];
      if(feat.upperBounds[feat.bins.get(0)] <= thresholds[node]) {
        node = leftChild[node];
      } else {
        node = rightChild[node];
      }
    }

    return ~node;
  }

  public int getLeaf(Dataset dataset, int instanceIndex) {
    if (numLeaves == 1) {
      return 0;
    }

    int node = 0;
    while (node >= 0) {
      /*
      // SISTA
      double v = dataset.getOriginalFeatureValue(instanceIndex, splitFeatures[node]);
      System.out.println("Classification point:");
      System.out.println("\tnode: " + node);
      System.out.println("\tsplitFeatures[node]: " + splitFeatures[node]);
      System.out.println("\tv: " + v);
      */

      if (dataset.getFeatureValue(instanceIndex, splitFeatures[node]) <= thresholds[node]) {
        node = leftChild[node];
      } else {
        node = rightChild[node];
      }
    }
    return ~node;
  }

  public int getLeafFromOriginalThreshold(Dataset dataset, int instanceIndex) {
    if (numLeaves == 1) {
      return 0;
    }

    int node = 0;
    while (node >= 0) {
     

      if (dataset.getOriginalFeatureValue(instanceIndex, splitFeatures[node]) <= originalThresholds[node]) {
        node = leftChild[node];
      } else {
        node = rightChild[node];
      }
    }
    return ~node;
  }

  public int getLeafFromOriginalThreshold(double[] featureVector) {
    if (numLeaves == 1) {
      return 0;
    }

    int node = 0;
    while (node >= 0) {
      if (featureVector[splitFeatures[node]] <= originalThresholds[node]) {
        node = leftChild[node];
      } else {
        node = rightChild[node];
      }
    }
    return ~node;
  }

  public int getParent(int node) {
    int parent = ArraysUtil.findIndex(leftChild, node, numLeaves - 1);
    if (parent >= 0) {
      return parent;
    }
    parent = ArraysUtil.findIndex(rightChild, node, numLeaves - 1);
    if (parent >= 0) {
      return parent;
    }
    return -1;
  }

  /**
   * Returns list of node parents
   */
  public int[] getNodeParents(int node) {
    List<Integer> parents = new ArrayList<Integer>();
    int parent = getParent(node);
    while (parent >= 0) {
      parents.add(parent);
      parent = getParent(parent);
    }
    int[] result = new int[parents.size()];
    for (int i = 0; i < result.length; i++) {
      result[i] = parents.get(parents.size() - i - 1);
    }
    return result;
  }
 
  /**
   * Extracts the list of leaves that belong to the subtree which
   * which has 'node' as its root. The result is added to 'leaves'
   * input list.
   */
  public void loadLeavesInSubtree(int node, List<Integer> leaves) {
    if (node < 0) {
      leaves.add(~node);
    } else {
      loadLeavesInSubtree(leftChild[node], leaves);
      loadLeavesInSubtree(rightChild[node], leaves);
    }
  }

  /**
   * turns a leaf of the tree into an internal node with two leaf-children
   */
  public int split(int leaf, TreeSplit split) {
    int indexOfNewNonLeaf = numLeaves - 1;

    // find the leaf's parent, and update its info
    int parent = ArraysUtil.findIndex(leftChild, ~leaf, numLeaves - 1);
    if (parent >= 0) {
      leftChild[parent] = indexOfNewNonLeaf;
    } else {
      parent = ArraysUtil.findIndex(rightChild, ~leaf, numLeaves - 1);
      if (parent >= 0) {
        rightChild[parent] = indexOfNewNonLeaf;
      }
    }

    splitFeatures[indexOfNewNonLeaf] = split.feature;
    thresholds[indexOfNewNonLeaf] = split.threshold;
    originalThresholds[indexOfNewNonLeaf] = split.originalThreshold;
    leftChild[indexOfNewNonLeaf] = ~leaf;
    rightChild[indexOfNewNonLeaf] = ~numLeaves;

    numLeaves++;
    return indexOfNewNonLeaf;
  }

  private int getNodeLabel(int n) {
    if (n < 0) {
      return -1 - (~n);
    }
    return n;
  }

  protected abstract void addCustomData(String linePrefix, StringBuilder sb);

  public abstract void loadCustomData(String str) throws Exception;

  public String toString(double weight, int indentationLevel) {
    String linePrefix = "";
    for (int i = 0; i < indentationLevel; i++) {
      linePrefix += "\t";
    }
    StringBuilder sb = new StringBuilder();
    sb.append("\n" + linePrefix + "<Tree leaves=\"" + numLeaves + "\" weight=\"" + weight + "\">");

    StringBuilder sbFeatures = new StringBuilder();
    StringBuilder sbLeftChildren = new StringBuilder();
    StringBuilder sbRightChild = new StringBuilder();
    StringBuilder sbThreshold = new StringBuilder();
    StringBuilder sbOriginalThreshold = new StringBuilder();

    int numNonLeaves = numLeaves - 1;
    for (int n = 0; n < numNonLeaves; n++) {
      sbFeatures.append(" " + splitFeatures[n]);
      sbLeftChildren.append(" " + getNodeLabel(leftChild[n]));
      sbRightChild.append(" " + getNodeLabel(rightChild[n]));
      sbThreshold.append(" " + thresholds[n]);
      sbOriginalThreshold.append(" " + originalThresholds[n]);
    }

    sb.append("\n" + linePrefix + "\t<SplitFeatures>" + sbFeatures.toString().trim() + "</SplitFeatures>");
    sb.append("\n" + linePrefix + "\t<LeftChildren>" + sbLeftChildren.toString().trim() + "</LeftChildren>");
    sb.append("\n" + linePrefix + "\t<RightChildren>" + sbRightChild.toString().trim() + "</RightChildren>");
    sb.append("\n" + linePrefix + "\t<Thresholds>" + sbThreshold.toString().trim() + "</Thresholds>");
    sb.append("\n" + linePrefix + "\t<OriginalThresholds>" + sbOriginalThreshold.toString().trim() + "</OriginalThresholds>");

    addCustomData(linePrefix, sb);

    sb.append("\n" + linePrefix + "</Tree>");
    return sb.toString();
  }

  protected String removeXmlTag(String line, String tagName) {
    return line.trim().replace(tagName, "").replace("<>", "").replace("</>", "");
  }

  public void loadFromString(int numLeaves, String splitFeaturesLine, String leftChildrenLine, String rightChildrenLine, String thresholdsLine,
      String originalThresholdsLine) throws Exception {
    splitFeatures = ArraysUtil.loadIntArrayFromLine(removeXmlTag(splitFeaturesLine, "SplitFeatures"), numLeaves - 1);
    leftChild = ArraysUtil.loadIntArrayFromLine(removeXmlTag(leftChildrenLine, "LeftChildren"), numLeaves - 1);
    rightChild = ArraysUtil.loadIntArrayFromLine(removeXmlTag(rightChildrenLine, "RightChildren"), numLeaves - 1);
    thresholds = ArraysUtil.loadIntArrayFromLine(removeXmlTag(thresholdsLine, "Thresholds"), numLeaves - 1);
    originalThresholds = ArraysUtil.loadDoubleArrayFromLine(removeXmlTag(originalThresholdsLine, "OriginalThresholds"), numLeaves - 1);
    this.numLeaves = numLeaves;
  }
 
  public abstract void backfit(Sample sample);
 
}
TOP

Related Classes of edu.uci.jforestsx.learning.trees.Tree

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.