package statechurn;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.logging.Level;
import java.util.logging.Logger;
import primitives.graph.*;
import java.util.*;
import translator.DOTWriter;
import java.util.HashSet;
public class KTails {
private HashMap<String, Node> nodes;
private int merges = 0;
private int counter = 0;
private Stack<Node> toremove;
private Graph g;
private HashMap<Transition, Boolean> visitedTs;
private HashMap<String, HashMap<String, HashMap>> tailCache;
private HashMap<String, HashMap<String, Boolean>> sameCache;
private int k = 5;
private int cacheMisses;
private int cacheHits;
public KTails(HashMap<String, Node> nodes, int k) {
this.nodes = nodes;
this.k = k;
toremove = new Stack<Node>();
g = new Graph();
visitedTs = new HashMap<Transition, Boolean>();
tailCache = new HashMap<String, HashMap<String, HashMap>>();
sameCache = new HashMap<String, HashMap<String, Boolean>>();
}
public void deleteFromCaches(Node n) {
tailCache.remove(n.getLabel());
sameCache.remove(n.getLabel());
}
public boolean equivalent(HashMap<String, HashMap> trace1, HashMap<String, HashMap> trace2) {
if (trace1.isEmpty()) {
return false;
}
return (trace1.equals(trace2));
}
public void clearVisitedTs() {
visitedTs = new HashMap<Transition, Boolean>();
}
public boolean canMerge(Node node1, Node node2) {
if (node1 == node2) {
return false;
}
String label1 = node1.getLabel();
String label2 = node2.getLabel();
if (label1.compareTo(label2) > 0) {
label1 = node2.getLabel();
label2 = node1.getLabel();
}
HashMap<String, Boolean> cHit = sameCache.get(label1);
if (cHit == null) {
cHit = new HashMap<String, Boolean>();
sameCache.put(label1, cHit);
}
Boolean b = cHit.get(label2);
if (b == null) {
cacheMisses++;
HashMap<String, HashMap> traces1 = tailCache.get(node1.getLabel());
HashMap<String, HashMap> traces2 = tailCache.get(node2.getLabel());
if (traces1 == null) {
traces1 = tracesOf(node1,k);
tailCache.put(node1.getLabel(), traces1);
}
if (traces2 == null) {
traces2 = tracesOf(node2, k);
tailCache.put(node2.getLabel(), traces2);
}
boolean res = equivalent(traces1, traces2);
cHit.put(label2, res);
return res;
} else {
cacheHits++;
return b;
}
}
public HashMap<String, HashMap> tracesOf(Node node, int length) {
Stack<ArrayList<Object>> stack = new Stack<ArrayList<Object>>();
HashMap<String, HashMap> ret = new HashMap<String, HashMap>();
stack.push(l(node, length, ret));
while (!stack.isEmpty()) {
ArrayList<Object> top = stack.pop();
Node n = (Node) top.get(0);
int l = (Integer) top.get(1);
HashMap<String, HashMap<String, Object>> rtarg = (HashMap<String, HashMap<String, Object>>) top.get(2);
if (l > 0) {
Set<Transition> ts = n.getTransitionsAsT();
for (Transition it : ts) {
Node dest = it.getDestinationNode();
HashMap<String, Object> target = rtarg.get(it.getLabel());
if (target == null) {
target = new HashMap<String, Object>();
rtarg.put(it.getLabel(), target);
}
stack.push(l(dest, l - 1, target));
}
}
}
return ret;
}
public Graph getGraph() {
Graph g = new Graph();
for (String k : nodes.keySet()) {
Node n = nodes.get(k);
g.addNode(n);
}
return g;
}
//TODO make sure this works.
public boolean doStep() {
Node na = null;
Node nb = null;
HashSet<MergePair> s = new HashSet<MergePair>();
boolean brokeForMemory= false;
outerloop:
for (String label1 : nodes.keySet()) {
for (String label2 : nodes.keySet()) {
if (!label1.equals(label2)) {
Node n1 = nodes.get(label1);
Node n2 = nodes.get(label2);
if (n1.refCount > 0 && n2.refCount > 0) {
if (canMerge(n1, n2)) {
s.add(new MergePair(n2,n1));
break;
}
if(!s.isEmpty() && (double)(Runtime.getRuntime().totalMemory()) / Runtime.getRuntime().maxMemory() > 0.95){
brokeForMemory = true;
break outerloop;
}
}
}
}
}
long start = System.currentTimeMillis();
int oldSize = nodes.size();
for(MergePair l : s){
na = l.getLeft();
nb = l.getRight();
if (na != null && nb != null && na.refCount > 0 && nb.refCount > 0){
replace(nb, na);
pruneNodes();
}
double time = (System.currentTimeMillis() - start) / 1000;
if (time > 5) {
int size = nodes.size();
int delta = oldSize - size;
oldSize = size;
double rate = (delta / (0.000001 + time));
System.err.println(
String.format(
"[doStep] %d merges have been made in %.2fs [%.2f/sec], machine has %d nodes.", delta, time, rate, size));
start = System.currentTimeMillis();
}
}
return !s.isEmpty() || brokeForMemory;
}
public void replace(Node node, Node with) {
//replace these with loops
if (node == with) {
return;
}
for (String l : nodes.keySet()) {
Node n = nodes.get(l);
if (n != null) {
Set<Transition> transitions = n.getTransitionsAsT();
for (Transition it : transitions) {
Node target = it.getDestinationNode();
String label = it.getLabel();
if (target == node) {
n.deleteTransition(it);
n.connect(with, label);
}
}
}
}
clearVisitedTs();
Set<Transition> ts = node.getTransitionsAsT();
for (Transition it : ts) {
merge(with, it);
}
nodes.remove(with.getLabel());
tailCache.remove(with.getLabel());
if (!with.getLabel().equals("Start")) {
with.setLabel("merge" + counter);
counter++;
}
if (node.getLabel().equals("Start")) {
with.setLabel("Start");
}
nodes.put(with.getLabel(), with);
toremove.add(node);
merges++;
}
public static ArrayList l(Object a, Object b) {
ArrayList ret = new ArrayList();
ret.add(a);
ret.add(b);
return ret;
}
public static ArrayList l(Object a, Object b, Object c) {
ArrayList ret = new ArrayList();
ret.add(a);
ret.add(b);
ret.add(c);
return ret;
}
public void merge(Node _to, Transition _from) {
Stack<ArrayList> stack = new Stack<ArrayList>();
stack.push(l(_to, _from));
while (!stack.isEmpty()) {
List rs = stack.pop();
Node to = (Node) rs.get(0);
Transition from = (Transition) rs.get(1);
if (visitedTs.get(from) != null) {
continue;
}
visitedTs.put(from, true);
if (!to.hasTransitionWithLabel(from.getLabel())) {
to.connect(from.getDestinationNode(), from.getLabel());
} else {
try {
Transition t = to.transitionWithLabel(from.getLabel());
if (t != from) {
if (t.getDestinationNode() == from.getDestinationNode()) {
} else {
Set<Transition> ts = from.getDestinationNode().getTransitionsAsT();
for (Transition it : ts) {
stack.push(l(t.getDestinationNode(), it));
}
}
}
} catch (TransitionNotFoundException e) {
}
}
}
}
public void pruneNodes() {
while (!toremove.isEmpty()) {
Node it = toremove.pop();
if (it == null) {
continue;
}
if (it.refCount <= 0) {
Set<Transition> ts = it.getTransitionsAsT();
for (Transition t : ts) {
toremove.push(t.getDestinationNode());
it.deleteTransition(t);
}
nodes.remove(it.getLabel());
deleteFromCaches(it);
}
}
toremove.clear();
}
public Graph doKTails() {
boolean go = true;
int iter = 0;
DOTWriter writer = new DOTWriter();
int startSize = nodes.size();
long startTime = System.currentTimeMillis();
g = getGraph();
dump(String.format("iter%d.dot", iter), writer.getRepresentation(g));
long start = System.currentTimeMillis();
int oldSize = nodes.size();
while (go) {
iter++;
go = doStep();
pruneNodes();
double time = (System.currentTimeMillis() - start) / 1000;
if (time > 5) {
int size = nodes.size();
int delta = oldSize - size;
oldSize = size;
double rate = (delta / (0.000001 + time));
System.err.println(
String.format(
"[doKTails] %d merges have been made in %.2fs [%.2f/sec], machine has %d nodes.", delta, time, rate, size));
start = System.currentTimeMillis();
}
}
g = getGraph();
double time = (System.currentTimeMillis() - startTime) / 1000.0;
long delta = startSize - nodes.size();
double rate = delta / (0.000001 + time);
System.err.println(
String.format(
"%d merges have been made in %.2fs [%.2f/sec], machine has %d nodes.", delta, time, rate, nodes.size()));
dump("final.dot", writer.getRepresentation(g));
return g;
}
public void dump(String filename, String content) {
BufferedWriter out = null;
try {
out = new BufferedWriter(new FileWriter(filename));
out.write(content);
out.close();
} catch (IOException ex) {
//Ignore exception since it's no big deal if we can't dump output.
} finally {
try {
out.close();
} catch (IOException ex) {
}
}
}
}