/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.extract.test;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;
import java.io.*;
import cc.mallet.extract.CRFExtractor;
import cc.mallet.extract.Extraction;
import cc.mallet.extract.LatticeViewer;
import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFTrainerByLabelLikelihood;
import cc.mallet.fst.MEMM;
import cc.mallet.fst.MEMMTrainer;
import cc.mallet.fst.TokenAccuracyEvaluator;
import cc.mallet.fst.TransducerEvaluator;
import cc.mallet.fst.tests.TestCRF;
import cc.mallet.fst.tests.TestMEMM;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.iterator.ArrayIterator;
import cc.mallet.types.InstanceList;
/**
* Created: Oct 31, 2004
*
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: TestLatticeViewer.java,v 1.1 2007/10/22 21:38:02 mccallum Exp $
*/
public class TestLatticeViewer extends TestCase {
public TestLatticeViewer (String name)
{
super (name);
}
private static File htmlFile = new File ("errors.html");
private static File latticeFile = new File ("lattice.html");
private static File htmlDir = new File ("html/");
public void testSpaceViewer () throws FileNotFoundException
{
Pipe pipe = TestMEMM.makeSpacePredictionPipe ();
String[] data0 = { TestCRF.data[0] };
String[] data1 = { TestCRF.data[1] };
InstanceList training = new InstanceList (pipe);
training.addThruPipe (new ArrayIterator (data0));
InstanceList testing = new InstanceList (pipe);
testing.addThruPipe (new ArrayIterator (data1));
CRF crf = new CRF (pipe, null);
crf.addFullyConnectedStatesForLabels ();
CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood (crf);
crft.trainIncremental (training);
CRFExtractor extor = hackCrfExtor (crf);
Extraction extration = extor.extract (new ArrayIterator (data1));
PrintStream out = new PrintStream (new FileOutputStream (htmlFile));
LatticeViewer.extraction2html (extration, extor, out);
out.close();
out = new PrintStream (new FileOutputStream (latticeFile));
LatticeViewer.extraction2html (extration, extor, out, true);
out.close();
}
static CRFExtractor hackCrfExtor (CRF crf)
{
Pipe[] newPipes = new Pipe [3];
SerialPipes pipes = (SerialPipes) crf.getInputPipe ();
for (int i = 0; i < 3; i++) {
Pipe p0 = pipes.getPipe (0);
//pipes.removePipe (0); TODO Fix me
//p0.setParent (null);
newPipes[i] = p0;
}
Pipe tokPipe = new SerialPipes (newPipes);
CRFExtractor extor = new CRFExtractor (crf, (Pipe)tokPipe);
return extor;
}
public void testDualSpaceViewer () throws IOException
{
Pipe pipe = TestMEMM.makeSpacePredictionPipe ();
String[] data0 = { TestCRF.data[0] };
String[] data1 = TestCRF.data;
InstanceList training = new InstanceList (pipe);
training.addThruPipe (new ArrayIterator (data0));
InstanceList testing = new InstanceList (pipe);
testing.addThruPipe (new ArrayIterator (data1));
CRF crf = new CRF (pipe, null);
crf.addFullyConnectedStatesForLabels ();
CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood (crf);
TokenAccuracyEvaluator eval = new TokenAccuracyEvaluator (new InstanceList[] {training, testing}, new String[] {"Training", "Testing"});
for (int i = 0; i < 5; i++) {
crft.train (training, 1);
eval.evaluate(crft);
}
CRFExtractor extor = hackCrfExtor (crf);
Extraction e1 = extor.extract (new ArrayIterator (data1));
Pipe pipe2 = TestMEMM.makeSpacePredictionPipe ();
InstanceList training2 = new InstanceList (pipe2);
training2.addThruPipe (new ArrayIterator (data0));
InstanceList testing2 = new InstanceList (pipe2);
testing2.addThruPipe (new ArrayIterator (data1));
MEMM memm = new MEMM (pipe2, null);
memm.addFullyConnectedStatesForLabels ();
MEMMTrainer memmt = new MEMMTrainer (memm);
TransducerEvaluator memmeval = new TokenAccuracyEvaluator (new InstanceList[] {training2, testing2}, new String[] {"Training2", "Testing2"});
memmt.train (training2, 5);
memmeval.evaluate(memmt);
CRFExtractor extor2 = hackCrfExtor (memm);
Extraction e2 = extor2.extract (new ArrayIterator (data1));
if (!htmlDir.exists ()) htmlDir.mkdir ();
LatticeViewer.viewDualResults (htmlDir, e1, extor, e2, extor2);
}
public static Test suite ()
{
return new TestSuite (TestLatticeViewer.class);
}
public static void main (String[] args) throws Throwable
{
TestSuite theSuite;
if (args.length > 0) {
theSuite = new TestSuite ();
for (int i = 0; i < args.length; i++) {
theSuite.addTest (new TestLatticeViewer (args[i]));
}
} else {
theSuite = (TestSuite) suite ();
}
junit.textui.TestRunner.run (theSuite);
}
}