/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.learning.extract;
* Created: Aug 23, 2005
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: AcrfExtractorTui.java,v 1.1 2007/10/22 21:38:02 mccallum Exp $
import bsh.EvalError;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.logging.Logger;
import java.util.regex.Pattern;
import cc.mallet.extract.Extraction;
import cc.mallet.extract.ExtractionEvaluator;
import cc.mallet.grmm.inference.Inferencer;
import cc.mallet.grmm.learning.*;
import cc.mallet.pipe.*;
import cc.mallet.pipe.iterator.FileListIterator;
import cc.mallet.pipe.iterator.LineGroupIterator;
import cc.mallet.pipe.iterator.PipeInputIterator;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.*;
public class AcrfExtractorTui {
private static final Logger logger = MalletLogger.getLogger (AcrfExtractorTui.class.getName ());
private static CommandOption.File outputPrefix = new CommandOption.File
(AcrfExtractorTui.class, "output-prefix", "FILENAME", true, null,
"Directory to write saved model to.", null);
private static CommandOption.File modelFile = new CommandOption.File
(AcrfExtractorTui.class, "model-file", "FILENAME", true, null, "Text file describing model structure.", null);
private static CommandOption.File trainFile = new CommandOption.File
(AcrfExtractorTui.class, "training", "FILENAME", true, null, "File containing training data.", null);
private static CommandOption.File testFile = new CommandOption.File
(AcrfExtractorTui.class, "testing", "FILENAME", true, null, "File containing testing data.", null);
private static CommandOption.Integer numLabelsOption = new CommandOption.Integer
(AcrfExtractorTui.class, "num-labels", "INT", true, -1,
"If supplied, number of labels on each line of input file." +
" Otherwise, the token ---- must separate labels from features.", null);
private static CommandOption.String trainerOption = new CommandOption.String
(AcrfExtractorTui.class, "trainer", "STRING", true, "ACRFExtractorTrainer",
"Specification of trainer type.", null);
private static CommandOption.String inferencerOption = new CommandOption.String
(AcrfExtractorTui.class, "inferencer", "STRING", true, "LoopyBP",
"Specification of inferencer.", null);
private static CommandOption.String maxInferencerOption = new CommandOption.String
(AcrfExtractorTui.class, "max-inferencer", "STRING", true, "LoopyBP.createForMaxProduct()",
"Specification of inferencer.", null);
private static CommandOption.String evalOption = new CommandOption.String
(AcrfExtractorTui.class, "eval", "STRING", true, "LOG",
"Evaluator to use. Java code grokking performed.", null);
private static CommandOption.String extractionEvalOption = new CommandOption.String
(AcrfExtractorTui.class, "extraction-eval", "STRING", true, "PerDocumentF1",
"Evaluator to use. Java code grokking performed.", null);
private static CommandOption.Integer checkpointIterations = new CommandOption.Integer
(AcrfExtractorTui.class, "checkpoint", "INT", true, -1, "Save a copy after every ___ iterations.", null);
static CommandOption.Boolean cacheUnrolledGraph = new CommandOption.Boolean
(AcrfExtractorTui.class, "cache-graphs", "true|false", true, true,
"Whether to use memory-intensive caching.", null);
static CommandOption.Boolean perTemplateTrain = new CommandOption.Boolean
(AcrfExtractorTui.class, "per-template-train", "true|false", true, false,
"Whether to pretrain templates before joint training.", null);
static CommandOption.Integer pttIterations = new CommandOption.Integer
(AcrfExtractorTui.class, "per-template-iterations", "INTEGER", false, 100,
"How many training iterations for each step of per-template-training.", null);
static CommandOption.Integer randomSeedOption = new CommandOption.Integer
(AcrfExtractorTui.class, "random-seed", "INTEGER", true, 0,
"The random seed for randomly selecting a proportion of the instance list for training", null);
static CommandOption.Boolean useTokenText = new CommandOption.Boolean
(AcrfExtractorTui.class, "use-token-text", "true|false", true, true,
"If true, first feature in list is assumed to be token identity, and is treated specially.", null);
private static CommandOption.Boolean labelsAtEnd = new CommandOption.Boolean
(AcrfExtractorTui.class, "labels-at-end", "INT", true, false,
"If true, then label is at end of each line, rather than beginning.", null);
static CommandOption.Boolean trainingIsList = new CommandOption.Boolean
(AcrfExtractorTui.class, "training-is-list", "true|false", true, false,
"If true, training option gives list of files to read for training.", null);
private static CommandOption.File dataDir = new CommandOption.File
(AcrfExtractorTui.class, "data-dir", "FILENAME", true, null, "If training-is-list, base directory in which training files located.", null);
private static BshInterpreter interpreter = setupInterpreter ();
public static void main (String[] args) throws IOException, EvalError
doProcessOptions (AcrfExtractorTui.class, args);
Timing timing = new Timing ();
GenericAcrfData2TokenSequence basePipe;
if (!numLabelsOption.wasInvoked ()) {
basePipe = new GenericAcrfData2TokenSequence ();
} else {
basePipe = new GenericAcrfData2TokenSequence (numLabelsOption.value);
if (!useTokenText.value) {
basePipe.setLabelsAtEnd (labelsAtEnd.value);
Pipe tokPipe = new SerialPipes (new Pipe[] {
(trainingIsList.value ? new Input2CharSequence () : (Pipe) new Noop ()),
Iterator<Instance> trainSource = constructIterator(trainFile.value, dataDir.value, trainingIsList.value);
Iterator<Instance> testSource;
if (testFile.wasInvoked ()) {
testSource = constructIterator (testFile.value, dataDir.value, trainingIsList.value);
} else {
testSource = null;
ACRF.Template[] tmpls = parseModelFile (modelFile.value);
ACRFExtractorTrainer trainer = createTrainer (trainerOption.value);
ACRFEvaluator eval = createEvaluator (evalOption.value);
ExtractionEvaluator extractionEval = createExtractionEvaluator (extractionEvalOption.value);
Inferencer inf = createInferencer (inferencerOption.value);
Inferencer maxInf = createInferencer (maxInferencerOption.value);
trainer.setPipes (tokPipe, new TokenSequence2FeatureVectorSequence ())
.setDataSource (trainSource, testSource)
.setEvaluator (eval)
.setTemplates (tmpls)
.setInferencer (inf)
.setViterbiInferencer (maxInf)
.setCheckpointDirectory (outputPrefix.value)
.setNumCheckpointIterations (checkpointIterations.value)
.setCacheUnrolledGraphs (cacheUnrolledGraph.value)
.setUsePerTemplateTrain (perTemplateTrain.value)
.setPerTemplateIterations (pttIterations.value);
logger.info ("Starting training...");
ACRFExtractor extor = trainer.trainExtractor ();
timing.tick ("Training");
FileUtils.writeGzippedObject (new File (outputPrefix.value, "extor.ser.gz"), extor);
timing.tick ("Serializing");
InstanceList testing = trainer.getTestingData ();
if (testing != null) {
eval.test (extor.getAcrf (), testing, "Final results");
if ((extractionEval != null) && (testing != null)) {
Extraction extraction = extor.extract (testing);
extractionEval.evaluate (extraction);
timing.tick ("Evaluting");
System.out.println ("Total time (ms) = " + timing.elapsedTime ());
private static BshInterpreter setupInterpreter ()
BshInterpreter interpreter = CommandOption.getInterpreter ();
try {
interpreter.eval ("import edu.umass.cs.mallet.base.extract.*");
interpreter.eval ("import edu.umass.cs.mallet.grmm.inference.*");
interpreter.eval ("import edu.umass.cs.mallet.grmm.learning.*");
interpreter.eval ("import edu.umass.cs.mallet.grmm.learning.templates.*");
interpreter.eval ("import edu.umass.cs.mallet.grmm.learning.extract.*");
} catch (EvalError e) {
throw new RuntimeException (e);
return interpreter;
private static Iterator<Instance> constructIterator (File trainFile, File dataDir, boolean isList) throws IOException
if (isList) {
return new FileListIterator (trainFile, dataDir, null, null, true);
} else {
return new LineGroupIterator (new FileReader (trainFile), Pattern.compile ("^\\s*$"), true);
public static ACRFEvaluator createEvaluator (String spec) throws EvalError
if (spec.indexOf ('(') >= 0) {
// assume it's Java code, and don't screw with it.
return (ACRFEvaluator) interpreter.eval (spec);
} else {
LinkedList toks = new LinkedList (Arrays.asList (spec.split ("\\s+")));
return createEvaluator (toks);
private static ExtractionEvaluator createExtractionEvaluator (String spec) throws EvalError
if (spec.indexOf ('(') >= 0) {
// assume it's Java code, and don't screw with it.
return (ExtractionEvaluator) interpreter.eval (spec);
} else {
spec = "new "+spec+"Evaluator ()";
return (ExtractionEvaluator) interpreter.eval (spec);
private static ACRFEvaluator createEvaluator (LinkedList toks)
String type = (String) toks.removeFirst ();
if (type.equalsIgnoreCase ("SEGMENT")) {
int slice = Integer.parseInt ((String) toks.removeFirst ());
if (toks.size() % 2 != 0)
throw new RuntimeException ("Error in --eval "+evalOption.value+": Every start tag must have a continue.");
int numTags = toks.size () / 2;
String[] startTags = new String [numTags];
String[] continueTags = new String [numTags];
for (int i = 0; i < numTags; i++) {
startTags[i] = (String) toks.removeFirst ();
continueTags[i] = (String) toks.removeFirst ();
return new MultiSegmentationEvaluatorACRF (startTags, continueTags, slice);
} else if (type.equalsIgnoreCase ("LOG")) {
return new DefaultAcrfTrainer.LogEvaluator ();
} else if (type.equalsIgnoreCase ("SERIAL")) {
List evals = new ArrayList ();
while (!toks.isEmpty ()) {
evals.add (createEvaluator (toks));
return new AcrfSerialEvaluator (evals);
} else {
throw new RuntimeException ("Error in --eval "+evalOption.value+": illegal evaluator "+type);
private static ACRFExtractorTrainer createTrainer (String spec) throws EvalError
String cmd;
if (spec.indexOf ('(') >= 0) {
// assume it's Java code, and don't screw with it.
cmd = spec;
} else if (spec.endsWith ("Trainer")) {
cmd = "new "+spec+"()";
} else {
cmd = "new "+spec+"Trainer()";
// Return whatever the Java code says to
Object trainer = interpreter.eval (cmd);
if (trainer instanceof ACRFExtractorTrainer)
return (ACRFExtractorTrainer) trainer;
else if (trainer instanceof DefaultAcrfTrainer)
return new ACRFExtractorTrainer ().setTrainingMethod ((ACRFTrainer) trainer);
else throw new RuntimeException ("Don't know what to do with trainer "+trainer);
private static Inferencer createInferencer (String spec) throws EvalError
String cmd;
if (spec.indexOf ('(') >= 0) {
// assume it's Java code, and don't screw with it.
cmd = spec;
} else {
cmd = "new "+spec+"()";
// Return whatever the Java code says to
Object inf = interpreter.eval (cmd);
if (inf instanceof Inferencer)
return (Inferencer) inf;
else throw new RuntimeException ("Don't know what to do with inferencer "+inf);
public static void doProcessOptions (Class childClass, String[] args)
CommandOption.List options = new CommandOption.List ("", new CommandOption[0]);
options.add (childClass);
options.process (args);
options.logOptions (Logger.getLogger (""));
private static ACRF.Template[] parseModelFile (File mdlFile) throws IOException, EvalError
BufferedReader in = new BufferedReader (new FileReader (mdlFile));
List tmpls = new ArrayList ();
String line = in.readLine ();
while (line != null) {
Object tmpl = interpreter.eval (line);
if (!(tmpl instanceof ACRF.Template)) {
throw new RuntimeException ("Error in "+mdlFile+" line "+in.toString ()+":\n Object "+tmpl+" not a template");
tmpls.add (tmpl);
line = in.readLine ();
return (ACRF.Template[]) tmpls.toArray (new ACRF.Template [0]);