/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.inference;
import java.util.Collection;
import java.util.Iterator;
import java.util.logging.Logger;
import java.util.logging.Level;
import java.io.Serializable;
import java.io.ObjectOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.HashVarSet;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.util.MalletLogger;
* An implementation of Hugin-style propagation for junction trees.
* This destructively modifies the junction tree so that its clique potentials
* are the true marginals of the underlying graph.
* <p/>
* End users will not usually need to use this class directly. Use
* <tt>JunctionTreeInferencer</tt> instead.
* <p/>
* This class is not an instance of Inferencer because it destructively
* modifies the junction tree, which the Inferencer methods do not do to
* factor graphs.
* <p/>
* Created: Feb 1, 2006
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: JunctionTreePropagation.java,v 1.1 2007/10/22 21:37:49 mccallum Exp $
class JunctionTreePropagation implements Serializable {
private static Logger logger = MalletLogger.getLogger (JunctionTreePropagation.class.getName ());
transient private int totalMessagesSent = 0;
private MessageStrategy strategy;
public JunctionTreePropagation (MessageStrategy strategy)
this.strategy = strategy;
public static JunctionTreePropagation createSumProductInferencer ()
return new JunctionTreePropagation (new SumProductMessageStrategy ());
public static JunctionTreePropagation createMaxProductInferencer ()
return new JunctionTreePropagation (new MaxProductMessageStrategy ());
public int getTotalMessagesSent ()
return totalMessagesSent;
public void computeMarginals (JunctionTree jt)
propagate (jt);
jt.normalizeAll (); // Necessary if jt originally unnormalized
/* Hugin-style propagation for junction trees */
// bottom-up pass
private void collectEvidence (JunctionTree jt, VarSet parent, VarSet child)
logger.finer ("collectEvidence " + parent + " --> " + child);
for (Iterator it = jt.getChildren (child).iterator (); it.hasNext ();) {
VarSet gchild = (VarSet) it.next ();
collectEvidence (jt, child, gchild);
if (parent != null) {
strategy.sendMessage (jt, child, parent);
// top-down pass
private void distributeEvidence (JunctionTree jt, VarSet parent)
for (Iterator it = jt.getChildren (parent).iterator (); it.hasNext ();) {
VarSet child = (VarSet) it.next ();
strategy.sendMessage (jt, parent, child);
distributeEvidence (jt, child);
private void propagate (JunctionTree jt)
VarSet root = (VarSet) jt.getRoot ();
collectEvidence (jt, null, root);
distributeEvidence (jt, root);
public Factor lookupMarginal (JunctionTree jt, VarSet varSet)
if (jt == null) { throw new IllegalStateException ("Call computeMarginals() first."); }
VarSet parent = jt.findParentCluster (varSet);
if (parent == null) {
throw new UnsupportedOperationException
("No parent cluster in " + jt + " for clique " + varSet);
Factor cpf = jt.getCPF (parent);
if (logger.isLoggable (Level.FINER)) {
logger.finer ("Lookup jt marginal: clique " + varSet + " cluster " + parent);
logger.finest (" cpf " + cpf);
Factor marginal = strategy.extractBelief (cpf, varSet);
marginal.normalize ();
return marginal;
public Factor lookupMarginal (JunctionTree jt, Variable var)
if (jt == null) { throw new IllegalStateException ("Call computeMarginals() first."); }
VarSet parent = jt.findParentCluster (var);
Factor cpf = jt.getCPF (parent);
if (logger.isLoggable (Level.FINER)) {
logger.finer ("Lookup jt marginal: var " + var + " cluster " + parent);
logger.finest (" cpf " + cpf);
Factor marginal = strategy.extractBelief (cpf, new HashVarSet (new Variable[] { var }));
marginal.normalize ();
return marginal;
* Implements a strategy pattern for message sending. This allows sum-product
* and max-product messages, e.g., to be different implementations of this strategy.
public interface MessageStrategy {
* Sends a message from the clique FROM to TO in a junction tree.
public void sendMessage (JunctionTree jt, VarSet from, VarSet to);
public Factor extractBelief (Factor cpf, VarSet varSet);
public static class SumProductMessageStrategy implements MessageStrategy, Serializable {
* This sends a sum-product message, normalized to avoid
* underflow.
public void sendMessage (JunctionTree jt, VarSet from, VarSet to)
Collection sepset = jt.getSepset (from, to);
Factor fromCpf = jt.getCPF (from);
Factor toCpf = jt.getCPF (to);
Factor oldSepsetPot = jt.getSepsetPot (from, to);
Factor lambda = fromCpf.marginalize (sepset);
lambda.normalize ();
jt.setSepsetPot (lambda, from, to);
toCpf = toCpf.multiply (lambda);
toCpf.divideBy (oldSepsetPot);
toCpf.normalize ();
jt.setCPF (to, toCpf);
public Factor extractBelief (Factor cpf, VarSet varSet)
return cpf.marginalize (varSet);
// Serialization
private static final long serialVersionUID = 1;
private static final int CUURENT_SERIAL_VERSION = 1;
private void writeObject (ObjectOutputStream out) throws IOException
out.defaultWriteObject ();
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
in.defaultReadObject ();
in.readInt (); // version
public static class MaxProductMessageStrategy implements MessageStrategy, Serializable {
* This sends a max-product message.
public void sendMessage (JunctionTree jt, VarSet from, VarSet to)
// System.err.println ("Send message "+from+" --> "+to);
Collection sepset = jt.getSepset (from, to);
Factor fromCpf = jt.getCPF (from);
Factor toCpf = jt.getCPF (to);
Factor oldSepsetPot = jt.getSepsetPot (from, to);
Factor lambda = fromCpf.extractMax (sepset);
lambda.normalize ();
jt.setSepsetPot (lambda, from, to);
toCpf = toCpf.multiply (lambda);
toCpf.divideBy (oldSepsetPot);
toCpf.normalize ();
jt.setCPF (to, toCpf);
public Factor extractBelief (Factor cpf, VarSet varSet)
return cpf.extractMax (varSet);
// Serialization
private static final long serialVersionUID = 1;
private static final int CUURENT_SERIAL_VERSION = 1;
private void writeObject (ObjectOutputStream out) throws IOException
out.defaultWriteObject ();
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
in.defaultReadObject ();
in.readInt (); // version
// Serialization
private static final long serialVersionUID = 1;
private static final int CUURENT_SERIAL_VERSION = 1;
private void writeObject (ObjectOutputStream out) throws IOException
out.defaultWriteObject ();
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
in.defaultReadObject ();
in.readInt (); // version