/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://mallet.cs.umass.edu/
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.inference;
import java.util.Collection;
import java.util.Iterator;
import java.util.logging.Logger;
import java.util.logging.Level;
import java.io.Serializable;
import java.io.ObjectOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.HashVarSet;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.util.MalletLogger;
/**
* An implementation of Hugin-style propagation for junction trees.
* This destructively modifies the junction tree so that its clique potentials
* are the true marginals of the underlying graph.
* <p/>
* End users will not usually need to use this class directly. Use
* <tt>JunctionTreeInferencer</tt> instead.
* <p/>
* This class is not an instance of Inferencer because it destructively
* modifies the junction tree, which the Inferencer methods do not do to
* factor graphs.
* <p/>
* Created: Feb 1, 2006
*
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: JunctionTreePropagation.java,v 1.1 2007/10/22 21:37:49 mccallum Exp $
*/
class JunctionTreePropagation implements Serializable {
private static Logger logger = MalletLogger.getLogger (JunctionTreePropagation.class.getName ());
transient private int totalMessagesSent = 0;
private MessageStrategy strategy;
public JunctionTreePropagation (MessageStrategy strategy)
{
this.strategy = strategy;
}
public static JunctionTreePropagation createSumProductInferencer ()
{
return new JunctionTreePropagation (new SumProductMessageStrategy ());
}
public static JunctionTreePropagation createMaxProductInferencer ()
{
return new JunctionTreePropagation (new MaxProductMessageStrategy ());
}
public int getTotalMessagesSent ()
{
return totalMessagesSent;
}
public void computeMarginals (JunctionTree jt)
{
propagate (jt);
jt.normalizeAll (); // Necessary if jt originally unnormalized
}
/* Hugin-style propagation for junction trees */
// bottom-up pass
private void collectEvidence (JunctionTree jt, VarSet parent, VarSet child)
{
logger.finer ("collectEvidence " + parent + " --> " + child);
for (Iterator it = jt.getChildren (child).iterator (); it.hasNext ();) {
VarSet gchild = (VarSet) it.next ();
collectEvidence (jt, child, gchild);
}
if (parent != null) {
totalMessagesSent++;
strategy.sendMessage (jt, child, parent);
}
}
// top-down pass
private void distributeEvidence (JunctionTree jt, VarSet parent)
{
for (Iterator it = jt.getChildren (parent).iterator (); it.hasNext ();) {
VarSet child = (VarSet) it.next ();
totalMessagesSent++;
strategy.sendMessage (jt, parent, child);
distributeEvidence (jt, child);
}
}
private void propagate (JunctionTree jt)
{
VarSet root = (VarSet) jt.getRoot ();
collectEvidence (jt, null, root);
distributeEvidence (jt, root);
}
public Factor lookupMarginal (JunctionTree jt, VarSet varSet)
{
if (jt == null) { throw new IllegalStateException ("Call computeMarginals() first."); }
VarSet parent = jt.findParentCluster (varSet);
if (parent == null) {
throw new UnsupportedOperationException
("No parent cluster in " + jt + " for clique " + varSet);
}
Factor cpf = jt.getCPF (parent);
if (logger.isLoggable (Level.FINER)) {
logger.finer ("Lookup jt marginal: clique " + varSet + " cluster " + parent);
logger.finest (" cpf " + cpf);
}
Factor marginal = strategy.extractBelief (cpf, varSet);
marginal.normalize ();
return marginal;
}
public Factor lookupMarginal (JunctionTree jt, Variable var)
{
if (jt == null) { throw new IllegalStateException ("Call computeMarginals() first."); }
VarSet parent = jt.findParentCluster (var);
Factor cpf = jt.getCPF (parent);
if (logger.isLoggable (Level.FINER)) {
logger.finer ("Lookup jt marginal: var " + var + " cluster " + parent);
logger.finest (" cpf " + cpf);
}
Factor marginal = strategy.extractBelief (cpf, new HashVarSet (new Variable[] { var }));
marginal.normalize ();
return marginal;
}
///////////////////////////////////////////////////////////////////////////
// MEESAGE STRATEGIES
///////////////////////////////////////////////////////////////////////////
/**
* Implements a strategy pattern for message sending. This allows sum-product
* and max-product messages, e.g., to be different implementations of this strategy.
*/
public interface MessageStrategy {
/**
* Sends a message from the clique FROM to TO in a junction tree.
*/
public void sendMessage (JunctionTree jt, VarSet from, VarSet to);
public Factor extractBelief (Factor cpf, VarSet varSet);
}
public static class SumProductMessageStrategy implements MessageStrategy, Serializable {
/**
* This sends a sum-product message, normalized to avoid
* underflow.
*/
public void sendMessage (JunctionTree jt, VarSet from, VarSet to)
{
Collection sepset = jt.getSepset (from, to);
Factor fromCpf = jt.getCPF (from);
Factor toCpf = jt.getCPF (to);
Factor oldSepsetPot = jt.getSepsetPot (from, to);
Factor lambda = fromCpf.marginalize (sepset);
lambda.normalize ();
jt.setSepsetPot (lambda, from, to);
toCpf = toCpf.multiply (lambda);
toCpf.divideBy (oldSepsetPot);
toCpf.normalize ();
jt.setCPF (to, toCpf);
}
public Factor extractBelief (Factor cpf, VarSet varSet)
{
return cpf.marginalize (varSet);
}
// Serialization
private static final long serialVersionUID = 1;
private static final int CUURENT_SERIAL_VERSION = 1;
private void writeObject (ObjectOutputStream out) throws IOException
{
out.defaultWriteObject ();
out.writeInt (CUURENT_SERIAL_VERSION);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
{
in.defaultReadObject ();
in.readInt (); // version
}
}
public static class MaxProductMessageStrategy implements MessageStrategy, Serializable {
/**
* This sends a max-product message.
*/
public void sendMessage (JunctionTree jt, VarSet from, VarSet to)
{
// System.err.println ("Send message "+from+" --> "+to);
Collection sepset = jt.getSepset (from, to);
Factor fromCpf = jt.getCPF (from);
Factor toCpf = jt.getCPF (to);
Factor oldSepsetPot = jt.getSepsetPot (from, to);
Factor lambda = fromCpf.extractMax (sepset);
lambda.normalize ();
jt.setSepsetPot (lambda, from, to);
toCpf = toCpf.multiply (lambda);
toCpf.divideBy (oldSepsetPot);
toCpf.normalize ();
jt.setCPF (to, toCpf);
}
public Factor extractBelief (Factor cpf, VarSet varSet)
{
return cpf.extractMax (varSet);
}
// Serialization
private static final long serialVersionUID = 1;
private static final int CUURENT_SERIAL_VERSION = 1;
private void writeObject (ObjectOutputStream out) throws IOException
{
out.defaultWriteObject ();
out.writeInt (CUURENT_SERIAL_VERSION);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
{
in.defaultReadObject ();
in.readInt (); // version
}
}
// Serialization
private static final long serialVersionUID = 1;
private static final int CUURENT_SERIAL_VERSION = 1;
private void writeObject (ObjectOutputStream out) throws IOException
{
out.defaultWriteObject ();
out.writeInt (CUURENT_SERIAL_VERSION);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
{
in.defaultReadObject ();
in.readInt (); // version
}
}