/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.inference;
import org._3pq.jgrapht.GraphHelper;
import org._3pq.jgrapht.UndirectedGraph;
import org._3pq.jgrapht.alg.ConnectivityInspector;
import org._3pq.jgrapht.graph.SimpleGraph;
import org._3pq.jgrapht.graph.ListenableUndirectedGraph;
import org._3pq.jgrapht.traverse.BreadthFirstIterator;
import cc.mallet.grmm.types.*;
import cc.mallet.grmm.util.Graphs;
import cc.mallet.types.Alphabet;
import cc.mallet.util.MalletLogger;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.*;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* Does inference in general graphical models using
* the Hugin junction tree algorithm.
*
* Created: Mon Nov 10 23:58:44 2003
*
* @author <a href="mailto:casutton@cs.umass.edu">Charles Sutton</a>
* @version $Id: JunctionTreeInferencer.java,v 1.1 2007/10/22 21:37:49 mccallum Exp $
*/
public class JunctionTreeInferencer extends AbstractInferencer {
private static Logger logger = MalletLogger.getLogger(JunctionTreeInferencer.class.getName());
private boolean inLogSpace;
private JunctionTreePropagation propagator;
public JunctionTreeInferencer()
{
this (JunctionTreePropagation.createSumProductInferencer ());
} // JunctionTreeInferencer constructor
public JunctionTreeInferencer (JunctionTreePropagation propagator)
{
this.propagator = propagator;
}
public static JunctionTreeInferencer createForMaxProduct ()
{
return new JunctionTreeInferencer (JunctionTreePropagation.createMaxProductInferencer ());
}
private boolean isAdjacent (UndirectedGraph g, Variable v1, Variable v2)
{
return g.getEdge (v1, v2) != null;
}
transient protected JunctionTree jtCurrent;
transient private ArrayList cliques;
/**
* Returns the number of edges that would be added to a graph if a
* given vertex would be removed in the triangulation procedure.
* The return value is the number of edges in the elimination
* clique of V that are not already present.
*/
private int newEdgesRequired(UndirectedGraph mdl, Variable v)
{
int rating = 0;
for (Iterator it1 = neighborsIterator (mdl,v); it1.hasNext();) {
Variable neighbor1 = (Variable) it1.next();
Iterator it2 = neighborsIterator (mdl,v);
while (it2.hasNext()) {
Variable neighbor2 = (Variable) it2.next();
if (neighbor1 != neighbor2) {
if (!isAdjacent (mdl, neighbor1, neighbor2)) {
rating++;
}
}
}
}
// System.out.println(v+" = "+rating);
return rating;
}
/**
* Returns the weight of the clique that would be added to a graph if a
* given vertex would be removed in the triangulation procedure.
* The return value is the number of edges in the elimination
* clique of V that are not already present.
*/
private int weightRequired (UndirectedGraph mdl, Variable v)
{
int rating = 1;
for (Iterator it1 = neighborsIterator (mdl,v); it1.hasNext();) {
Variable neighbor = (Variable) it1.next();
rating *= neighbor.getNumOutcomes();
}
// System.out.println(v+" = "+rating);
return rating;
}
private void connectNeighbors(UndirectedGraph mdl, Variable v)
{
for (Iterator it1 = neighborsIterator(mdl,v); it1.hasNext();) {
Variable neighbor1 = (Variable) it1.next();
Iterator it2 = neighborsIterator(mdl,v);
while (it2.hasNext()) {
Variable neighbor2 = (Variable) it2.next();
if (neighbor1 != neighbor2) {
if (!isAdjacent (mdl, neighbor1, neighbor2)) {
try {
mdl.addEdge(neighbor1, neighbor2);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
}
}
}
// xx should refactor into Collections.any (Coll, TObjectProc)
/* Return true iff a clique in L strictly contains c. */
private boolean findSuperClique(List l, VarSet c)
{
for (Iterator it = l.iterator(); it.hasNext();) {
VarSet c2 = (VarSet) it.next();
if (c2.containsAll(c)) {
return true;
}
}
return false;
}
// works like the obscure <=> operator in Perl.
private static int cmp(int i1, int i2)
{
if (i1 < i2) {
return -1;
} else if (i1 > i2) {
return 1;
} else {
return 0;
}
}
public Variable pickVertexToRemove (UndirectedGraph mdl, ArrayList lst)
{
Iterator it = lst.iterator();
Variable best = (Variable) it.next();
int bestVal1 = newEdgesRequired (mdl, best);
int bestVal2 = weightRequired (mdl, best);
while (it.hasNext()) {
Variable v = (Variable) it.next();
int val = newEdgesRequired (mdl, v);
if (val < bestVal1) {
best = v;
bestVal1 = val;
bestVal2 = weightRequired (mdl, v);
} else if (val == bestVal1) {
int val2 = weightRequired (mdl, v);
if (val2 < bestVal2) {
best = v;
bestVal1 = val;
bestVal2 = val2;
}
}
}
return best;
}
/**
* Adds edges to graph until it is triangulated.
*/
private void triangulate(final UndirectedGraph mdl)
{
UndirectedGraph mdl2 = dupGraph (mdl);
ArrayList vars = new ArrayList(mdl.vertexSet());
Alphabet varMap = makeVertexMap(vars);
cliques = new ArrayList();
// debug
if (logger.isLoggable (Level.FINER)) {
logger.finer ("Triangulating model: "+mdl);
String ret = "";
for (int i = 0; i < vars.size(); i++) {
Variable next = (Variable) vars.get(i);
ret += next.toString() + "\n"; // " (" + mdl.getIndex(next) + ")\n ";
}
logger.finer(ret);
}
while (!vars.isEmpty()) {
Variable v = (Variable) pickVertexToRemove (mdl2, vars);
logger.finer("Triangulating vertex " + v);
VarSet varSet = new BitVarSet (v.getUniverse (), GraphHelper.neighborListOf (mdl2, v));
varSet.add(v);
if (!findSuperClique(cliques, varSet)) {
cliques.add(varSet);
if (logger.isLoggable (Level.FINER)) {
logger.finer (" Elim clique " + varSet + " size " + varSet.size () + " weight " + varSet.weight ());
}
}
// must remove V from graph first, because adding the edges
// will change the rating of other vertices
connectNeighbors (mdl2, v);
vars.remove(v);
mdl2.removeVertex (v);
}
if (logger.isLoggable(Level.FINE)) {
logger.fine("Triangulation done. Cliques are: ");
int totSize = 0, totWeight = 0, maxSize = 0, maxWeight = 0;
for (Iterator it = cliques.iterator(); it.hasNext();) {
VarSet c = (VarSet) it.next();
logger.finer(c.toString());
totSize += c.size();
maxSize = Math.max(c.size(), maxSize);
totWeight += c.weight();
maxWeight = Math.max(c.weight(), maxWeight);
}
double sz = cliques.size();
logger.fine("Jt created " + sz + " cliques. Size: avg " + (totSize / sz)
+ " max " + (maxSize) + " Weight: avg " + (totWeight / sz)
+ " max " + (maxWeight));
}
}
private Alphabet makeVertexMap(ArrayList vars)
{
Alphabet map = new Alphabet (vars.size (), Variable.class);
map.lookupIndices(vars.toArray(), true);
return map;
}
private static int sepsetSize(BitVarSet[] pair)
{
assert pair.length == 2;
return pair[0].intersectionSize(pair[1]);
}
private static int sepsetCost(VarSet[] pair)
{
assert pair.length == 2;
return pair[0].weight() + pair[1].weight();
}
// Given two pairs of cliques, returns -1 if the pair o1 should be
// added to the tree first. We add pairs that have the largest
// mass (number of vertices in common) to ensure that the clique
// tree satifies the running intersection property.
private static Comparator sepsetChooser = new Comparator() {
public int compare(Object o1, Object o2)
{
if (o1 == o2) return 0;
BitVarSet[] pair1 = (BitVarSet[]) o1;
BitVarSet[] pair2 = (BitVarSet[]) o2;
int size1 = sepsetSize(pair1);
int size2 = sepsetSize(pair2);
int retval = -cmp(size1, size2);
if (retval == 0) {
// Break ties by adding the sepset with the
// smallest cost (sum of weights of connected clusters)
int cost1 = sepsetCost(pair1);
int cost2 = sepsetCost(pair2);
retval = cmp(cost1, cost2);
// Still a tie? Break arbitrarily but consistently.
if (retval == 0) {
retval = cmp (o1.hashCode (), o2.hashCode ());
}
}
return retval;
}
};
private JunctionTree graphToJt (UndirectedGraph g)
{
JunctionTree jt = new JunctionTree (g.vertexSet ().size ());
Object root = g.vertexSet ().iterator ().next ();
jt.add (root);
for (Iterator it1 = new BreadthFirstIterator (g, root); it1.hasNext ();) {
Object v1 = it1.next ();
for (Iterator it2 = GraphHelper.neighborListOf (g, v1).iterator (); it2.hasNext ();) {
Object v2 = it2.next ();
if (jt.getParent (v1) != v2) {
jt.addNode (v1, v2);
}
}
}
return jt;
}
private JunctionTree buildJtStructure()
{
TreeSet pq = new TreeSet(sepsetChooser);
// Initialize pq with all possible edges...
for (Iterator it = cliques.iterator(); it.hasNext();) {
BitVarSet c1 = (BitVarSet) it.next();
for (Iterator it2 = cliques.iterator(); it2.hasNext();) {
BitVarSet c2 = (BitVarSet) it2.next();
if (c1 == c2) break;
pq.add(new BitVarSet[]{c1, c2});
}
}
// ...and add the edges to jt that come to the top of the queue
// and don't cause a cycle.
// xxx OK, this sucks. openjgraph doesn't allow adding
// disconnected edges to a tree, so what we'll do is create a
// Graph frist, then convert it to a Tree.
ListenableUndirectedGraph g = new ListenableUndirectedGraph (new SimpleGraph ());
// first add every clique to the graph
for (Iterator it = cliques.iterator(); it.hasNext();) {
VarSet c = (VarSet) it.next();
g.addVertex (c);
}
ConnectivityInspector inspector = new ConnectivityInspector (g);
g.addGraphListener (inspector);
// then add n - 1 edges
int numCliques = cliques.size();
int edgesAdded = 0;
while (edgesAdded < numCliques - 1) {
VarSet[] pair = (VarSet[]) pq.first();
pq.remove(pair);
if (!inspector.pathExists(pair[0], pair[1])) {
g.addEdge(pair[0], pair[1]);
edgesAdded++;
}
}
JunctionTree jt = graphToJt(g);
if (logger.isLoggable (Level.FINER)) {
logger.finer (" jt structure was " + jt);
}
return jt;
}
private void initJtCpts(FactorGraph mdl, JunctionTree jt)
{
for (Iterator it = jt.getVerticesIterator(); it.hasNext();) {
VarSet c = (VarSet) it.next();
// DiscreteFactor ptl = createBlankFactor (c);
// jt.setCPF(c, ptl);
jt.setCPF (c, new ConstantFactor (1.0));
}
for (Iterator it = mdl.factors ().iterator(); it.hasNext();) {
Factor ptl = (Factor) it.next();
VarSet parent = jt.findParentCluster(ptl.varSet());
assert parent != null
: "Unable to find parent cluster for ptl " + ptl + "in jt " + jt;
Factor cpf = jt.getCPF(parent);
Factor newCpf = cpf.multiply(ptl);
jt.setCPF (parent, newCpf);
/* debug
if (jt.isNaN()) {
throw new RuntimeException ("Got a NaN");
}
*/
}
}
private AbstractTableFactor createBlankFactor (VarSet c)
{
if (inLogSpace) {
return new LogTableFactor (c);
} else {
return new TableFactor (c);
}
}
public void computeMarginals (FactorGraph mdl)
{
inLogSpace = mdl.getFactor (0) instanceof LogTableFactor;
buildJunctionTree(mdl);
propagator.computeMarginals(jtCurrent);
totalMessagesSent += propagator.getTotalMessagesSent();
}
public void computeMarginals (JunctionTree jt)
{
inLogSpace = false; //??
jtCurrent = jt;
propagator.computeMarginals(jtCurrent);
totalMessagesSent += propagator.getTotalMessagesSent();
}
/**
* Constructs a junction tree from a given factor graph. Does not perform BP in the resulting
* graph. So this gives you the structure of a jnuction tree, but the factors don't correspond
* to the true marginals unless you call BP yourself.
* @param mdl Factor graph to compute JT for.
*/
public JunctionTree buildJunctionTree(FactorGraph mdl)
{
jtCurrent = (JunctionTree) mdl.getInferenceCache(JunctionTreeInferencer.class);
if (jtCurrent != null) {
jtCurrent.clearCPFs();
} else {
/* The graph g is the topology of the MRF that corresponds to the factor graph mdl.
* Essentially, this means that we triangulate factor graphs by converting to an MRF first.
* I could have chosen to trianglualte the FactorGraph directly, but I didn't for historical reasons
* (I already had a version of triangulate() for MRFs, not bipartite factor graphs.)
* Note that the call to mdlToGraph() is perfectly valid for FactorGraphs that are also DirectedModels,
* and has the effect of moralizing in that case. */
UndirectedGraph g = Graphs.mdlToGraph (mdl);
triangulate (g);
jtCurrent = buildJtStructure();
mdl.setInferenceCache(JunctionTreeInferencer.class, jtCurrent);
}
initJtCpts(mdl, jtCurrent);
return jtCurrent;
}
private UndirectedGraph dupGraph (UndirectedGraph original)
{
UndirectedGraph copy = new SimpleGraph ();
GraphHelper.addGraph (copy, original);
return copy;
}
public Factor lookupMarginal(Variable var)
{
return propagator.lookupMarginal (jtCurrent, var);
}
public Factor lookupMarginal(VarSet varSet)
{
return propagator.lookupMarginal (jtCurrent, varSet);
}
public double lookupLogJoint(Assignment assn)
{
return jtCurrent.lookupLogJoint(assn);
}
public double dumpLogJoint(Assignment assn)
{
return jtCurrent.dumpLogJoint(assn);
}
/**
* Returns the JunctionTree computed from the last call to
* {@link #computeMarginals}. Caller must not modify return value.
*/
public JunctionTree lookupJunctionTree ()
{
return jtCurrent;
}
private Iterator neighborsIterator (UndirectedGraph g, Variable v)
{
return GraphHelper.neighborListOf (g, v).iterator ();
}
public void dump ()
{
if (jtCurrent != null) {
System.out.println("Current junction tree");
jtCurrent.dump();
} else {
System.out.println("NO current junction tree");
}
}
transient private int totalMessagesSent = 0;
/**
* Returns the total number of messages this inferencer has sent.
*/
public int getTotalMessagesSent () { return totalMessagesSent; }
// Serialization
private static final long serialVersionUID = 1;
// If seralization-incompatible changes are made to these classes,
// then smarts can be added to these methods for backward compatibility.
private void writeObject (ObjectOutputStream out) throws IOException {
out.defaultWriteObject ();
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
in.defaultReadObject ();
}
} // JunctionTreeInferencer