/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.learning.extract;
import java.util.Iterator;
import java.util.Random;
import java.util.List;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.logging.Logger;
import java.io.File;
import cc.mallet.extract.Extraction;
import cc.mallet.extract.TokenizationFilter;
import cc.mallet.grmm.inference.Inferencer;
import cc.mallet.grmm.learning.*;
import cc.mallet.grmm.util.RememberTokenizationPipe;
import cc.mallet.grmm.util.PipedIterator;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.PipeUtils;
import cc.mallet.pipe.iterator.PipeInputIterator;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Instance;
import cc.mallet.util.CollectionUtils;
import cc.mallet.util.FileUtils;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Timing;
/**
* Created: Mar 31, 2005
*
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: ACRFExtractorTrainer.java,v 1.1 2007/10/22 21:38:02 mccallum Exp $
*/
public class ACRFExtractorTrainer {
private static final Logger logger = MalletLogger.getLogger (ACRFExtractorTrainer.class.getName());
private int numIter = 99999;
protected ACRF.Template[] tmpls;
protected InstanceList training;
protected InstanceList testing;
private Iterator<Instance> testIterator;
private Iterator<Instance> trainIterator;
ACRFTrainer trainer = new DefaultAcrfTrainer ();
protected Pipe featurePipe;
protected Pipe tokPipe;
protected ACRFEvaluator evaluator = new DefaultAcrfTrainer.LogEvaluator ();
TokenizationFilter filter;
private Inferencer inferencer;
private Inferencer viterbiInferencer;
private int numCheckpointIterations = -1;
private File checkpointDirectory = null;
private boolean usePerTemplateTrain = false;
private int perTemplateIterations = 100;
private boolean cacheUnrolledGraphs;
// For data subsets
private Random r;
private double trainingPct = -1;
private double testingPct = -1;
// Using cascaded setter idiom
public ACRFExtractorTrainer setTemplates (ACRF.Template[] tmpls)
{
this.tmpls = tmpls;
return this;
}
public ACRFExtractorTrainer setDataSource (Iterator<Instance> trainIterator, Iterator<Instance> testIterator)
{
this.trainIterator = trainIterator;
this.testIterator = testIterator;
return this;
}
public ACRFExtractorTrainer setData (InstanceList training, InstanceList testing)
{
this.training = training;
this.testing = testing;
return this;
}
public ACRFExtractorTrainer setNumIterations (int numIter)
{
this.numIter = numIter;
return this;
}
public int getNumIter ()
{
return numIter;
}
public ACRFExtractorTrainer setPipes (Pipe tokPipe, Pipe featurePipe)
{
RememberTokenizationPipe rtp = new RememberTokenizationPipe ();
this.featurePipe = PipeUtils.concatenatePipes (rtp, featurePipe);
this.tokPipe = tokPipe;
return this;
}
public ACRFExtractorTrainer setEvaluator (ACRFEvaluator evaluator)
{
this.evaluator = evaluator;
return this;
}
public ACRFExtractorTrainer setTrainingMethod (ACRFTrainer acrfTrainer)
{
trainer = acrfTrainer;
return this;
}
public ACRFExtractorTrainer setTokenizatioFilter (TokenizationFilter filter)
{
this.filter = filter;
return this;
}
public ACRFExtractorTrainer setCacheUnrolledGraphs (boolean cacheUnrolledGraphs)
{
this.cacheUnrolledGraphs = cacheUnrolledGraphs;
return this;
}
public ACRFExtractorTrainer setNumCheckpointIterations (int numCheckpointIterations)
{
this.numCheckpointIterations = numCheckpointIterations;
return this;
}
public ACRFExtractorTrainer setCheckpointDirectory (File checkpointDirectory)
{
this.checkpointDirectory = checkpointDirectory;
return this;
}
public ACRFExtractorTrainer setUsePerTemplateTrain (boolean usePerTemplateTrain)
{
this.usePerTemplateTrain = usePerTemplateTrain;
return this;
}
public ACRFExtractorTrainer setPerTemplateIterations (int numIter)
{
this.perTemplateIterations = numIter;
return this;
}
public ACRFTrainer getTrainer ()
{
return trainer;
}
public TokenizationFilter getFilter ()
{
return filter;
}
// Main methods
public ACRFExtractor trainExtractor ()
{
ACRF acrf = (usePerTemplateTrain) ? perTemplateTrain() : trainAcrf ();
ACRFExtractor extor = new ACRFExtractor (acrf, tokPipe, featurePipe);
if (filter != null) extor.setTokenizationFilter (filter);
return extor;
}
private ACRF perTemplateTrain ()
{
Timing timing = new Timing ();
boolean hasConverged = false;
ACRF miniAcrf = null;
if (training == null) setupData ();
for (int ti = 0; ti < tmpls.length; ti++) {
ACRF.Template[] theseTmpls = new ACRF.Template[ti+1];
System.arraycopy (tmpls, 0, theseTmpls, 0, theseTmpls.length);
logger.info ("***PerTemplateTrain: Round "+ti+"\n Templates: "+
CollectionUtils.dumpToString (Arrays.asList (theseTmpls), " "));
miniAcrf = new ACRF (featurePipe, theseTmpls);
setupAcrf (miniAcrf);
ACRFEvaluator eval = setupEvaluator ("tmpl"+ti);
hasConverged = trainer.train (miniAcrf, training, null, testing, eval, perTemplateIterations);
timing.tick ("PerTemplateTrain round "+ti);
}
// finish by training to convergence
ACRFEvaluator eval = setupEvaluator ("full");
if (!hasConverged)
trainer.train (miniAcrf, training, null, testing, eval, numIter);
// the last acrf is the one to go with;
return miniAcrf;
}
/**
* Trains a new ACRF object with the given settings. Subclasses may override this method
* to implement alternative training procedures.
* @return a trained ACRF
*/
public ACRF trainAcrf ()
{
if (training == null) setupData ();
ACRF acrf = new ACRF (featurePipe, tmpls);
setupAcrf (acrf);
ACRFEvaluator eval = setupEvaluator ("");
trainer.train (acrf, training, null, testing, eval, numIter);
return acrf;
}
private void setupAcrf (ACRF acrf)
{
if (cacheUnrolledGraphs) acrf.setCacheUnrolledGraphs (true);
if (inferencer != null) acrf.setInferencer (inferencer);
if (viterbiInferencer != null) acrf.setViterbiInferencer (viterbiInferencer);
}
private ACRFEvaluator setupEvaluator (String checkpointPrefix)
{
ACRFEvaluator eval = evaluator;
if (numCheckpointIterations > 0) {
List evals = new ArrayList ();
evals.add (evaluator);
evals.add (new CheckpointingEvaluator (checkpointDirectory, numCheckpointIterations, tokPipe, featurePipe));
eval = new AcrfSerialEvaluator (evals);
}
return eval;
}
protected void setupData ()
{
Timing timing = new Timing ();
training = new InstanceList (featurePipe);
training.addThruPipe (new PipedIterator (trainIterator, tokPipe));
if (trainingPct > 0) training = subsetData (training, trainingPct);
if (testIterator != null) {
testing = new InstanceList (featurePipe);
testing.addThruPipe (new PipedIterator (testIterator, tokPipe));
if (testingPct > 0) testing = subsetData (testing, trainingPct);
}
timing.tick ("Data loading");
}
private InstanceList subsetData (InstanceList data, double pct)
{
InstanceList[] lsts = data.split (r, new double[] { pct, 1 - pct });
return lsts[0];
}
public InstanceList getTrainingData ()
{
if (training == null) setupData ();
return training;
}
public InstanceList getTestingData ()
{
if (testing == null) setupData ();
return testing;
}
public Extraction extractOnTestData (ACRFExtractor extor)
{
return extor.extract (testing);
}
public ACRFExtractorTrainer setInferencer (Inferencer inferencer)
{
this.inferencer = inferencer;
return this;
}
public ACRFExtractorTrainer setViterbiInferencer (Inferencer viterbiInferencer)
{
this.viterbiInferencer = viterbiInferencer;
return this;
}
public ACRFExtractorTrainer setDataSubsets (Random random, double trainingPct, double testingPct)
{
r = random;
this.trainingPct = trainingPct;
this.testingPct = testingPct;
return this;
}
// checkpointing
private static class CheckpointingEvaluator extends ACRFEvaluator {
private File directory;
private int interval;
private Pipe tokPipe;
private Pipe featurePipe;
public CheckpointingEvaluator (File directory, int interval, Pipe tokPipe, Pipe featurePipe)
{
this.directory = directory;
this.interval = interval;
this.tokPipe = tokPipe;
this.featurePipe = featurePipe;
}
public boolean evaluate (ACRF acrf, int iter, InstanceList training, InstanceList validation, InstanceList testing)
{
if (iter > 0 && iter % interval == 0) {
ACRFExtractor extor = new ACRFExtractor (acrf, tokPipe, featurePipe);
FileUtils.writeGzippedObject (new File (directory, "extor."+iter+".ser.gz"), extor);
}
return true;
}
public void test (InstanceList gold, List returned, String description) { }
}
}