Package cc.mallet.grmm.test

Source Code of cc.mallet.grmm.test.TestFactorGraph

/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.test;

import junit.framework.TestCase;
import junit.framework.Test;
import junit.framework.TestSuite;


import java.io.BufferedReader;
import java.io.StringReader;
import java.io.IOException;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Arrays;

import cc.mallet.grmm.inference.RandomGraphs;
import cc.mallet.grmm.types.*;
import cc.mallet.grmm.util.ModelReader;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.Randoms;
import cc.mallet.util.Timing;

/**
* Created: Mar 17, 2005
*
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: TestFactorGraph.java,v 1.1 2007/10/22 21:37:40 mccallum Exp $
*/
public class TestFactorGraph extends TestCase {

  private Variable[] vars;
  private TableFactor tbl1;
  private TableFactor tbl2;
  private TableFactor tbl3;
  private LogTableFactor ltbl1;
  private LogTableFactor ltbl2;

  public TestFactorGraph (String name)
  {
    super (name);
  }

  protected void setUp () throws Exception
  {
    vars = new Variable[] {
            new Variable (2),
            new Variable (2),
            new Variable (2),
            new Variable (2),
    };

    tbl1 = new TableFactor (new Variable[] { vars[0], vars[1] }, new double[] { 0.8, 0.1, 0.1, 0.8 });
    tbl2 = new TableFactor (new Variable[] { vars[1], vars[2] }, new double[] { 0.2, 0.7, 0.8, 0.2 });
    tbl3 = new TableFactor (new Variable[] { vars[2], vars[3] }, new double[] { 0.2, 0.4, 0.6, 0.4 });

    ltbl1 = LogTableFactor.makeFromValues (new Variable[] { vars[0], vars[1] }, new double[] { 0.8, 0.1, 0.1, 0.8 });
    ltbl2 = LogTableFactor.makeFromValues (new Variable[] { vars[1], vars[2] }, new double[] { 0.2, 0.7, 0.8, 0.2 });

  }

  public void testMultiplyBy ()
  {
    FactorGraph fg = new FactorGraph ();
    fg.multiplyBy (tbl1);
    fg.multiplyBy (tbl2);

    assertEquals (2, fg.factors ().size());
    assertTrue (fg.factors ().contains (tbl1));
    assertTrue (fg.factors ().contains (tbl2));

    assertEquals (3, fg.numVariables ());
    assertTrue (fg.variablesSet ().contains (vars[0]));
    assertTrue (fg.variablesSet ().contains (vars[1]));
    assertTrue (fg.variablesSet ().contains (vars[2]));
  }

  public void testNumVariables ()
  {
    FactorGraph fg = new FactorGraph ();
    fg.multiplyBy (tbl1);
    fg.multiplyBy (tbl2);
    assertEquals (3, fg.numVariables ());
  }

  public void testMultiply ()
  {
    FactorGraph fg = new FactorGraph ();
    fg.multiplyBy (tbl1);
    fg.multiplyBy (tbl2);

    FactorGraph fg2 = (FactorGraph) fg.multiply (tbl3);
    assertEquals (2, fg.factors ().size());
    assertEquals (3, fg2.factors ().size());
    assertTrue (!fg.factors ().contains (tbl3));
    assertTrue (fg2.factors ().contains (tbl3));
  }

  public void testValue ()
  {
    FactorGraph fg = new FactorGraph ();
    fg.multiplyBy (tbl1);
    fg.multiplyBy (tbl2);


    Assignment assn = new Assignment (fg.varSet ().toVariableArray (), new int[] { 0, 1, 0 });
    assertEquals (0.08, fg.value (assn), 1e-5);
  }

  public void testMarginalize ()
  {
    FactorGraph fg = new FactorGraph ();
    fg.multiplyBy (tbl1);
    fg.multiplyBy (tbl2);

    Factor marg = fg.marginalize (vars[1]);
    Factor expected = new TableFactor (vars[1], new double[] { 0.81, 0.9 });

    assertTrue (expected.almostEquals (marg));
  }

  public void testSum ()
  {
    FactorGraph fg = new FactorGraph ();
    fg.multiplyBy (tbl1);
    fg.multiplyBy (tbl2);
    assertEquals (1.71, fg.sum (), 1e-5);
  }

  public void testNormalize ()
  {
    FactorGraph fg = new FactorGraph ();
    fg.multiplyBy (tbl1);
    fg.multiplyBy (tbl2);
    fg.normalize ();
    assertEquals (1.0, fg.sum(), 1e-5);
  }

  public void testLogNormalize ()
  {
    FactorGraph fg = new FactorGraph ();
    fg.multiplyBy (ltbl1);
    fg.multiplyBy (ltbl2);
    fg.normalize ();
    assertEquals (1.0, fg.sum(), 1e-5);
  }

  public void testEmbeddedFactorGraph ()
  {
    FactorGraph embeddedFg = new FactorGraph ();
    embeddedFg.multiplyBy (tbl1);
    embeddedFg.multiplyBy (tbl2);

    FactorGraph fg = new FactorGraph ();
    fg.multiplyBy (embeddedFg);
    fg.multiplyBy (tbl3);

    assertEquals (4, fg.varSet ().size ());
    assertEquals (2, fg.factors ().size ());

    Assignment assn = new Assignment (fg.varSet ().toVariableArray (), new int [4]);
    assertEquals (0.032, fg.value (assn), 1e-5);

    AbstractTableFactor tbl = fg.asTable ();
    assertEquals (4, tbl.varSet ().size ());
    assertEquals (0.032, tbl.value (assn), 1e-5);

  }

  public void testAsTable ()
  {
    FactorGraph fg = new FactorGraph ();
    fg.multiplyBy (tbl1);
    fg.multiplyBy (tbl2);

    AbstractTableFactor actual = fg.asTable ();
    AbstractTableFactor expected = (AbstractTableFactor) tbl1.multiply (tbl2);

    assertTrue (expected.almostEquals (actual));
  }

  public void testTableTimesFg ()
  {
    FactorGraph fg = new FactorGraph ();
    fg.multiplyBy (tbl1);
    fg.multiplyBy (tbl2);

    Factor product = tbl3.multiply (fg);

    assertTrue (product instanceof AbstractTableFactor);
    assertEquals (4, product.varSet ().size ());

    Assignment assn = new Assignment (product.varSet ().toVariableArray (), new int [4]);
    assertEquals (0.032, product.value (assn), 1e-5);
  }

  public void testLogTableTimesFg ()
  {
    FactorGraph fg = new FactorGraph ();
    fg.multiplyBy (tbl1);
    fg.multiplyBy (tbl2);

    Factor product = ltbl1.multiply (fg);

    assertTrue (product instanceof AbstractTableFactor);
    assertEquals (3, product.varSet ().size ());

    Assignment assn = new Assignment (product.varSet ().toVariableArray (), new int [3]);
    assertEquals (0.128, product.value (assn), 1e-5);
  }

  public void testRemove ()
  {
    FactorGraph fg = new FactorGraph ();
    fg.multiplyBy (tbl1);
    fg.multiplyBy (tbl2);
    assertEquals (2, fg.getDegree (vars[1]));

    fg.divideBy (tbl1);

    assertEquals (2, fg.varSet ().size ());

    Assignment assn = new Assignment (fg.varSet ().toVariableArray (), new int [2]);
    assertEquals (0.2, fg.value (assn), 1e-5);

    int nvs = 0;
    for (Iterator it = fg.varSetIterator (); it.hasNext(); it.next ()) {
      nvs++;
    }
    assertEquals (1, nvs);

    assertEquals (1, fg.getDegree (vars[1]));

    assertTrue (fg.get (0) != fg.get (1));
    assertEquals (vars[1], fg.get (0));
    assertEquals (vars[2], fg.get (1));

  }

  public void testRedundantDomains ()
  {
    FactorGraph fg = new FactorGraph ();
    fg.multiplyBy (tbl1);
    fg.multiplyBy (tbl2);
    fg.multiplyBy (ltbl1);

    assertEquals (3, fg.varSet ().size ());
    assertEquals ("Wrong factors in FG, was "+fg.dumpToString (), 3, fg.factors ().size ());


    Assignment assn = new Assignment (fg.varSet ().toVariableArray (), new int [3]);
    assertEquals (0.128, fg.value (assn), 1e-5);
  }

  private static String uniformMdlstr =
          "VAR sigma u1 u2 : continuous\n" +
          "VAR x1 x2 : 2\n" +
          "sigma ~ Uniform -0.5 0.5\n" +
          "u1 ~ Uniform -0.5 0.5\n" +
          "u2 ~ Uniform -0.5 0.5\n" +
          "x1 x2 ~ BinaryPair sigma\n" +
          "x1 ~ Unary u1\n" +
          "x2 ~ Unary u2\n";

  public void testContinousSample () throws IOException
  {
    ModelReader reader = new ModelReader ();
    FactorGraph fg = reader.readModel (new BufferedReader (new StringReader (uniformMdlstr)));
    Randoms r = new Randoms (324143);
    Assignment allAssn = new Assignment ();
    for (int i = 0; i < 10000; i++) {
      Assignment row = fg.sample (r);
      allAssn.addRow (row);
    }

    Variable x1 = fg.findVariable ("x1");
    Assignment assn1 = (Assignment) allAssn.marginalize (x1);
    int[] col = assn1.getColumnInt (x1);
    double mean = MatrixOps.sum (col) / ((double)col.length);
    assertEquals (0.5, mean, 0.025);
  }

  private static String uniformMdlstr2 =
          "VAR sigma u1 u2 : continuous\n" +
          "VAR x1 x2 : 2\n" +
          "sigma ~ Normal 0.0 0.2\n" +
          "u1 ~ Normal 0.0 0.2\n" +
          "u2 ~ Normal 0.0 0.2\n" +
          "x1 x2 ~ BinaryPair sigma\n" +
          "x1 ~ Unary u1\n" +
          "x2 ~ Unary u2\n";

  public void testContinousSample2 () throws IOException
  {
    ModelReader reader = new ModelReader ();
    FactorGraph fg = reader.readModel (new BufferedReader (new StringReader (uniformMdlstr2)));
    Randoms r = new Randoms (324143);
    Assignment allAssn = new Assignment ();
    for (int i = 0; i < 10000; i++) {
      Assignment row = fg.sample (r);
      allAssn.addRow (row);
    }

    Variable x1 = fg.findVariable ("x2");
    Assignment assn1 = (Assignment) allAssn.marginalize (x1);
    int[] col = assn1.getColumnInt (x1);
    double mean = MatrixOps.sum (col) / ((double)col.length);
    assertEquals (0.5, mean, 0.01);

    Variable x2 = fg.findVariable ("x2");
    Assignment assn2 = (Assignment) allAssn.marginalize (x2);
    int[] col2 = assn2.getColumnInt (x2);
    double mean2 = MatrixOps.sum (col2) / ((double)col2.length);
    assertEquals (0.5, mean2, 0.025);
  }

  public void testAllFactorsOf () throws IOException
  {
    ModelReader reader = new ModelReader ();
    FactorGraph fg = reader.readModel (new BufferedReader (new StringReader (uniformMdlstr2)));
    Variable var = new Variable (2);
    var.setLabel ("v0");
    List lst = fg.allFactorsOf (var);
    assertEquals (0, lst.size ());
  }

  public void testAllFactorsOf2 () throws IOException
  {
    Variable x1 = new Variable (2);
    Variable x2 = new Variable (2);

    FactorGraph fg = new FactorGraph ();
    fg.addFactor (new TableFactor (x1));
    fg.addFactor (new TableFactor (x2));
    fg.addFactor (new TableFactor (new Variable[] { x1, x2 }));

    List lst = fg.allFactorsOf (x1);
    assertEquals (1, lst.size ());
    for (Iterator it = lst.iterator (); it.hasNext ();) {
      Factor f = (Factor) it.next ();
      assertEquals (1, f.varSet().size());
      assertTrue (f.varSet ().contains (x1));
    }

    HashVarSet vs = new HashVarSet (new Variable[]{x1, x2});
    List lst2 = fg.allFactorsOf (vs);
    assertEquals (1, lst2.size ());

    Factor f = (Factor) lst2.get (0);
    assertTrue (f.varSet ().equals (vs));
  }


  public void testAsTable2 ()
  {
    Factor f1 = new TableFactor (vars[0], new double[] { 0.6, 0.4 });
    Factor f2 = new ConstantFactor (2.0);
    FactorGraph fg = new FactorGraph (new Factor[] { f1, f2 });
    AbstractTableFactor tbl = fg.asTable ();
    assertTrue (Arrays.equals(new double[] { 0.6 * 2.0, 0.4 * 2.0 }, tbl.toValueArray ()));
  }

  public void testClear ()
  {
    FactorGraph fg = new FactorGraph ();
    fg.multiplyBy (tbl1);
    fg.multiplyBy (tbl2);

    assertEquals (3, fg.numVariables ());
    assertEquals (2, fg.factors ().size ());

    fg.clear ();
    assertEquals (0, fg.numVariables ());
    assertEquals (0, fg.factors ().size ());

    for (int vi = 0; vi < tbl1.varSet ().size (); vi++) {
      assertTrue (!fg.containsVar (tbl1.getVariable (vi)));
    }
    for (int vi = 0; vi < tbl2.varSet ().size (); vi++) {
      assertTrue (!fg.containsVar (tbl2.getVariable (vi)));
    }
  }

  public void testCacheExpanding ()
  {
    FactorGraph baseFg = RandomGraphs.randomFrustratedGrid (25, 1.0, new java.util.Random (3324879));
    Assignment assn = new Assignment (baseFg, new int[baseFg.numVariables ()]);
    double val = baseFg.logValue (assn);

    Timing timing = new Timing ();

    int numReps = 100;
    for (int rep = 0; rep < numReps; rep++) {
      FactorGraph fg = new FactorGraph (baseFg.numVariables ());
      for (int fi = 0; fi < baseFg.factors().size(); fi++) {
        fg.multiplyBy (baseFg.getFactor (fi));
      }
      assertEquals (val, fg.logValue (assn), 1e-5);
    }
    long time1 = timing.elapsedTime ();
    timing.tick ("No-expansion time");

    for (int rep = 0; rep < numReps; rep++) {
      FactorGraph fg = new FactorGraph ();
      for (int fi = 0; fi < baseFg.factors().size(); fi++) {
        fg.multiplyBy (baseFg.getFactor (fi));
      }
      assertEquals (val, fg.logValue (assn), 1e-5);
    }
    long time2 = timing.elapsedTime ();
    timing.tick ("With-expansion time");

    assertTrue (time1 < time2);
  }


  public static Test suite ()
  {
    return new TestSuite (TestFactorGraph.class);
  }

  public static void main (String[] args) throws Throwable
  {
    TestSuite theSuite;
    if (args.length > 0) {
      theSuite = new TestSuite ();
      for (int i = 0; i < args.length; i++) {
        theSuite.addTest (new TestFactorGraph (args[i]));
      }
    } else {
      theSuite = (TestSuite) TestFactorGraph.suite ();
    }

    junit.textui.TestRunner.run (theSuite);
  }

}
TOP

Related Classes of cc.mallet.grmm.test.TestFactorGraph

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.