/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.util;
import java.util.*;
import cc.mallet.grmm.inference.Inferencer;
import cc.mallet.grmm.inference.JunctionTree;
import cc.mallet.grmm.inference.JunctionTreeInferencer;
import cc.mallet.grmm.types.*;
import gnu.trove.THashSet;
/**
* Static utilities that do useful things with factor graphs.
*
* Created: Sep 22, 2005
*
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: Models.java,v 1.1 2007/10/22 21:37:58 mccallum Exp $
*/
public class Models {
/**
* Returns a new factor graph, the same as a given one, except that all the nodes in
* the given Assignment are clamped as evidence.
* @param mdl Old model. Will not be modified.
* @param assn Evidence to add
* @return A new factor graph.
*/
public static FactorGraph addEvidence (FactorGraph mdl, Assignment assn)
{
return addEvidence (mdl, assn, null);
}
public static FactorGraph addEvidence (FactorGraph mdl, Assignment assn, Map toSlicedMap)
{
FactorGraph newMdl = new FactorGraph (mdl.numVariables ());
addSlicedPotentials (mdl, newMdl, assn, toSlicedMap);
return newMdl;
}
public static UndirectedModel addEvidence (UndirectedModel mdl, Assignment assn)
{
UndirectedModel newMdl = new UndirectedModel (mdl.numVariables ());
addSlicedPotentials (mdl, newMdl, assn, null);
return newMdl;
}
private static void addSlicedPotentials (FactorGraph fromMdl, FactorGraph toMdl, Assignment assn, Map toSlicedMap)
{
Set inputVars = new THashSet (Arrays.asList (assn.getVars ()));
Set remainingVars = new THashSet (fromMdl.variablesSet ());
remainingVars.removeAll (inputVars);
for (Iterator it = fromMdl.factorsIterator (); it.hasNext ();) {
Factor ptl = (Factor) it.next ();
Set theseVars = new THashSet (ptl.varSet ());
theseVars.retainAll (remainingVars);
Factor slicedPtl = ptl.slice (assn);
toMdl.addFactor (slicedPtl);
if (toSlicedMap != null) {
toSlicedMap.put (ptl, slicedPtl);
}
}
}
/**
* Returns the highest-score Assignment in a model according to a given inferencer.
* @param mdl Factor graph to use
* @param inf Inferencer to use. No need to call <tt>computeMarginals</tt> first.
* @return An Assignment
*/
public static Assignment bestAssignment (FactorGraph mdl, Inferencer inf)
{
inf.computeMarginals (mdl);
int[] outcomes = new int [mdl.numVariables ()];
for (int i = 0; i < outcomes.length; i++) {
Variable var = mdl.get (i);
int best = inf.lookupMarginal (var).argmax ();
outcomes[i] = best;
}
return new Assignment (mdl, outcomes);
}
/**
* Computes the exact entropy of a factor graph distribution using the junction tree algorithm.
* If the model is intractable, then this method won't return a number anytime soon.
*/
public static double entropy (FactorGraph mdl)
{
JunctionTreeInferencer inf = new JunctionTreeInferencer ();
inf.computeMarginals (mdl);
JunctionTree jt = inf.lookupJunctionTree ();
return jt.entropy ();
}
/**
* Computes the KL divergence <tt>KL(mdl1||mdl2)</tt>. Junction tree is used to compute the entropy.
* <p>
* TODO: This probably won't handle when the jnuction tree for MDL2 contains a clique that's not present in the
* junction tree for mdl1. If so, this is a bug.
*
* @param mdl1
* @param mdl2
* @return KL(mdl1||mdl2)
*/
public static double KL (FactorGraph mdl1, FactorGraph mdl2)
{
JunctionTreeInferencer inf1 = new JunctionTreeInferencer ();
inf1.computeMarginals (mdl1);
JunctionTree jt1 = inf1.lookupJunctionTree ();
JunctionTreeInferencer inf2 = new JunctionTreeInferencer ();
inf2.computeMarginals (mdl2);
JunctionTree jt2 = inf2.lookupJunctionTree ();
double entropy = jt1.entropy ();
double energy = 0;
for (Iterator it = jt2.clusterPotentials ().iterator(); it.hasNext();) {
Factor marg2 = (Factor) it.next ();
Factor marg1 = inf1.lookupMarginal (marg2.varSet ());
for (AssignmentIterator assnIt = marg2.assignmentIterator (); assnIt.hasNext();) {
energy += marg1.value (assnIt) * marg2.logValue (assnIt);
assnIt.advance();
}
}
for (Iterator it = jt2.sepsetPotentials ().iterator(); it.hasNext();) {
Factor marg2 = (Factor) it.next ();
Factor marg1 = inf1.lookupMarginal (marg2.varSet ());
for (AssignmentIterator assnIt = marg2.assignmentIterator (); assnIt.hasNext();) {
energy -= marg1.value (assnIt) * marg2.logValue (assnIt);
assnIt.advance();
}
}
return -entropy - energy;
}
public static void removeConstantFactors (FactorGraph sliced)
{
List factors = new ArrayList (sliced.factors ());
for (Iterator it = factors.iterator (); it.hasNext();) {
Factor factor = (Factor) it.next ();
if (factor instanceof ConstantFactor) {
sliced.divideBy (factor);
}
}
}
}