/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.test;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.*;
import org._3pq.jgrapht.UndirectedGraph;
import org._3pq.jgrapht.GraphHelper;
import cc.mallet.grmm.inference.RandomGraphs;
import cc.mallet.grmm.types.*;
import cc.mallet.grmm.util.Graphs;
import cc.mallet.util.ArrayUtils;
/**
* Created: Mar 17, 2005
*
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: TestUndirectedModel.java,v 1.1 2007/10/22 21:37:41 mccallum Exp $
*/
public class TestUndirectedModel extends TestCase {
public TestUndirectedModel (String name)
{
super (name);
}
public void testOutputToDot () throws IOException
{
FactorGraph mdl = TestInference.createRandomGrid (3, 4, 2, new Random (4234));
PrintWriter out = new PrintWriter (new FileWriter (new File ("grmm-model.dot")));
mdl.printAsDot (out);
out.close ();
System.out.println ("Now you can open up grmm-model.dot in Graphviz.");
}
/**
* Tests that models can be created that have multiple factors over the same variable, and that
* potentialOfVertex returns the product in that case.
*/
public void testMultipleNodePotentials ()
{
Variable var = new Variable (2);
FactorGraph mdl = new FactorGraph (new Variable[]{var});
Factor ptl1 = new TableFactor (var, new double[]{0.5, 0.5});
mdl.addFactor (ptl1);
Factor ptl2 = new TableFactor (var, new double[]{0.25, 0.25});
mdl.addFactor (ptl2);
// verify that factorOf(var) doesn't work
try {
mdl.factorOf (var);
fail ();
} catch (RuntimeException e) {} // expected
List factors = mdl.allFactorsOf (var);
Factor total = TableFactor.multiplyAll (factors);
double[] expected = {0.125, 0.125};
assertTrue ("Arrays not equal\n Expected " + ArrayUtils.toString (expected)
+ "\n Actual " + ArrayUtils.toString (((TableFactor) total).toValueArray ()),
Arrays.equals (expected, ((TableFactor) total).toValueArray ()));
}
/**
* Tests that models can be created that have multiple factors over the same edge, and that
* potentialOfEdge returns the product in that case.
*/
public void testMultipleEdgePotentials ()
{
Variable v1 = new Variable (2);
Variable v2 = new Variable (2);
Variable[] vars = new Variable[]{v1, v2};
FactorGraph mdl = new FactorGraph (vars);
Factor ptl1 = new TableFactor (vars, new double[]{0.5, 0.5, 0.5, 0.5});
mdl.addFactor (ptl1);
Factor ptl2 = new TableFactor (vars, new double[]{0.25, 0.25, 0.5, 0.5});
mdl.addFactor (ptl2);
try {
mdl.factorOf (v1, v2);
fail ();
} catch (RuntimeException e) {}
Collection factors = mdl.allFactorsContaining (new HashVarSet (vars));
assertEquals (2, factors.size ());
assertTrue (factors.contains (ptl1));
assertTrue (factors.contains (ptl2));
double[] vals = {0.125, 0.125, 0.25, 0.25};
Factor total = TableFactor.multiplyAll (factors);
Factor expected = new TableFactor (vars, vals);
assertTrue ("Arrays not equal\n Expected " + ArrayUtils.toString (vals)
+ "\n Actual " + ArrayUtils.toString (((TableFactor) total).toValueArray ()),
expected.almostEquals (total, 1e-10));
}
public void testPotentialConnections ()
{
Variable v1 = new Variable (2);
Variable v2 = new Variable (2);
Variable v3 = new Variable (2);
Variable[] vars = new Variable[]{v1, v2, v3};
FactorGraph mdl = new FactorGraph ();
TableFactor ptl = new TableFactor (vars, new double [8]);
mdl.addFactor (ptl);
assertTrue (mdl.isAdjacent (v1, v2));
assertTrue (mdl.isAdjacent (v2, v3));
assertTrue (mdl.isAdjacent (v1, v3));
}
public void testThreeNodeModel ()
{
Random r = new Random (23534709);
FactorGraph mdl = new FactorGraph ();
Variable root = new Variable (2);
Variable childL = new Variable (2);
Variable childR = new Variable (2);
mdl.addFactor (root, childL, RandomGraphs.generateMixedPotentialValues (r, 1.5));
mdl.addFactor (root, childR, RandomGraphs.generateMixedPotentialValues (r, 1.5));
// assertTrue (mdl.isConnected (root, childL));
// assertTrue (mdl.isConnected (root, childR));
// assertTrue (mdl.isConnected (childL, childR));
assertTrue (mdl.isAdjacent (root, childR));
assertTrue (mdl.isAdjacent (root, childL));
assertTrue (!mdl.isAdjacent (childL, childR));
assertTrue (mdl.factorOf (root, childL) != null);
assertTrue (mdl.factorOf (root, childR) != null);
}
// Verify that potentialOfVertex and potentialOfEdge (which use
// caches) are consistent with the potentials set.
public void testUndirectedCaches ()
{
List models = TestInference.createTestModels ();
for (Iterator it = models.iterator (); it.hasNext ();) {
FactorGraph mdl = (FactorGraph) it.next ();
verifyCachesConsistent (mdl);
}
}
private void verifyCachesConsistent (FactorGraph mdl)
{
Factor pot, pot2, pot3;
for (Iterator it = mdl.factors ().iterator (); it.hasNext ();) {
pot = (Factor) it.next ();
// System.out.println("Testing model "+i+" potential "+pot);
Object[] vars = pot.varSet ().toArray ();
switch (vars.length) {
case 1:
pot2 = mdl.factorOf ((Variable) vars[0]);
assertTrue (pot == pot2);
break;
case 2:
Variable var1 = (Variable) vars[0];
Variable var2 = (Variable) vars[1];
pot2 = mdl.factorOf (var1, var2);
pot3 = mdl.factorOf (var2, var1);
assertTrue (pot == pot2);
assertTrue (pot2 == pot3);
break;
// Factors of size > 2 aren't now cached.
default:
break;
}
}
}
// Verify that potentialOfVertex and potentialOfEdge (which use
// caches) are consistent with the potentials set even if a vertex is removed.
public void testUndirectedCachesAfterRemove ()
{
List models = TestInference.createTestModels ();
for (Iterator mdlIt = models.iterator (); mdlIt.hasNext ();) {
FactorGraph mdl = (FactorGraph) mdlIt.next ();
mdl = (FactorGraph) mdl.duplicate ();
mdl.remove (mdl.get (0));
// Verify that indexing correct
for (Iterator it = mdl.variablesIterator (); it.hasNext ();) {
Variable var = (Variable) it.next ();
int idx = mdl.getIndex (var);
assertTrue (idx >= 0);
assertTrue (idx < mdl.numVariables ());
}
// Verify that caches consistent
verifyCachesConsistent (mdl);
}
}
public void testMdlToGraph ()
{
List models = TestInference.createTestModels ();
for (Iterator mdlIt = models.iterator (); mdlIt.hasNext ();) {
UndirectedModel mdl = (UndirectedModel) mdlIt.next ();
UndirectedGraph g = Graphs.mdlToGraph (mdl);
Set vertices = g.vertexSet ();
// check the number of vertices
assertEquals (mdl.numVariables (), vertices.size ());
// check the number of edges
int numEdgePtls = 0;
for (Iterator factorIt = mdl.factors ().iterator (); factorIt.hasNext ();) {
Factor factor = (Factor) factorIt.next ();
if (factor.varSet ().size() == 2) numEdgePtls++;
}
assertEquals (numEdgePtls, g.edgeSet ().size ());
// check that the neighbors of each vertex contain at least some of what they're supposed to
Iterator it = vertices.iterator ();
while (it.hasNext ()) {
Variable var = (Variable) it.next ();
assertTrue (vertices.contains (var));
Set neighborsInG = new HashSet (GraphHelper.neighborListOf (g, var));
neighborsInG.add (var);
Iterator factorIt = mdl.allFactorsContaining (var).iterator ();
while (factorIt.hasNext ()) {
Factor factor = (Factor) factorIt.next ();
assertTrue (neighborsInG.containsAll (factor.varSet ()));
}
}
}
}
public void testFactorOfSet ()
{
Variable[] vars = new Variable [3];
for (int i = 0; i < vars.length; i++) {
vars[i] = new Variable (2);
}
Factor factor = new TableFactor (vars, new double[] { 0, 1, 2, 3, 4, 5, 6, 7 });
FactorGraph fg = new FactorGraph (vars);
fg.addFactor (factor);
assertTrue (factor == fg.factorOf (factor.varSet ()));
HashSet set = new HashSet (factor.varSet ());
assertTrue (factor == fg.factorOf (set));
set.remove (vars[0]);
assertTrue (null == fg.factorOf (set));
}
public static Test suite ()
{
return new TestSuite (TestUndirectedModel.class);
}
public static void main (String[] args) throws Throwable
{
TestSuite theSuite;
if (args.length > 0) {
theSuite = new TestSuite ();
for (int i = 0; i < args.length; i++) {
theSuite.addTest (new TestUndirectedModel (args[i]));
}
} else {
theSuite = (TestSuite) suite ();
}
junit.textui.TestRunner.run (theSuite);
}
}