/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.inference;
import java.util.HashSet;
import java.util.Set;
import java.util.Iterator;
import java.util.Collection;
import java.io.ObjectOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.TableFactor;
import cc.mallet.grmm.types.Variable;
/**
* The variable elimination algorithm for inference in graphical
* models.
*
* Created: Mon Sep 22 17:34:00 2003
*
* @author <a href="mailto:casutton@cs.umass.edu">Charles Sutton</a>
* @version $Id: VariableElimination.java,v 1.1 2007/10/22 21:37:49 mccallum Exp $
*/
public class VariableElimination
extends AbstractInferencer
{
private Factor eliminate (Collection allPhi, Variable node) {
HashSet phiSet = new HashSet();
/* collect the potentials that include this variable */
for (Iterator j = allPhi.iterator(); j.hasNext(); ) {
Factor cpf = (Factor) j.next ();
if (cpf.varSet().isEmpty() || cpf.containsVar (node)) {
phiSet.add (cpf);
j.remove ();
}
}
return TableFactor.multiplyAll (phiSet);
}
/**
* The bulk of the variable-elimination algorithm. Returns the
* marginal density of the variable QUERY in the undirected
* model MODEL, except that the density is un-normalized.
* The normalization is done in a separate function to make
* computeNormalizationFactor easier.
*/
public Factor unnormalizedMarginal
(FactorGraph model, Variable query)
{
/* here the elimination order is random */
/* note that using buckets would make this more efficient as
well. */
/* make a copy of potentials */
HashSet allPhi = new HashSet();
for (Iterator i = model.factorsIterator (); i.hasNext(); ){
Factor factor = (Factor) i.next ();
allPhi.add(factor.duplicate());
}
Set nodes = model.variablesSet ();
/* Eliminate each node in turn */
for (Iterator i = nodes.iterator(); i.hasNext(); ) {
Variable node = (Variable) i.next();
if (node == query) continue; // Eliminate the query variable last!
Factor newCPF = eliminate (allPhi, node);
/* Extract (marginalize) over this variables */
Factor singleCPF;
if(newCPF.varSet().size() == 1) {
singleCPF = newCPF;
} else {
singleCPF = newCPF.marginalizeOut (node);
}
/* add it back to the list of potentials */
allPhi.add(singleCPF);
}
/* Now, all the potentials that are left should contain only the
* query variable.... UNLESS the graph is disconnected. So just
* eliminate the query var.
*/
Factor marginal = eliminate (allPhi, query);
assert marginal.containsVar (query);
assert marginal.varSet().size() == 1;
return marginal;
}
/**
* Computes the normalization constant for a model.
*/
public double computeNormalizationFactor (FactorGraph m) {
/* What we'll do is get the unnormalized marginal of an arbitrary
* node; then sum the marginal to get the normalization factor. */
Variable var = (Variable) m.variablesSet ().iterator().next();
Factor marginal = unnormalizedMarginal (m, var);
return marginal.sum ();
}
transient FactorGraph mdlCurrent;
// Inert. All work done in lookupMarginal().
public void computeMarginals (FactorGraph m)
{
mdlCurrent = m;
}
public Factor lookupMarginal (Variable var)
{
Factor marginal = unnormalizedMarginal (mdlCurrent, var);
marginal.normalize();
return marginal;
}
// Serialization
private static final long serialVersionUID = 1;
// If seralization-incompatible changes are made to these classes,
// then smarts can be added to these methods for backward compatibility.
private void writeObject (ObjectOutputStream out) throws IOException {
out.defaultWriteObject ();
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
in.defaultReadObject ();
}
}