/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.learning.extract;
/**
*
* Created: Aug 23, 2005
*
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: AcrfExtractorTui.java,v 1.1 2007/10/22 21:38:02 mccallum Exp $
*/
import bsh.EvalError;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.logging.Logger;
import java.util.regex.Pattern;
import cc.mallet.extract.Extraction;
import cc.mallet.extract.ExtractionEvaluator;
import cc.mallet.grmm.inference.Inferencer;
import cc.mallet.grmm.learning.*;
import cc.mallet.pipe.*;
import cc.mallet.pipe.iterator.FileListIterator;
import cc.mallet.pipe.iterator.LineGroupIterator;
import cc.mallet.pipe.iterator.PipeInputIterator;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.*;
public class AcrfExtractorTui {
private static final Logger logger = MalletLogger.getLogger (AcrfExtractorTui.class.getName ());
private static CommandOption.File outputPrefix = new CommandOption.File
(AcrfExtractorTui.class, "output-prefix", "FILENAME", true, null,
"Directory to write saved model to.", null);
private static CommandOption.File modelFile = new CommandOption.File
(AcrfExtractorTui.class, "model-file", "FILENAME", true, null, "Text file describing model structure.", null);
private static CommandOption.File trainFile = new CommandOption.File
(AcrfExtractorTui.class, "training", "FILENAME", true, null, "File containing training data.", null);
private static CommandOption.File testFile = new CommandOption.File
(AcrfExtractorTui.class, "testing", "FILENAME", true, null, "File containing testing data.", null);
private static CommandOption.Integer numLabelsOption = new CommandOption.Integer
(AcrfExtractorTui.class, "num-labels", "INT", true, -1,
"If supplied, number of labels on each line of input file." +
" Otherwise, the token ---- must separate labels from features.", null);
private static CommandOption.String trainerOption = new CommandOption.String
(AcrfExtractorTui.class, "trainer", "STRING", true, "ACRFExtractorTrainer",
"Specification of trainer type.", null);
private static CommandOption.String inferencerOption = new CommandOption.String
(AcrfExtractorTui.class, "inferencer", "STRING", true, "LoopyBP",
"Specification of inferencer.", null);
private static CommandOption.String maxInferencerOption = new CommandOption.String
(AcrfExtractorTui.class, "max-inferencer", "STRING", true, "LoopyBP.createForMaxProduct()",
"Specification of inferencer.", null);
private static CommandOption.String evalOption = new CommandOption.String
(AcrfExtractorTui.class, "eval", "STRING", true, "LOG",
"Evaluator to use. Java code grokking performed.", null);
private static CommandOption.String extractionEvalOption = new CommandOption.String
(AcrfExtractorTui.class, "extraction-eval", "STRING", true, "PerDocumentF1",
"Evaluator to use. Java code grokking performed.", null);
private static CommandOption.Integer checkpointIterations = new CommandOption.Integer
(AcrfExtractorTui.class, "checkpoint", "INT", true, -1, "Save a copy after every ___ iterations.", null);
static CommandOption.Boolean cacheUnrolledGraph = new CommandOption.Boolean
(AcrfExtractorTui.class, "cache-graphs", "true|false", true, true,
"Whether to use memory-intensive caching.", null);
static CommandOption.Boolean perTemplateTrain = new CommandOption.Boolean
(AcrfExtractorTui.class, "per-template-train", "true|false", true, false,
"Whether to pretrain templates before joint training.", null);
static CommandOption.Integer pttIterations = new CommandOption.Integer
(AcrfExtractorTui.class, "per-template-iterations", "INTEGER", false, 100,
"How many training iterations for each step of per-template-training.", null);
static CommandOption.Integer randomSeedOption = new CommandOption.Integer
(AcrfExtractorTui.class, "random-seed", "INTEGER", true, 0,
"The random seed for randomly selecting a proportion of the instance list for training", null);
static CommandOption.Boolean useTokenText = new CommandOption.Boolean
(AcrfExtractorTui.class, "use-token-text", "true|false", true, true,
"If true, first feature in list is assumed to be token identity, and is treated specially.", null);
private static CommandOption.Boolean labelsAtEnd = new CommandOption.Boolean
(AcrfExtractorTui.class, "labels-at-end", "INT", true, false,
"If true, then label is at end of each line, rather than beginning.", null);
static CommandOption.Boolean trainingIsList = new CommandOption.Boolean
(AcrfExtractorTui.class, "training-is-list", "true|false", true, false,
"If true, training option gives list of files to read for training.", null);
private static CommandOption.File dataDir = new CommandOption.File
(AcrfExtractorTui.class, "data-dir", "FILENAME", true, null, "If training-is-list, base directory in which training files located.", null);
private static BshInterpreter interpreter = setupInterpreter ();
public static void main (String[] args) throws IOException, EvalError
{
doProcessOptions (AcrfExtractorTui.class, args);
Timing timing = new Timing ();
GenericAcrfData2TokenSequence basePipe;
if (!numLabelsOption.wasInvoked ()) {
basePipe = new GenericAcrfData2TokenSequence ();
} else {
basePipe = new GenericAcrfData2TokenSequence (numLabelsOption.value);
}
if (!useTokenText.value) {
basePipe.setFeaturesIncludeToken(false);
basePipe.setIncludeTokenText(false);
}
basePipe.setLabelsAtEnd (labelsAtEnd.value);
Pipe tokPipe = new SerialPipes (new Pipe[] {
(trainingIsList.value ? new Input2CharSequence () : (Pipe) new Noop ()),
basePipe,
});
Iterator<Instance> trainSource = constructIterator(trainFile.value, dataDir.value, trainingIsList.value);
Iterator<Instance> testSource;
if (testFile.wasInvoked ()) {
testSource = constructIterator (testFile.value, dataDir.value, trainingIsList.value);
} else {
testSource = null;
}
ACRF.Template[] tmpls = parseModelFile (modelFile.value);
ACRFExtractorTrainer trainer = createTrainer (trainerOption.value);
ACRFEvaluator eval = createEvaluator (evalOption.value);
ExtractionEvaluator extractionEval = createExtractionEvaluator (extractionEvalOption.value);
Inferencer inf = createInferencer (inferencerOption.value);
Inferencer maxInf = createInferencer (maxInferencerOption.value);
trainer.setPipes (tokPipe, new TokenSequence2FeatureVectorSequence ())
.setDataSource (trainSource, testSource)
.setEvaluator (eval)
.setTemplates (tmpls)
.setInferencer (inf)
.setViterbiInferencer (maxInf)
.setCheckpointDirectory (outputPrefix.value)
.setNumCheckpointIterations (checkpointIterations.value)
.setCacheUnrolledGraphs (cacheUnrolledGraph.value)
.setUsePerTemplateTrain (perTemplateTrain.value)
.setPerTemplateIterations (pttIterations.value);
logger.info ("Starting training...");
ACRFExtractor extor = trainer.trainExtractor ();
timing.tick ("Training");
FileUtils.writeGzippedObject (new File (outputPrefix.value, "extor.ser.gz"), extor);
timing.tick ("Serializing");
InstanceList testing = trainer.getTestingData ();
if (testing != null) {
eval.test (extor.getAcrf (), testing, "Final results");
}
if ((extractionEval != null) && (testing != null)) {
Extraction extraction = extor.extract (testing);
extractionEval.evaluate (extraction);
timing.tick ("Evaluting");
}
System.out.println ("Total time (ms) = " + timing.elapsedTime ());
}
private static BshInterpreter setupInterpreter ()
{
BshInterpreter interpreter = CommandOption.getInterpreter ();
try {
interpreter.eval ("import edu.umass.cs.mallet.base.extract.*");
interpreter.eval ("import edu.umass.cs.mallet.grmm.inference.*");
interpreter.eval ("import edu.umass.cs.mallet.grmm.learning.*");
interpreter.eval ("import edu.umass.cs.mallet.grmm.learning.templates.*");
interpreter.eval ("import edu.umass.cs.mallet.grmm.learning.extract.*");
} catch (EvalError e) {
throw new RuntimeException (e);
}
return interpreter;
}
private static Iterator<Instance> constructIterator (File trainFile, File dataDir, boolean isList) throws IOException
{
if (isList) {
return new FileListIterator (trainFile, dataDir, null, null, true);
} else {
return new LineGroupIterator (new FileReader (trainFile), Pattern.compile ("^\\s*$"), true);
}
}
public static ACRFEvaluator createEvaluator (String spec) throws EvalError
{
if (spec.indexOf ('(') >= 0) {
// assume it's Java code, and don't screw with it.
return (ACRFEvaluator) interpreter.eval (spec);
} else {
LinkedList toks = new LinkedList (Arrays.asList (spec.split ("\\s+")));
return createEvaluator (toks);
}
}
private static ExtractionEvaluator createExtractionEvaluator (String spec) throws EvalError
{
if (spec.indexOf ('(') >= 0) {
// assume it's Java code, and don't screw with it.
return (ExtractionEvaluator) interpreter.eval (spec);
} else {
spec = "new "+spec+"Evaluator ()";
return (ExtractionEvaluator) interpreter.eval (spec);
}
}
private static ACRFEvaluator createEvaluator (LinkedList toks)
{
String type = (String) toks.removeFirst ();
if (type.equalsIgnoreCase ("SEGMENT")) {
int slice = Integer.parseInt ((String) toks.removeFirst ());
if (toks.size() % 2 != 0)
throw new RuntimeException ("Error in --eval "+evalOption.value+": Every start tag must have a continue.");
int numTags = toks.size () / 2;
String[] startTags = new String [numTags];
String[] continueTags = new String [numTags];
for (int i = 0; i < numTags; i++) {
startTags[i] = (String) toks.removeFirst ();
continueTags[i] = (String) toks.removeFirst ();
}
return new MultiSegmentationEvaluatorACRF (startTags, continueTags, slice);
} else if (type.equalsIgnoreCase ("LOG")) {
return new DefaultAcrfTrainer.LogEvaluator ();
} else if (type.equalsIgnoreCase ("SERIAL")) {
List evals = new ArrayList ();
while (!toks.isEmpty ()) {
evals.add (createEvaluator (toks));
}
return new AcrfSerialEvaluator (evals);
} else {
throw new RuntimeException ("Error in --eval "+evalOption.value+": illegal evaluator "+type);
}
}
private static ACRFExtractorTrainer createTrainer (String spec) throws EvalError
{
String cmd;
if (spec.indexOf ('(') >= 0) {
// assume it's Java code, and don't screw with it.
cmd = spec;
} else if (spec.endsWith ("Trainer")) {
cmd = "new "+spec+"()";
} else {
cmd = "new "+spec+"Trainer()";
}
// Return whatever the Java code says to
Object trainer = interpreter.eval (cmd);
if (trainer instanceof ACRFExtractorTrainer)
return (ACRFExtractorTrainer) trainer;
else if (trainer instanceof DefaultAcrfTrainer)
return new ACRFExtractorTrainer ().setTrainingMethod ((ACRFTrainer) trainer);
else throw new RuntimeException ("Don't know what to do with trainer "+trainer);
}
private static Inferencer createInferencer (String spec) throws EvalError
{
String cmd;
if (spec.indexOf ('(') >= 0) {
// assume it's Java code, and don't screw with it.
cmd = spec;
} else {
cmd = "new "+spec+"()";
}
// Return whatever the Java code says to
Object inf = interpreter.eval (cmd);
if (inf instanceof Inferencer)
return (Inferencer) inf;
else throw new RuntimeException ("Don't know what to do with inferencer "+inf);
}
public static void doProcessOptions (Class childClass, String[] args)
{
CommandOption.List options = new CommandOption.List ("", new CommandOption[0]);
options.add (childClass);
options.process (args);
options.logOptions (Logger.getLogger (""));
}
private static ACRF.Template[] parseModelFile (File mdlFile) throws IOException, EvalError
{
BufferedReader in = new BufferedReader (new FileReader (mdlFile));
List tmpls = new ArrayList ();
String line = in.readLine ();
while (line != null) {
Object tmpl = interpreter.eval (line);
if (!(tmpl instanceof ACRF.Template)) {
throw new RuntimeException ("Error in "+mdlFile+" line "+in.toString ()+":\n Object "+tmpl+" not a template");
}
tmpls.add (tmpl);
line = in.readLine ();
}
return (ACRF.Template[]) tmpls.toArray (new ACRF.Template [0]);
}
}