Package org.renjin.compiler.pipeline

Source Code of org.renjin.compiler.pipeline.DeferredGraph

package org.renjin.compiler.pipeline;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.renjin.compiler.pipeline.optimize.Optimizers;
import org.renjin.primitives.vector.DeferredComputation;
import org.renjin.primitives.vector.MemoizedComputation;
import org.renjin.sexp.DoubleArrayVector;
import org.renjin.sexp.Vector;

import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.ListIterator;

/**
* Directed, acyclic graph (DAG) of a deferred computation.
*
* <p>This graph as is constructed at the moment that the
* interpreter actually needs the result of a computation.
*
*/
public class DeferredGraph {

  private DeferredNode rootNode;
  private List<DeferredNode> nodes = Lists.newArrayList();
  private int nextNodeId = 1;
  private IdentityHashMap<Vector, DeferredNode> nodeMap = Maps.newIdentityHashMap();

  public DeferredGraph(DeferredComputation root) {
    this.rootNode = new DeferredNode(nextNodeId(), root);
    nodes.add(rootNode);
    nodeMap.put(root, rootNode);
    addChildren(this.rootNode);

    Optimizers optimizers = new Optimizers();
    optimizers.optimize(this);
    removeOrphans();
  }

  private int nextNodeId() {
    return nextNodeId++;
  }

  private void addChildren(DeferredNode parent) {
    for(Vector operand : parent.getComputation().getOperands()) {
      DeferredNode node = nodeMap.get(operand);
      if(node == null) {
        node = new DeferredNode(nextNodeId(), operand);
        if(node.isComputation()) {
          addChildren(node);
        }
        node = tryMerge(node);
        nodeMap.put(operand, node);
      }
      parent.addOperand(node);
      node.addUse(parent);
    }
  }

  private DeferredNode tryMerge(DeferredNode newNode) {
    for(DeferredNode node : nodeMap.values()) {
      if(node.equivalent(newNode)) {
        return node;
      }
    }
    nodes.add(newNode);
    return newNode;
  }

  public void dumpGraph() {
    try {
      File tempFile = File.createTempFile("deferred", ".dot");
      PrintWriter writer = new PrintWriter(tempFile);
      printGraph(writer);
      writer.close();
      System.out.println("Dumping compute graph to " + tempFile.getAbsolutePath());
    } catch (IOException e) {
    }
  }

  public void printGraph(PrintWriter writer) {
    writer.println("digraph G {");
    printEdges(writer);
    printNodes(writer);
    writer.println("}");
    writer.flush();
  }

  private void printEdges(PrintWriter writer) {
    for(DeferredNode node : nodes) {
      for(DeferredNode operand : node.getOperands()) {
        writer.println(node.getDebugId() + " -> " + operand.getDebugId());
      }
    }
  }

  private void printNodes(PrintWriter writer) {
    for(DeferredNode node : nodes) {
      String shape = "box";
      if(node.isComputation()) {
        if(node.getComputation() instanceof  MemoizedComputation) {
          shape = "ellipse";
        } else {
          shape = "parallelogram";
        }
      }
      writer.println(node.getDebugId() + " [ label=\"" + node.getDebugLabel() + "\",  " +
          "shape=\"" + shape + "\"]");
    }
  }

  public DeferredNode getRoot() {
    return rootNode;
  }

  public List<DeferredNode> getNodes() {
    return nodes;
  }

  public void replaceNode(DeferredNode toReplace, DeferredNode replacementValue) {
    nodes.remove(toReplace);
    if(!nodes.contains(replacementValue)) {
      nodes.add(replacementValue);
    }

    for(DeferredNode operand : toReplace.getOperands()) {
      operand.removeUse(toReplace);
    }

    for(DeferredNode node : nodes) {
      node.replaceOperand(toReplace, replacementValue);
      node.replaceUse(toReplace, replacementValue);
    }
  }

  private void removeOrphans() {
    boolean removing;
    do {
      removing = false;
      ListIterator<DeferredNode> it = nodes.listIterator();
      while(it.hasNext()) {
        DeferredNode node = it.next();
        if(node != rootNode && !node.isUsed()) {
          removing = true;
          it.remove();
        }
      }
    } while(removing);
  }

}
TOP

Related Classes of org.renjin.compiler.pipeline.DeferredGraph

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.