// CRFClassifier -- a probabilistic (CRF) sequence model, mainly used for NER.
// Copyright (c) 2002-2014 The Board of Trustees of
// The Leland Stanford Junior University. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software Foundation,
// Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
//
// For more information, bug reports, fixes, contact:
// Christopher Manning
// Dept of Computer Science, Gates 1A
// Stanford CA 94305-9010
// USA
// Support/Questions: java-nlp-user@lists.stanford.edu
// Licensing: java-nlp-support@lists.stanford.edu
package edu.stanford.nlp.ie.crf;
import edu.stanford.nlp.ie.*;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.objectbank.ObjectBank;
import edu.stanford.nlp.optimization.*;
import edu.stanford.nlp.optimization.Function;
import edu.stanford.nlp.sequences.*;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.*;
import java.io.*;
import java.lang.reflect.InvocationTargetException;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.*;
import java.util.regex.*;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
/**
* Class for Sequence Classification using a Conditional Random Field model.
* The code has functionality for different document formats, but when
* using the standard {@link edu.stanford.nlp.sequences.ColumnDocumentReaderAndWriter} for training
* or testing models, input files are expected to
* be one token per line with the columns indicating things like the word,
* POS, chunk, and answer class. The default for
* <code>ColumnDocumentReaderAndWriter</code> training data is 3 column input,
* with the columns containing a word, its POS, and its gold class, but
* this can be specified via the <code>map</code> property.
* </p><p>
* When run on a file with <code>-textFile</code>,
* the file is assumed to be plain English text (or perhaps simple HTML/XML),
* and a reasonable attempt is made at English tokenization by
* {@link PlainTextDocumentReaderAndWriter}. The class used to read
* the text can be changed with -plainTextDocumentReaderAndWriter.
* Extra options can be supplied to the tokenizer using the
* -tokenizeOptions flag.
* </p><p>
* To read from stdin, use the flag -readStdin. The same
* reader/writer will be used as for -textFile.
* </p>
* <b>Typical command-line usage</b>
* <p>For running a trained model with a provided serialized classifier on a
* text file: <p>
* <code>
* java -mx500m edu.stanford.nlp.ie.crf.CRFClassifier -loadClassifier
* conll.ner.gz -textFile samplesentences.txt
* </code>
* <p>
* When specifying all parameters in a properties file (train, test, or
* runtime):
* <p>
* <code>
* java -mx1g edu.stanford.nlp.ie.crf.CRFClassifier -prop propFile
* </code>
* <p>
* To train and test a simple NER model from the command line:<br>
* <code>java -mx1000m edu.stanford.nlp.ie.crf.CRFClassifier
* -trainFile trainFile -testFile testFile -macro > output </code>
* </p>
* <p>
* To train with multiple files: <br>
* <code>java -mx1000m edu.stanford.nlp.ie.crf.CRFClassifier
* -trainFileList file1,file2,... -testFile testFile -macro > output</code>
* </p>
* <p>
* To test on multiple files, use the -testFiles option and a comma
* separated list.
* </p>
* Features are defined by a {@link edu.stanford.nlp.sequences.FeatureFactory}.
* {@link NERFeatureFactory} is used by default, and you should look
* there for feature templates and properties or flags that will cause
* certain features to be used when training an NER classifier. There
* are also various feature factories for Chinese word segmentation
* such as {@link edu.stanford.nlp.wordseg.ChineseSegmenterFeatureFactory}.
* Features are specified either
* by a Properties file (which is the recommended method) or by flags on the
* command line. The flags are read into a {@link SeqClassifierFlags} object,
* which the user need not be concerned with, unless wishing to add new
* features. </p> CRFClassifier may also be used programmatically. When creating
* a new instance, you <i>must</i> specify a Properties object. You may then
* call train methods to train a classifier, or load a classifier. The other way
* to get a CRFClassifier is to deserialize one via the static
* {@link CRFClassifier#getClassifier(String)} methods, which return a
* deserialized classifier. You may then tag (classify the items of) documents
* using either the assorted <code>classify()</code> or the assorted
* <code>classify</code> methods in {@link AbstractSequenceClassifier}.
* Probabilities assigned by the CRF can be interrogated using either the
* <code>printProbsDocument()</code> or <code>getCliqueTrees()</code> methods.
*
* @author Jenny Finkel
* @author Sonal Gupta (made the class generic)
* @author Mengqiu Wang (LOP implementation and non-linear CRF implementation)
* TODO(mengqiu) need to move the embedding lookup and capitalization features into a FeatureFactory
*/
public class CRFClassifier<IN extends CoreMap> extends AbstractSequenceClassifier<IN> {
List<Index<CRFLabel>> labelIndices;
Index<String> tagIndex;
Pair<double[][], double[][]> entityMatrices;
CliquePotentialFunction cliquePotentialFunction;
HasCliquePotentialFunction cliquePotentialFunctionHelper;
/** Parameter weights of the classifier. */
double[][] weights;
/** index the features of CRF */
Index<String> featureIndex;
/** caches the featureIndex */
int[] map;
Random random = new Random(2147483647L);
Index<Integer> nodeFeatureIndicesMap;
Index<Integer> edgeFeatureIndicesMap;
Map<String, double[]> embeddings; // = null;
/**
* Name of default serialized classifier resource to look for in a jar file.
*/
public static final String DEFAULT_CLASSIFIER = "/edu/stanford/nlp/models/ner/english.all.3class.distsim.crf.ser.gz";
private static final boolean VERBOSE = false;
/**
* Fields for grouping features
*/
Pattern suffixPatt = Pattern.compile(".+?((?:-[A-Z]+)+)\\|.*C");
Index<String> templateGroupIndex;
Map<Integer, Integer> featureIndexToTemplateIndex;
// Label dictionary for fast decoding
LabelDictionary labelDictionary;
// List selftraindatums = new ArrayList();
protected CRFClassifier() {
super(new SeqClassifierFlags());
}
public CRFClassifier(Properties props) {
super(props);
}
public CRFClassifier(SeqClassifierFlags flags) {
super(flags);
}
/**
* Makes a copy of the crf classifier
*/
public CRFClassifier(CRFClassifier<IN> crf) {
super(crf.flags);
this.windowSize = crf.windowSize;
this.featureFactories = crf.featureFactories;
this.pad = crf.pad;
this.knownLCWords = (crf.knownLCWords != null) ? Generics.<String>newHashSet(crf.knownLCWords) : null;
this.featureIndex = (crf.featureIndex != null) ? new HashIndex<String>(crf.featureIndex.objectsList()) : null;
this.classIndex = (crf.classIndex != null) ? new HashIndex<String>(crf.classIndex.objectsList()) : null;
if (crf.labelIndices != null) {
this.labelIndices = new ArrayList<Index<CRFLabel>>(crf.labelIndices.size());
for (int i = 0; i < crf.labelIndices.size(); i++) {
this.labelIndices.add((crf.labelIndices.get(i) != null) ? new HashIndex<CRFLabel>(crf.labelIndices.get(i).objectsList()) : null);
}
} else {
this.labelIndices = null;
}
this.cliquePotentialFunction = crf.cliquePotentialFunction;
}
/**
* Returns the total number of weights associated with this classifier.
*
* @return number of weights
*/
public int getNumWeights() {
if (weights == null) return 0;
int numWeights = 0;
for (double[] wts : weights) {
numWeights += wts.length;
}
return numWeights;
}
/**
* Get index of featureType for feature indexed by i. (featureType index is
* used to index labelIndices to get labels.)
*
* @param i
* feature index
* @return index of featureType
*/
private int getFeatureTypeIndex(int i) {
return getFeatureTypeIndex(featureIndex.get(i));
}
/**
* Get index of featureType for feature based on the feature string
* (featureType index used to index labelIndices to get labels)
*
* @param feature
* feature string
* @return index of featureType
*/
private static int getFeatureTypeIndex(String feature) {
if (feature.endsWith("|C")) {
return 0;
} else if (feature.endsWith("|CpC")) {
return 1;
} else if (feature.endsWith("|Cp2C")) {
return 2;
} else if (feature.endsWith("|Cp3C")) {
return 3;
} else if (feature.endsWith("|Cp4C")) {
return 4;
} else if (feature.endsWith("|Cp5C")) {
return 5;
} else {
throw new RuntimeException("Unknown feature type " + feature);
}
}
/**
* Scales the weights of this CRFClassifier by the specified weight.
*
* @param scale The scale to multiply by
*/
public void scaleWeights(double scale) {
for (int i = 0; i < weights.length; i++) {
for (int j = 0; j < weights[i].length; j++) {
weights[i][j] *= scale;
}
}
}
/**
* Combines weights from another crf (scaled by weight) into this CRF's
* weights (assumes that this CRF's indices have already been updated to
* include features/labels from the other crf)
*
* @param crf Other CRF whose weights to combine into this CRF
* @param weight Amount to scale the other CRF's weights by
*/
private void combineWeights(CRFClassifier<IN> crf, double weight) {
int numFeatures = featureIndex.size();
int oldNumFeatures = weights.length;
// Create a map of other crf labels to this crf labels
Map<CRFLabel, CRFLabel> crfLabelMap = Generics.newHashMap();
for (int i = 0; i < crf.labelIndices.size(); i++) {
for (int j = 0; j < crf.labelIndices.get(i).size(); j++) {
CRFLabel labels = crf.labelIndices.get(i).get(j);
int[] newLabelIndices = new int[i + 1];
for (int ci = 0; ci <= i; ci++) {
String classLabel = crf.classIndex.get(labels.getLabel()[ci]);
newLabelIndices[ci] = this.classIndex.indexOf(classLabel);
}
CRFLabel newLabels = new CRFLabel(newLabelIndices);
crfLabelMap.put(labels, newLabels);
int k = this.labelIndices.get(i).indexOf(newLabels); // IMPORTANT: the indexing is needed, even when not printed out!
// System.err.println("LabelIndices " + i + " " + labels + ": " + j +
// " mapped to " + k);
}
}
// Create map of featureIndex to featureTypeIndex
map = new int[numFeatures];
for (int i = 0; i < numFeatures; i++) {
map[i] = getFeatureTypeIndex(i);
}
// Create new weights
double[][] newWeights = new double[numFeatures][];
for (int i = 0; i < numFeatures; i++) {
int length = labelIndices.get(map[i]).size();
newWeights[i] = new double[length];
if (i < oldNumFeatures) {
assert (length >= weights[i].length);
System.arraycopy(weights[i], 0, newWeights[i], 0, weights[i].length);
}
}
weights = newWeights;
// Get original weight indices from other crf and weight them in
// depending on the type of the feature, different number of weights is
// associated with it
for (int i = 0; i < crf.weights.length; i++) {
String feature = crf.featureIndex.get(i);
int newIndex = featureIndex.indexOf(feature);
// Check weights are okay dimension
if (weights[newIndex].length < crf.weights[i].length) {
throw new RuntimeException("Incompatible CRFClassifier: weight length mismatch for feature " + newIndex + ": "
+ featureIndex.get(newIndex) + " (also feature " + i + ": " + crf.featureIndex.get(i) + ") " + ", len1="
+ weights[newIndex].length + ", len2=" + crf.weights[i].length);
}
int featureTypeIndex = map[newIndex];
for (int j = 0; j < crf.weights[i].length; j++) {
CRFLabel labels = crf.labelIndices.get(featureTypeIndex).get(j);
CRFLabel newLabels = crfLabelMap.get(labels);
int k = this.labelIndices.get(featureTypeIndex).indexOf(newLabels);
weights[newIndex][k] += crf.weights[i][j] * weight;
}
}
}
/**
* Combines weighted crf with this crf
*
* @param crf
* @param weight
*/
public void combine(CRFClassifier<IN> crf, double weight) {
Timing timer = new Timing();
// Check the CRFClassifiers are compatible
if (!this.pad.equals(crf.pad)) {
throw new RuntimeException("Incompatible CRFClassifier: pad does not match");
}
if (this.windowSize != crf.windowSize) {
throw new RuntimeException("Incompatible CRFClassifier: windowSize does not match");
}
if (this.labelIndices.size() != crf.labelIndices.size()) {
// Should match since this should be same as the windowSize
throw new RuntimeException("Incompatible CRFClassifier: labelIndices length does not match");
}
this.classIndex.addAll(crf.classIndex.objectsList());
// Combine weights of the other classifier with this classifier,
// weighing the other classifier's weights by weight
// First merge the feature indicies
int oldNumFeatures1 = this.featureIndex.size();
int oldNumFeatures2 = crf.featureIndex.size();
int oldNumWeights1 = this.getNumWeights();
int oldNumWeights2 = crf.getNumWeights();
this.featureIndex.addAll(crf.featureIndex.objectsList());
this.knownLCWords.addAll(crf.knownLCWords);
assert (weights.length == oldNumFeatures1);
// Combine weights of this classifier with other classifier
for (int i = 0; i < labelIndices.size(); i++) {
this.labelIndices.get(i).addAll(crf.labelIndices.get(i).objectsList());
}
System.err.println("Combining weights: will automatically match labelIndices");
combineWeights(crf, weight);
int numFeatures = featureIndex.size();
int numWeights = getNumWeights();
long elapsedMs = timer.stop();
System.err.println("numFeatures: orig1=" + oldNumFeatures1 + ", orig2=" + oldNumFeatures2 + ", combined="
+ numFeatures);
System.err
.println("numWeights: orig1=" + oldNumWeights1 + ", orig2=" + oldNumWeights2 + ", combined=" + numWeights);
System.err.println("Time to combine CRFClassifier: " + Timing.toSecondsString(elapsedMs) + " seconds");
}
public void dropFeaturesBelowThreshold(double threshold) {
Index<String> newFeatureIndex = new HashIndex<String>();
for (int i = 0; i < weights.length; i++) {
double smallest = weights[i][0];
double biggest = weights[i][0];
for (int j = 1; j < weights[i].length; j++) {
if (weights[i][j] > biggest) {
biggest = weights[i][j];
}
if (weights[i][j] < smallest) {
smallest = weights[i][j];
}
if (biggest - smallest > threshold) {
newFeatureIndex.add(featureIndex.get(i));
break;
}
}
}
int[] newMap = new int[newFeatureIndex.size()];
for (int i = 0; i < newMap.length; i++) {
int index = featureIndex.indexOf(newFeatureIndex.get(i));
newMap[i] = map[index];
}
map = newMap;
featureIndex = newFeatureIndex;
}
/**
* Convert a document List into arrays storing the data features and labels.
* This is used at test time.
*
* @param document Testing documents
* @return A Triple, where the first element is an int[][][] representing the
* data, the second element is an int[] representing the labels, and
* the third element is a double[][][] representing the feature values (optionally null)
*/
public Triple<int[][][], int[], double[][][]> documentToDataAndLabels(List<IN> document) {
int docSize = document.size();
// first index is position in the document also the index of the
// clique/factor table
// second index is the number of elements in the clique/window these
// features are for (starting with last element)
// third index is position of the feature in the array that holds them.
// An element in data[j][k][m] is the feature index of the mth feature occurring in
// position k of the jth clique
int[][][] data = new int[docSize][windowSize][];
double[][][] featureVals = new double[docSize][windowSize][];
// index is the position in the document.
// element in labels[j] is the index of the correct label (if it exists) at
// position j of document
int[] labels = new int[docSize];
if (flags.useReverse) {
Collections.reverse(document);
}
// System.err.println("docSize:"+docSize);
for (int j = 0; j < docSize; j++) {
CRFDatum<List<String>, CRFLabel> d = makeDatum(document, j, featureFactories);
List<List<String>> features = d.asFeatures();
List<double[]> featureValList = d.asFeatureVals();
for (int k = 0, fSize = features.size(); k < fSize; k++) {
Collection<String> cliqueFeatures = features.get(k);
data[j][k] = new int[cliqueFeatures.size()];
if(featureValList != null) { // CRFBiasedClassifier.makeDatum causes null
featureVals[j][k] = featureValList.get(k);
}
int m = 0;
for (String feature : cliqueFeatures) {
int index = featureIndex.indexOf(feature);
if (index >= 0) {
data[j][k][m] = index;
m++;
} else {
// this is where we end up when we do feature threshold cutoffs
}
}
if (m < data[j][k].length) {
int[] f = new int[m];
System.arraycopy(data[j][k], 0, f, 0, m);
data[j][k] = f;
if (featureVals[j][k] != null) {
double[] fVal = new double[m];
System.arraycopy(featureVals[j][k], 0, fVal, 0, m);
featureVals[j][k] = fVal;
}
}
}
IN wi = document.get(j);
labels[j] = classIndex.indexOf(wi.get(CoreAnnotations.AnswerAnnotation.class));
}
if (flags.useReverse) {
Collections.reverse(document);
}
return new Triple<int[][][], int[], double[][][]>(data, labels, featureVals);
}
private int[][][] transformDocData(int[][][] docData) {
int[][][] transData = new int[docData.length][][];
for (int i = 0; i < docData.length; i++) {
transData[i] = new int[docData[i].length][];
for (int j = 0; j < docData[i].length; j++) {
int[] cliqueFeatures = docData[i][j];
transData[i][j] = new int[cliqueFeatures.length];
for (int n = 0; n < cliqueFeatures.length; n++) {
int transFeatureIndex = -1;
if (j == 0) {
transFeatureIndex = nodeFeatureIndicesMap.indexOf(cliqueFeatures[n]);
if (transFeatureIndex == -1)
throw new RuntimeException("node cliqueFeatures[n]="+cliqueFeatures[n]+" not found, nodeFeatureIndicesMap.size="+nodeFeatureIndicesMap.size());
} else {
transFeatureIndex = edgeFeatureIndicesMap.indexOf(cliqueFeatures[n]);
if (transFeatureIndex == -1)
throw new RuntimeException("edge cliqueFeatures[n]="+cliqueFeatures[n]+" not found, edgeFeatureIndicesMap.size="+edgeFeatureIndicesMap.size());
}
transData[i][j][n] = transFeatureIndex;
}
}
}
return transData;
}
public void printLabelInformation(String testFile, DocumentReaderAndWriter<IN> readerAndWriter) throws Exception {
ObjectBank<List<IN>> documents = makeObjectBankFromFile(testFile, readerAndWriter);
for (List<IN> document : documents) {
printLabelValue(document);
}
}
public void printLabelValue(List<IN> document) {
if (flags.useReverse) {
Collections.reverse(document);
}
NumberFormat nf = new DecimalFormat();
List<String> classes = new ArrayList<String>();
for (int i = 0; i < classIndex.size(); i++) {
classes.add(classIndex.get(i));
}
String[] columnHeaders = classes.toArray(new String[classes.size()]);
// System.err.println("docSize:"+docSize);
for (int j = 0; j < document.size(); j++) {
System.out.println("--== " + document.get(j).get(CoreAnnotations.TextAnnotation.class) + " ==--");
List<String[]> lines = new ArrayList<String[]>();
List<String> rowHeaders = new ArrayList<String>();
List<String> line = new ArrayList<String>();
for (int p = 0; p < labelIndices.size(); p++) {
if (j + p >= document.size()) {
continue;
}
CRFDatum<List<String>, CRFLabel> d = makeDatum(document, j + p, featureFactories);
List<List<String>> features = d.asFeatures();
for (int k = p, fSize = features.size(); k < fSize; k++) {
Collection<String> cliqueFeatures = features.get(k);
for (String feature : cliqueFeatures) {
int index = featureIndex.indexOf(feature);
if (index >= 0) {
// line.add(feature+"["+(-p)+"]");
rowHeaders.add(feature + '[' + (-p) + ']');
double[] values = new double[labelIndices.get(0).size()];
for (CRFLabel label : labelIndices.get(k)) {
int[] l = label.getLabel();
double v = weights[index][labelIndices.get(k).indexOf(label)];
values[l[l.length - 1 - p]] += v;
}
for (double value : values) {
line.add(nf.format(value));
}
lines.add(line.toArray(new String[line.size()]));
line = new ArrayList<String>();
}
}
}
// lines.add(Collections.<String>emptyList());
System.out.println(StringUtils.makeTextTable(lines.toArray(new String[lines.size()][0]), rowHeaders
.toArray(new String[rowHeaders.size()]), columnHeaders, 0, 1, true));
System.out.println();
}
// System.err.println(edu.stanford.nlp.util.StringUtils.join(lines,"\n"));
}
if (flags.useReverse) {
Collections.reverse(document);
}
}
/**
* Convert an ObjectBank to arrays of data features and labels.
* This version is used at training time.
*
* @return A Triple, where the first element is an int[][][][] representing the
* data, the second element is an int[][] representing the labels, and
* the third element is a double[][][][] representing the feature values
* which could be optionally left as null.
*/
public Triple<int[][][][], int[][], double[][][][]> documentsToDataAndLabels(Collection<List<IN>> documents) {
// first index is the number of the document
// second index is position in the document also the index of the
// clique/factor table
// third index is the number of elements in the clique/window these features
// are for (starting with last element)
// fourth index is position of the feature in the array that holds them
// element in data[i][j][k][m] is the index of the mth feature occurring in
// position k of the jth clique of the ith document
// int[][][][] data = new int[documentsSize][][][];
List<int[][][]> data = new ArrayList<int[][][]>();
List<double[][][]> featureVal = new ArrayList<double[][][]>();
// first index is the number of the document
// second index is the position in the document
// element in labels[i][j] is the index of the correct label (if it exists)
// at position j in document i
// int[][] labels = new int[documentsSize][];
List<int[]> labels = new ArrayList<int[]>();
int numDatums = 0;
for (List<IN> doc : documents) {
Triple<int[][][], int[], double[][][]> docTriple = documentToDataAndLabels(doc);
data.add(docTriple.first());
labels.add(docTriple.second());
if (flags.useEmbedding)
featureVal.add(docTriple.third());
numDatums += doc.size();
}
System.err.println("numClasses: " + classIndex.size() + ' ' + classIndex);
System.err.println("numDocuments: " + data.size());
System.err.println("numDatums: " + numDatums);
System.err.println("numFeatures: " + featureIndex.size());
printFeatures();
double[][][][] featureValArr = null;
if (flags.useEmbedding)
featureValArr = featureVal.toArray(new double[data.size()][][][]);
return new Triple<int[][][][], int[][], double[][][][]>(
data.toArray(new int[data.size()][][][]),
labels.toArray(new int[labels.size()][]),
featureValArr);
}
/**
* Convert an ObjectBank to corresponding collection of data features and
* labels. This version is used at test time.
*
* @return A List of pairs, one for each document, where the first element is
* an int[][][] representing the data and the second element is an
* int[] representing the labels.
*/
public List<Triple<int[][][], int[], double[][][]>> documentsToDataAndLabelsList(Collection<List<IN>> documents) {
int numDatums = 0;
List<Triple<int[][][], int[], double[][][]>> docList = new ArrayList<Triple<int[][][], int[], double[][][]>>();
for (List<IN> doc : documents) {
Triple<int[][][], int[], double[][][]> docTriple = documentToDataAndLabels(doc);
docList.add(docTriple);
numDatums += doc.size();
}
System.err.println("numClasses: " + classIndex.size() + ' ' + classIndex);
System.err.println("numDocuments: " + docList.size());
System.err.println("numDatums: " + numDatums);
System.err.println("numFeatures: " + featureIndex.size());
return docList;
}
protected void printFeatures() {
if (flags.printFeatures == null) {
return;
}
try {
String enc = flags.inputEncoding;
if (flags.inputEncoding == null) {
System.err.println("flags.inputEncoding doesn't exist, using UTF-8 as default");
enc = "UTF-8";
}
PrintWriter pw = new PrintWriter(new OutputStreamWriter(new FileOutputStream("features-" + flags.printFeatures
+ ".txt"), enc), true);
for (String feat : featureIndex) {
pw.println(feat);
}
pw.close();
} catch (IOException ioe) {
ioe.printStackTrace();
}
}
/**
* This routine builds the {@code labelIndices} which give the
* empirically legal label sequences (of length (order) at most
* {@code windowSize}) and the {@code classIndex}, which indexes
* known answer classes.
*
* @param ob The training data: Read from an ObjectBank, each item in it is a
* {@code List<CoreLabel>}.
*/
protected void makeAnswerArraysAndTagIndex(Collection<List<IN>> ob) {
boolean useFeatureCountThresh = flags.featureCountThresh > 1;
Set<String>[] featureIndices = new HashSet[windowSize];
Map<String, Integer>[] featureCountIndices = null;
for (int i = 0; i < windowSize; i++) {
featureIndices[i] = Generics.newHashSet();
}
if (useFeatureCountThresh) {
featureCountIndices = new HashMap[windowSize];
for (int i = 0; i < windowSize; i++) {
featureCountIndices[i] = Generics.newHashMap();
}
}
labelIndices = new ArrayList<Index<CRFLabel>>(windowSize);
for (int i = 0; i < windowSize; i++) {
labelIndices.add(new HashIndex<CRFLabel>());
}
Index<CRFLabel> labelIndex = labelIndices.get(windowSize - 1);
if (classIndex == null)
classIndex = new HashIndex<String>();
// classIndex.add("O");
classIndex.add(flags.backgroundSymbol);
Set<String>[] seenBackgroundFeatures = new HashSet[2];
seenBackgroundFeatures[0] = Generics.newHashSet();
seenBackgroundFeatures[1] = Generics.newHashSet();
int wordCount = 0;
if (flags.labelDictionaryCutoff > 0) {
this.labelDictionary = new LabelDictionary();
}
for (List<IN> doc : ob) {
if (flags.useReverse) {
Collections.reverse(doc);
}
// create the full set of labels in classIndex
// note: update to use addAll later
for (IN token : doc) {
wordCount++;
String ans = token.get(CoreAnnotations.AnswerAnnotation.class);
if (ans == null || ans.equals("")) {
throw new IllegalArgumentException("Word " + wordCount + " (\"" + token.get(CoreAnnotations.TextAnnotation.class) + "\") has a blank answer");
}
classIndex.add(ans);
if (labelDictionary != null) {
String observation = token.get(CoreAnnotations.TextAnnotation.class);
labelDictionary.increment(observation, ans);
}
}
for (int j = 0, docSize = doc.size(); j < docSize; j++) {
CRFDatum<List<String>, CRFLabel> d = makeDatum(doc, j, featureFactories);
labelIndex.add(d.label());
List<List<String>> features = d.asFeatures();
for (int k = 0, fSize = features.size(); k < fSize; k++) {
Collection<String> cliqueFeatures = features.get(k);
if (k < 2 && flags.removeBackgroundSingletonFeatures) {
String ans = doc.get(j).get(CoreAnnotations.AnswerAnnotation.class);
boolean background = ans.equals(flags.backgroundSymbol);
if (k == 1 && j > 0 && background) {
ans = doc.get(j - 1).get(CoreAnnotations.AnswerAnnotation.class);
background = ans.equals(flags.backgroundSymbol);
}
if (background) {
for (String f : cliqueFeatures) {
if (useFeatureCountThresh) {
if (!featureCountIndices[k].containsKey(f)) {
if (seenBackgroundFeatures[k].contains(f)) {
seenBackgroundFeatures[k].remove(f);
featureCountIndices[k].put(f, 1);
} else {
seenBackgroundFeatures[k].add(f);
}
}
} else {
if (!featureIndices[k].contains(f)) {
if (seenBackgroundFeatures[k].contains(f)) {
seenBackgroundFeatures[k].remove(f);
featureIndices[k].add(f);
} else {
seenBackgroundFeatures[k].add(f);
}
}
}
}
} else {
seenBackgroundFeatures[k].removeAll(cliqueFeatures);
if (useFeatureCountThresh) {
Map<String, Integer> fCountIndex = featureCountIndices[k];
for (String f: cliqueFeatures) {
if (fCountIndex.containsKey(f))
fCountIndex.put(f, fCountIndex.get(f)+1);
else
fCountIndex.put(f, 1);
}
} else {
featureIndices[k].addAll(cliqueFeatures);
}
}
} else {
if (useFeatureCountThresh) {
Map<String, Integer> fCountIndex = featureCountIndices[k];
for (String f: cliqueFeatures) {
if (fCountIndex.containsKey(f))
fCountIndex.put(f, fCountIndex.get(f)+1);
else
fCountIndex.put(f, 1);
}
} else {
featureIndices[k].addAll(cliqueFeatures);
}
}
}
}
if (flags.useReverse) {
Collections.reverse(doc);
}
}
if (useFeatureCountThresh) {
int numFeatures = 0;
for (int i = 0; i < windowSize; i++) {
numFeatures += featureCountIndices[i].size();
}
System.err.println("Before feature count thresholding, numFeatures = " + numFeatures);
for (int i = 0; i < windowSize; i++) {
for(Iterator<Map.Entry<String, Integer>> it = featureCountIndices[i].entrySet().iterator(); it.hasNext(); ) {
Map.Entry<String, Integer> entry = it.next();
if(entry.getValue() < flags.featureCountThresh) {
it.remove();
}
}
featureIndices[i].addAll(featureCountIndices[i].keySet());
featureCountIndices[i] = null;
}
}
int numFeatures = 0;
for (int i = 0; i < windowSize; i++) {
numFeatures += featureIndices[i].size();
}
System.err.println("numFeatures = " + numFeatures);
featureIndex = new HashIndex<String>();
map = new int[numFeatures];
if (flags.groupByFeatureTemplate) {
templateGroupIndex = new HashIndex<String>();
featureIndexToTemplateIndex = new HashMap<Integer, Integer>();
}
for (int i = 0; i < windowSize; i++) {
Index<Integer> featureIndexMap = new HashIndex<Integer>();
featureIndex.addAll(featureIndices[i]);
for (String str : featureIndices[i]) {
int index = featureIndex.indexOf(str);
map[index] = i;
featureIndexMap.add(index);
// grouping features by template
if (flags.groupByFeatureTemplate) {
Matcher m = suffixPatt.matcher(str);
String groupSuffix = "NoTemplate";
if (m.matches()) {
groupSuffix = m.group(1);
}
groupSuffix += "-c:"+i;
int groupIndex = templateGroupIndex.addToIndex(groupSuffix);
featureIndexToTemplateIndex.put(index, groupIndex);
}
}
// todo [cdm 2014]: Talk to Mengqiu about this; it seems like it only supports first order CRF
if (i == 0) {
nodeFeatureIndicesMap = featureIndexMap;
System.err.println("setting nodeFeatureIndicesMap, size="+nodeFeatureIndicesMap.size());
} else {
edgeFeatureIndicesMap = featureIndexMap;
System.err.println("setting edgeFeatureIndicesMap, size="+edgeFeatureIndicesMap.size());
}
}
if (flags.numOfFeatureSlices > 0) {
System.err.println("Taking " + flags.numOfFeatureSlices + " out of " + flags.totalFeatureSlice + " slices of node features for training");
pruneNodeFeatureIndices(flags.totalFeatureSlice, flags.numOfFeatureSlices);
}
if (flags.useObservedSequencesOnly) {
for (int i = 0, liSize = labelIndex.size(); i < liSize; i++) {
CRFLabel label = labelIndex.get(i);
for (int j = windowSize - 2; j >= 0; j--) {
label = label.getOneSmallerLabel();
labelIndices.get(j).add(label);
}
}
} else {
for (int i = 0; i < labelIndices.size(); i++) {
labelIndices.set(i, allLabels(i + 1, classIndex));
}
}
if (VERBOSE) {
for (int i = 0, fiSize = featureIndex.size(); i < fiSize; i++) {
System.out.println(i + ": " + featureIndex.get(i));
}
}
if (labelDictionary != null) {
labelDictionary.lock(flags.labelDictionaryCutoff, classIndex);
}
}
protected static Index<CRFLabel> allLabels(int window, Index<String> classIndex) {
int[] label = new int[window];
// cdm 2005: array initialization isn't necessary: JLS (3rd ed.) 4.12.5
// Arrays.fill(label, 0);
int numClasses = classIndex.size();
Index<CRFLabel> labelIndex = new HashIndex<CRFLabel>();
OUTER: while (true) {
CRFLabel l = new CRFLabel(label);
labelIndex.add(l);
int[] label1 = new int[window];
System.arraycopy(label, 0, label1, 0, label.length);
label = label1;
for (int j = 0; j < label.length; j++) {
label[j]++;
if (label[j] >= numClasses) {
label[j] = 0;
if (j == label.length - 1) {
break OUTER;
}
} else {
break;
}
}
}
return labelIndex;
}
/**
* Makes a CRFDatum by producing features and a label from input data at a
* specific position, using the provided factory.
*
* @param info The input data
* @param loc The position to build a datum at
* @param featureFactories The FeatureFactories to use to extract features
* @return The constructed CRFDatum
*/
public CRFDatum<List<String>, CRFLabel> makeDatum(List<IN> info, int loc,
List<FeatureFactory<IN>> featureFactories) {
// pad.set(CoreAnnotations.AnswerAnnotation.class, flags.backgroundSymbol); // cdm: isn't this unnecessary, as this is how it's initialized in AbstractSequenceClassifier.reinit?
PaddedList<IN> pInfo = new PaddedList<IN>(info, pad);
ArrayList<List<String>> features = new ArrayList<List<String>>();
List<double[]> featureVals = new ArrayList<double[]>();
// for (int i = 0; i < windowSize; i++) {
// List featuresC = new ArrayList();
// for (int j = 0; j < FeatureFactory.win[i].length; j++) {
// featuresC.addAll(featureFactory.features(info, loc,
// FeatureFactory.win[i][j]));
// }
// features.add(featuresC);
// }
// todo [cdm Aug 2012]: Since getCliques returns all cliques within its bounds, can't the for loop here be eliminated? But my first attempt to removed failed to produce identical results....
Collection<Clique> done = Generics.newHashSet();
for (int i = 0; i < windowSize; i++) {
List<String> featuresC = new ArrayList<String>();
List<Clique> windowCliques = FeatureFactory.getCliques(i, 0);
windowCliques.removeAll(done);
done.addAll(windowCliques);
double[] featureValArr = null;
if (flags.useEmbedding && i == 0) { // only activated for node features
featureValArr = makeDatumUsingEmbedding(info, loc, featureFactories, pInfo, featuresC, windowCliques);
} else {
for (Clique c : windowCliques) {
for (FeatureFactory featureFactory : featureFactories) {
featuresC.addAll(featureFactory.getCliqueFeatures(pInfo, loc, c)); //todo useless copy because of typing reasons
}
}
}
features.add(featuresC);
featureVals.add(featureValArr);
}
int[] labels = new int[windowSize];
for (int i = 0; i < windowSize; i++) {
String answer = pInfo.get(loc + i - windowSize + 1).get(CoreAnnotations.AnswerAnnotation.class);
labels[i] = classIndex.indexOf(answer);
}
printFeatureLists(pInfo.get(loc), features);
CRFDatum<List<String>, CRFLabel> d = new CRFDatum<List<String>, CRFLabel>(features, new CRFLabel(labels), featureVals);
// System.err.println(d);
return d;
}
private double[] makeDatumUsingEmbedding(List<IN> info, int loc, List<FeatureFactory<IN>> featureFactories, PaddedList<IN> pInfo, List<String> featuresC, List<Clique> windowCliques) {
double[] featureValArr;
List<double[]> embeddingList = new ArrayList<double[]>();
int concatEmbeddingLen = 0;
String currentWord = null;
for (int currLoc = loc-2; currLoc <= loc+2; currLoc++) {
double[] embedding = null;
if (currLoc >=0 && currLoc < info.size()) {
currentWord = info.get(loc).get(CoreAnnotations.TextAnnotation.class);
String word = currentWord.toLowerCase();
word = word.replaceAll("(-)?\\d+(\\.\\d*)?", "0");
if (embeddings.containsKey(word))
embedding = embeddings.get(word);
else
embedding = embeddings.get("UNKNOWN");
} else {
embedding = embeddings.get("PADDING");
}
for (int e = 0; e < embedding.length; e++) {
featuresC.add("EMBEDDING-(" + (currLoc-loc) + ")-" + e);
}
if (flags.addCapitalFeatures) {
int numOfCapitalFeatures = 4;
double[] newEmbedding = new double[embedding.length + numOfCapitalFeatures];
int currLen = embedding.length;
System.arraycopy(embedding, 0, newEmbedding, 0, currLen);
for (int e = 0; e < numOfCapitalFeatures; e++)
featuresC.add("CAPITAL-(" + (currLoc-loc) + ")-" + e);
if (currLoc >=0 && currLoc < info.size()) { // skip PADDING
// check if word is all caps
if (currentWord.toUpperCase().equals(currentWord))
newEmbedding[currLen] = 1;
else {
currLen += 1;
// check if word is all lower
if (currentWord.toLowerCase().equals(currentWord))
newEmbedding[currLen] = 1;
else {
currLen += 1;
// check first letter cap
if (Character.isUpperCase(currentWord.charAt(0)))
newEmbedding[currLen] = 1;
else {
currLen += 1;
// check if at least one non-initial letter is cap
String remainder = currentWord.substring(1);
if (!remainder.toLowerCase().equals(remainder))
newEmbedding[currLen] = 1;
}
}
}
}
embedding = newEmbedding;
}
embeddingList.add(embedding);
concatEmbeddingLen += embedding.length;
}
double[] concatEmbedding = new double[concatEmbeddingLen];
int currPos = 0;
for (double[] em: embeddingList) {
System.arraycopy(em, 0, concatEmbedding, currPos, em.length);
currPos += em.length;
}
if (flags.prependEmbedding) {
int additionalFeatureCount = 0;
for (Clique c : windowCliques) {
for (FeatureFactory featureFactory : featureFactories) {
Collection<String> fCol = featureFactory.getCliqueFeatures(pInfo, loc, c); //todo useless copy because of typing reasons
featuresC.addAll(fCol);
additionalFeatureCount += fCol.size();
}
}
featureValArr = new double[concatEmbedding.length + additionalFeatureCount];
System.arraycopy(concatEmbedding, 0, featureValArr, 0, concatEmbedding.length);
Arrays.fill(featureValArr, concatEmbedding.length, featureValArr.length, 1.0);
} else {
featureValArr = concatEmbedding;
}
if (flags.addBiasToEmbedding) {
featuresC.add("BIAS-FEATURE");
double[] newFeatureValArr = new double[featureValArr.length + 1];
System.arraycopy(featureValArr, 0, newFeatureValArr, 0, featureValArr.length);
newFeatureValArr[newFeatureValArr.length-1] = 1;
featureValArr = newFeatureValArr;
}
return featureValArr;
}
@Override
public void dumpFeatures(Collection<List<IN>> docs) {
if (flags.exportFeatures != null) {
Timing timer = new Timing();
timer.start();
CRFFeatureExporter<IN> featureExporter = new CRFFeatureExporter<IN>(this);
featureExporter.printFeatures(flags.exportFeatures, docs);
long elapsedMs = timer.stop();
System.err.println("Time to export features: " + Timing.toSecondsString(elapsedMs) + " seconds");
}
}
@Override
public List<IN> classify(List<IN> document) {
if (flags.doGibbs) {
try {
return classifyGibbs(document);
} catch (Exception e) {
throw new RuntimeException("Error running testGibbs inference!", e);
}
} else if (flags.crfType.equalsIgnoreCase("maxent")) {
return classifyMaxEnt(document);
} else {
throw new RuntimeException("Unsupported inference type: " + flags.crfType);
}
}
private List<IN> classify(List<IN> document, Triple<int[][][], int[], double[][][]> documentDataAndLabels) {
if (flags.doGibbs) {
try {
return classifyGibbs(document, documentDataAndLabels);
} catch (Exception e) {
throw new RuntimeException("Error running testGibbs inference!", e);
}
} else if (flags.crfType.equalsIgnoreCase("maxent")) {
return classifyMaxEnt(document, documentDataAndLabels);
} else {
throw new RuntimeException("Unsupported inference type: " + flags.crfType);
}
}
/**
* This method is supposed to be used by CRFClassifierEvaluator only, should not have global visibility.
* The generic {@code classifyAndWriteAnswers} omits the second argument {@code documentDataAndLabels}.
*/
void classifyAndWriteAnswers(Collection<List<IN>> documents,
List<Triple<int[][][], int[], double[][][]>> documentDataAndLabels,
PrintWriter printWriter,
DocumentReaderAndWriter<IN> readerAndWriter) throws IOException {
Timing timer = new Timing();
Counter<String> entityTP = new ClassicCounter<String>();
Counter<String> entityFP = new ClassicCounter<String>();
Counter<String> entityFN = new ClassicCounter<String>();
boolean resultsCounted = true;
int numWords = 0;
int numDocs = 0;
for (List<IN> doc : documents) {
classify(doc, documentDataAndLabels.get(numDocs));
numWords += doc.size();
writeAnswers(doc, printWriter, readerAndWriter);
resultsCounted = resultsCounted && countResults(doc, entityTP, entityFP, entityFN);
numDocs++;
}
long millis = timer.stop();
double wordspersec = numWords / (((double) millis) / 1000);
NumberFormat nf = new DecimalFormat("0.00"); // easier way!
if (!flags.suppressTestDebug)
System.err.println(StringUtils.getShortClassName(this) + " tagged " + numWords + " words in " + numDocs
+ " documents at " + nf.format(wordspersec) + " words per second.");
if (resultsCounted && !flags.suppressTestDebug) {
printResults(entityTP, entityFP, entityFN);
}
}
@Override
public SequenceModel getSequenceModel(List<IN> doc) {
Triple<int[][][], int[], double[][][]> p = documentToDataAndLabels(doc);
return getSequenceModel(p, doc);
}
private SequenceModel getSequenceModel(Triple<int[][][], int[], double[][][]> documentDataAndLabels, List<IN> document) {
return labelDictionary == null ? new TestSequenceModel(getCliqueTree(documentDataAndLabels)) :
new TestSequenceModel(getCliqueTree(documentDataAndLabels), labelDictionary, document);
}
protected CliquePotentialFunction getCliquePotentialFunctionForTest() {
if (cliquePotentialFunction == null) {
cliquePotentialFunction = new LinearCliquePotentialFunction(weights);
}
return cliquePotentialFunction;
}
public void updateWeightsForTest(double[] x) {
cliquePotentialFunction = cliquePotentialFunctionHelper.getCliquePotentialFunction(x);
}
/**
* Do standard sequence inference, using either Viterbi or Beam inference
* depending on the value of {@code flags.inferenceType}.
*
* @param document Document to classify. Classification happens in place.
* This document is modified.
* @return The classified document
*/
public List<IN> classifyMaxEnt(List<IN> document) {
if (document.isEmpty()) {
return document;
}
SequenceModel model = getSequenceModel(document);
return classifyMaxEnt(document, model);
}
private List<IN> classifyMaxEnt(List<IN> document, Triple<int[][][], int[], double[][][]> documentDataAndLabels) {
if (document.isEmpty()) {
return document;
}
SequenceModel model = getSequenceModel(documentDataAndLabels, document);
return classifyMaxEnt(document, model);
}
private List<IN> classifyMaxEnt(List<IN> document, SequenceModel model) {
if (document.isEmpty()) {
return document;
}
if (flags.inferenceType == null) {
flags.inferenceType = "Viterbi";
}
BestSequenceFinder tagInference;
if (flags.inferenceType.equalsIgnoreCase("Viterbi")) {
tagInference = new ExactBestSequenceFinder();
} else if (flags.inferenceType.equalsIgnoreCase("Beam")) {
tagInference = new BeamBestSequenceFinder(flags.beamSize);
} else {
throw new RuntimeException("Unknown inference type: " + flags.inferenceType + ". Your options are Viterbi|Beam.");
}
int[] bestSequence = tagInference.bestSequence(model);
if (flags.useReverse) {
Collections.reverse(document);
}
for (int j = 0, docSize = document.size(); j < docSize; j++) {
IN wi = document.get(j);
String guess = classIndex.get(bestSequence[j + windowSize - 1]);
wi.set(CoreAnnotations.AnswerAnnotation.class, guess);
}
if (flags.useReverse) {
Collections.reverse(document);
}
return document;
}
public List<IN> classifyGibbs(List<IN> document) throws ClassNotFoundException, SecurityException,
NoSuchMethodException, IllegalArgumentException, InstantiationException, IllegalAccessException,
InvocationTargetException {
Triple<int[][][], int[], double[][][]> p = documentToDataAndLabels(document);
return classifyGibbs(document, p);
}
public List<IN> classifyGibbs(List<IN> document, Triple<int[][][], int[], double[][][]> documentDataAndLabels)
throws ClassNotFoundException, SecurityException, NoSuchMethodException, IllegalArgumentException,
InstantiationException, IllegalAccessException, InvocationTargetException {
// System.err.println("Testing using Gibbs sampling.");
List<IN> newDocument = document; // reversed if necessary
if (flags.useReverse) {
Collections.reverse(document);
newDocument = new ArrayList<IN>(document);
Collections.reverse(document);
}
CRFCliqueTree<? extends CharSequence> cliqueTree = getCliqueTree(documentDataAndLabels);
PriorModelFactory<IN> pmf = (PriorModelFactory<IN>) Class.forName(flags.priorModelFactory).newInstance();
ListeningSequenceModel prior = pmf.getInstance(flags.backgroundSymbol, classIndex, tagIndex, newDocument, entityMatrices, flags);
if (flags.useUniformPrior) {
} else {
throw new RuntimeException("no prior specified");
}
SequenceModel model = new FactoredSequenceModel(cliqueTree, prior);
SequenceListener listener = new FactoredSequenceListener(cliqueTree, prior);
SequenceGibbsSampler sampler = new SequenceGibbsSampler(0, 0, listener);
int[] sequence = new int[cliqueTree.length()];
if (flags.initViterbi) {
TestSequenceModel testSequenceModel = new TestSequenceModel(cliqueTree);
ExactBestSequenceFinder tagInference = new ExactBestSequenceFinder();
int[] bestSequence = tagInference.bestSequence(testSequenceModel);
System.arraycopy(bestSequence, windowSize - 1, sequence, 0, sequence.length);
} else {
int[] initialSequence = SequenceGibbsSampler.getRandomSequence(model);
System.arraycopy(initialSequence, 0, sequence, 0, sequence.length);
}
sampler.verbose = 0;
if (flags.annealingType.equalsIgnoreCase("linear")) {
sequence = sampler.findBestUsingAnnealing(model, CoolingSchedule.getLinearSchedule(1.0, flags.numSamples),
sequence);
} else if (flags.annealingType.equalsIgnoreCase("exp") || flags.annealingType.equalsIgnoreCase("exponential")) {
sequence = sampler.findBestUsingAnnealing(model, CoolingSchedule.getExponentialSchedule(1.0, flags.annealingRate,
flags.numSamples), sequence);
} else {
throw new RuntimeException("No annealing type specified");
}
if (flags.useReverse) {
Collections.reverse(document);
}
for (int j = 0, dsize = newDocument.size(); j < dsize; j++) {
IN wi = document.get(j);
if (wi == null) throw new RuntimeException("");
if (classIndex == null) throw new RuntimeException("");
wi.set(CoreAnnotations.AnswerAnnotation.class, classIndex.get(sequence[j]));
}
if (flags.useReverse) {
Collections.reverse(document);
}
return document;
}
/**
* Takes a {@link List} of something that extends {@link CoreMap} and prints
* the likelihood of each possible label at each point.
*
* @param document
* A {@link List} of something that extends CoreMap.
*/
@Override
public void printProbsDocument(List<IN> document) {
Triple<int[][][], int[], double[][][]> p = documentToDataAndLabels(document);
CRFCliqueTree<String> cliqueTree = getCliqueTree(p);
// for (int i = 0; i < factorTables.length; i++) {
for (int i = 0; i < cliqueTree.length(); i++) {
IN wi = document.get(i);
System.out.print(wi.get(CoreAnnotations.TextAnnotation.class) + '\t');
for (Iterator<String> iter = classIndex.iterator(); iter.hasNext();) {
String label = iter.next();
int index = classIndex.indexOf(label);
// double prob = Math.pow(Math.E, factorTables[i].logProbEnd(index));
double prob = cliqueTree.prob(i, index);
System.out.print(label + '=' + prob);
if (iter.hasNext()) {
System.out.print("\t");
} else {
System.out.print("\n");
}
}
}
}
/**
* Takes the file, reads it in, and prints out the likelihood of each possible
* label at each point. This gives a simple way to examine the probability
* distributions of the CRF. See <code>getCliqueTrees()</code> for more.
*
* @param filename
* The path to the specified file
*/
public void printFirstOrderProbs(String filename, DocumentReaderAndWriter<IN> readerAndWriter) {
// only for the OCR data does this matter
flags.ocrTrain = false;
ObjectBank<List<IN>> docs = makeObjectBankFromFile(filename, readerAndWriter);
printFirstOrderProbsDocuments(docs);
}
/**
* Takes a {@link List} of documents and prints the likelihood of each
* possible label at each point.
*
* @param documents
* A {@link List} of {@link List} of INs.
*/
public void printFirstOrderProbsDocuments(ObjectBank<List<IN>> documents) {
for (List<IN> doc : documents) {
printFirstOrderProbsDocument(doc);
System.out.println();
}
}
/**
* Takes the file, reads it in, and prints out the factor table at each position.
*
* @param filename
* The path to the specified file
*/
public void printFactorTable(String filename, DocumentReaderAndWriter<IN> readerAndWriter) {
// only for the OCR data does this matter
flags.ocrTrain = false;
ObjectBank<List<IN>> docs = makeObjectBankFromFile(filename, readerAndWriter);
printFactorTableDocuments(docs);
}
/**
* Takes a {@link List} of documents and prints the factor table
* at each point.
*
* @param documents
* A {@link List} of {@link List} of INs.
*/
public void printFactorTableDocuments(ObjectBank<List<IN>> documents) {
for (List<IN> doc : documents) {
printFactorTableDocument(doc);
System.out.println();
}
}
/**
* Want to make arbitrary probability queries? Then this is the method for
* you. Given the filename, it reads it in and breaks it into documents, and
* then makes a CRFCliqueTree for each document. you can then ask the clique
* tree for marginals and conditional probabilities of almost anything you
* want.
*/
public List<CRFCliqueTree<String>> getCliqueTrees(String filename, DocumentReaderAndWriter<IN> readerAndWriter) {
// only for the OCR data does this matter
flags.ocrTrain = false;
List<CRFCliqueTree<String>> cts = new ArrayList<CRFCliqueTree<String>>();
ObjectBank<List<IN>> docs = makeObjectBankFromFile(filename, readerAndWriter);
for (List<IN> doc : docs) {
cts.add(getCliqueTree(doc));
}
return cts;
}
public CRFCliqueTree<String> getCliqueTree(Triple<int[][][], int[], double[][][]> p) {
int[][][] data = p.first();
double[][][] featureVal = p.third();
return CRFCliqueTree.getCalibratedCliqueTree(data, labelIndices, classIndex.size(), classIndex,
flags.backgroundSymbol, getCliquePotentialFunctionForTest(), featureVal);
}
public CRFCliqueTree<String> getCliqueTree(List<IN> document) {
Triple<int[][][], int[], double[][][]> p = documentToDataAndLabels(document);
return getCliqueTree(p);
}
/**
* Takes a {@link List} of something that extends {@link CoreMap} and prints
* the factor table at each point.
*
* @param document
* A {@link List} of something that extends {@link CoreMap}.
*/
public void printFactorTableDocument(List<IN> document) {
CRFCliqueTree<String> cliqueTree = getCliqueTree(document);
FactorTable[] factorTables = cliqueTree.getFactorTables();
StringBuilder sb = new StringBuilder();
for (int i=0; i < factorTables.length; i++) {
IN wi = document.get(i);
sb.append(wi.get(CoreAnnotations.TextAnnotation.class));
sb.append("\t");
FactorTable table = factorTables[i];
for (int j = 0; j < table.size(); j++) {
int[] arr = table.toArray(j);
sb.append(classIndex.get(arr[0]));
sb.append(":");
sb.append(classIndex.get(arr[1]));
sb.append(":");
sb.append(cliqueTree.logProb(i, arr));
sb.append(" ");
}
sb.append("\n");
}
System.out.print(sb.toString());
}
/**
* Takes a {@link List} of something that extends {@link CoreMap} and prints
* the likelihood of each possible label at each point.
*
* @param document
* A {@link List} of something that extends {@link CoreMap}.
*/
public void printFirstOrderProbsDocument(List<IN> document) {
CRFCliqueTree<String> cliqueTree = getCliqueTree(document);
// for (int i = 0; i < factorTables.length; i++) {
for (int i = 0; i < cliqueTree.length(); i++) {
IN wi = document.get(i);
System.out.print(wi.get(CoreAnnotations.TextAnnotation.class) + '\t');
for (Iterator<String> iter = classIndex.iterator(); iter.hasNext();) {
String label = iter.next();
int index = classIndex.indexOf(label);
if (i == 0) {
// double prob = Math.pow(Math.E, factorTables[i].logProbEnd(index));
double prob = cliqueTree.prob(i, index);
System.out.print(label + '=' + prob);
if (iter.hasNext()) {
System.out.print("\t");
} else {
System.out.print("\n");
}
} else {
for (Iterator<String> iter1 = classIndex.iterator(); iter1.hasNext();) {
String label1 = iter1.next();
int index1 = classIndex.indexOf(label1);
// double prob = Math.pow(Math.E, factorTables[i].logProbEnd(new
// int[]{index1, index}));
double prob = cliqueTree.prob(i, new int[] { index1, index });
System.out.print(label1 + '_' + label + '=' + prob);
if (iter.hasNext() || iter1.hasNext()) {
System.out.print("\t");
} else {
System.out.print("\n");
}
}
}
}
}
}
/**
* Load auxiliary data to be used in constructing features and labels
* Intended to be overridden by subclasses
*/
protected Collection<List<IN>> loadAuxiliaryData(Collection<List<IN>> docs, DocumentReaderAndWriter<IN> readerAndWriter) {
return docs;
}
/** {@inheritDoc} */
@Override
public void train(Collection<List<IN>> objectBankWrapper, DocumentReaderAndWriter<IN> readerAndWriter) {
Timing timer = new Timing();
timer.start();
Collection<List<IN>> docs = new ArrayList<List<IN>>();
for (List<IN> doc : objectBankWrapper) {
docs.add(doc);
}
if (flags.numOfSlices > 0) {
System.err.println("Taking " + flags.numOfSlices + " out of " + flags.totalDataSlice + " slices of data for training");
List<List<IN>> docsToShuffle = new ArrayList<List<IN>>();
for (List<IN> doc : docs) {
docsToShuffle.add(doc);
}
Collections.shuffle(docsToShuffle, random);
int cutOff = (int)(docsToShuffle.size() / (flags.totalDataSlice + 0.0) * flags.numOfSlices);
docs = docsToShuffle.subList(0, cutOff);
}
Collection<List<IN>> totalDocs = loadAuxiliaryData(docs, readerAndWriter);
makeAnswerArraysAndTagIndex(totalDocs);
long elapsedMs = timer.stop();
System.err.println("Time to convert docs to feature indices: " + Timing.toSecondsString(elapsedMs) + " seconds");
if (flags.serializeClassIndexTo != null) {
timer.start();
serializeClassIndex(flags.serializeClassIndexTo);
elapsedMs = timer.stop();
System.err.println("Time to export class index : " + Timing.toSecondsString(elapsedMs) + " seconds");
}
if (flags.exportFeatures != null) {
dumpFeatures(docs);
}
for (int i = 0; i <= flags.numTimesPruneFeatures; i++) {
timer.start();
Triple<int[][][][], int[][], double[][][][]> dataAndLabelsAndFeatureVals = documentsToDataAndLabels(docs);
elapsedMs = timer.stop();
System.err.println("Time to convert docs to data/labels: " + Timing.toSecondsString(elapsedMs) + " seconds");
Evaluator[] evaluators = null;
if (flags.evaluateIters > 0 || flags.terminateOnEvalImprovement) {
List<Evaluator> evaluatorList = new ArrayList<Evaluator>();
if (flags.useMemoryEvaluator)
evaluatorList.add(new MemoryEvaluator());
if (flags.evaluateTrain) {
CRFClassifierEvaluator<IN> crfEvaluator = new CRFClassifierEvaluator<IN>("Train set", this);
List<Triple<int[][][], int[], double[][][]>> trainDataAndLabels = new ArrayList<Triple<int[][][], int[], double[][][]>>();
int[][][][] data = dataAndLabelsAndFeatureVals.first();
int[][] labels = dataAndLabelsAndFeatureVals.second();
double[][][][] featureVal = dataAndLabelsAndFeatureVals.third();
for (int j = 0; j < data.length; j++) {
Triple<int[][][], int[], double[][][]> p = new Triple<int[][][], int[], double[][][]>(data[j], labels[j], featureVal[j]);
trainDataAndLabels.add(p);
}
crfEvaluator.setTestData(docs, trainDataAndLabels);
if (flags.evalCmd.length() > 0)
crfEvaluator.setEvalCmd(flags.evalCmd);
evaluatorList.add(crfEvaluator);
}
if (flags.testFile != null) {
CRFClassifierEvaluator<IN> crfEvaluator = new CRFClassifierEvaluator<IN>("Test set (" + flags.testFile + ")",
this);
ObjectBank<List<IN>> testObjBank = makeObjectBankFromFile(flags.testFile, readerAndWriter);
List<List<IN>> testDocs = new ArrayList<List<IN>>();
for (List<IN> doc : testObjBank) {
testDocs.add(doc);
}
List<Triple<int[][][], int[], double[][][]>> testDataAndLabels = documentsToDataAndLabelsList(testDocs);
crfEvaluator.setTestData(testDocs, testDataAndLabels);
if (flags.evalCmd.length() > 0)
crfEvaluator.setEvalCmd(flags.evalCmd);
evaluatorList.add(crfEvaluator);
}
if (flags.testFiles != null) {
String[] testFiles = flags.testFiles.split(",");
for (String testFile : testFiles) {
CRFClassifierEvaluator<IN> crfEvaluator = new CRFClassifierEvaluator<IN>("Test set ("
+ testFile + ")", this);
ObjectBank<List<IN>> testObjBank = makeObjectBankFromFile(testFile, readerAndWriter);
List<Triple<int[][][], int[], double[][][]>> testDataAndLabels = documentsToDataAndLabelsList(testObjBank);
crfEvaluator.setTestData(testObjBank, testDataAndLabels);
if (flags.evalCmd.length() > 0)
crfEvaluator.setEvalCmd(flags.evalCmd);
evaluatorList.add(crfEvaluator);
}
}
evaluators = new Evaluator[evaluatorList.size()];
evaluatorList.toArray(evaluators);
}
if (flags.numTimesPruneFeatures == i) {
docs = null; // hopefully saves memory
}
// save feature index to disk and read in later
File featIndexFile = null;
// CRFLogConditionalObjectiveFunction.featureIndex = featureIndex;
int numFeatures = featureIndex.size();
if (flags.saveFeatureIndexToDisk) {
try {
System.err.println("Writing feature index to temporary file.");
featIndexFile = IOUtils.writeObjectToTempFile(featureIndex, "featIndex" + i + ".tmp");
// featureIndex = null;
} catch (IOException e) {
throw new RuntimeException("Could not open temporary feature index file for writing.");
}
}
// first index is the number of the document
// second index is position in the document also the index of the
// clique/factor table
// third index is the number of elements in the clique/window thase
// features are for (starting with last element)
// fourth index is position of the feature in the array that holds them
// element in data[i][j][k][m] is the index of the mth feature occurring
// in position k of the jth clique of the ith document
int[][][][] data = dataAndLabelsAndFeatureVals.first();
// first index is the number of the document
// second index is the position in the document
// element in labels[i][j] is the index of the correct label (if it
// exists) at position j in document i
int[][] labels = dataAndLabelsAndFeatureVals.second();
double[][][][] featureVals = dataAndLabelsAndFeatureVals.third();
if (flags.loadProcessedData != null) {
List<List<CRFDatum<Collection<String>, String>>> processedData = loadProcessedData(flags.loadProcessedData);
if (processedData != null) {
// enlarge the data and labels array
int[][][][] allData = new int[data.length + processedData.size()][][][];
double[][][][] allFeatureVals = new double[featureVals.length + processedData.size()][][][];
int[][] allLabels = new int[labels.length + processedData.size()][];
System.arraycopy(data, 0, allData, 0, data.length);
System.arraycopy(labels, 0, allLabels, 0, labels.length);
System.arraycopy(featureVals, 0, allFeatureVals, 0, featureVals.length);
// add to the data and labels array
addProcessedData(processedData, allData, allLabels, allFeatureVals, data.length);
data = allData;
labels = allLabels;
featureVals = allFeatureVals;
}
}
double[] oneDimWeights = trainWeights(data, labels, evaluators, i, featureVals);
if (oneDimWeights != null) {
this.weights = to2D(oneDimWeights, labelIndices, map);
}
// if (flags.useFloat) {
// oneDimWeights = trainWeightsUsingFloatCRF(data, labels, evaluators, i, featureVals);
// } else if (flags.numLopExpert > 1) {
// oneDimWeights = trainWeightsUsingLopCRF(data, labels, evaluators, i, featureVals);
// } else {
// oneDimWeights = trainWeightsUsingDoubleCRF(data, labels, evaluators, i, featureVals);
// }
// save feature index to disk and read in later
if (flags.saveFeatureIndexToDisk) {
try {
System.err.println("Reading temporary feature index file.");
featureIndex = (Index<String>) IOUtils.readObjectFromFile(featIndexFile);
} catch (Exception e) {
throw new RuntimeException("Could not open temporary feature index file for reading.");
}
}
if (i != flags.numTimesPruneFeatures) {
dropFeaturesBelowThreshold(flags.featureDiffThresh);
System.err.println("Removing features with weight below " + flags.featureDiffThresh + " and retraining...");
}
}
}
public double[][] to2D(double[] weights, List<Index<CRFLabel>> labelIndices, int[] map) {
double[][] newWeights = new double[map.length][];
int index = 0;
for (int i = 0; i < map.length; i++) {
newWeights[i] = new double[labelIndices.get(map[i]).size()];
System.arraycopy(weights, index, newWeights[i], 0, labelIndices.get(map[i]).size());
index += labelIndices.get(map[i]).size();
}
return newWeights;
}
protected void pruneNodeFeatureIndices(int totalNumOfFeatureSlices, int numOfFeatureSlices) {
int numOfNodeFeatures = nodeFeatureIndicesMap.size();
int beginIndex = 0;
int endIndex = Math.min( (int)(numOfNodeFeatures / (totalNumOfFeatureSlices+0.0) * numOfFeatureSlices), numOfNodeFeatures);
List<Integer> nodeFeatureOriginalIndices = nodeFeatureIndicesMap.objectsList();
List<Integer> edgeFeatureOriginalIndices = edgeFeatureIndicesMap.objectsList();
Index<Integer> newNodeFeatureIndex = new HashIndex<Integer>();
Index<Integer> newEdgeFeatureIndex = new HashIndex<Integer>();
Index<String> newFeatureIndex = new HashIndex<String>();
for (int i = beginIndex; i < endIndex; i++) {
int oldIndex = nodeFeatureOriginalIndices.get(i);
String f = featureIndex.get(oldIndex);
int index = newFeatureIndex.addToIndex(f);
newNodeFeatureIndex.add(index);
}
for (Integer edgeFIndex: edgeFeatureOriginalIndices) {
String f = featureIndex.get(edgeFIndex);
int index = newFeatureIndex.addToIndex(f);
newEdgeFeatureIndex.add(index);
}
nodeFeatureIndicesMap = newNodeFeatureIndex;
edgeFeatureIndicesMap = newEdgeFeatureIndex;
int[] newMap = new int[newFeatureIndex.size()];
for (int i = 0; i < newMap.length; i++) {
int index = featureIndex.indexOf(newFeatureIndex.get(i));
newMap[i] = map[index];
}
map = newMap;
featureIndex = newFeatureIndex;
}
protected CRFLogConditionalObjectiveFunction getObjectiveFunction(int[][][][] data, int[][] labels) {
return new CRFLogConditionalObjectiveFunction(data, labels, windowSize, classIndex,
labelIndices, map, flags.priorType, flags.backgroundSymbol, flags.sigma, null, flags.multiThreadGrad);
}
protected double[] trainWeights(int[][][][] data, int[][] labels, Evaluator[] evaluators, int pruneFeatureItr, double[][][][] featureVals) {
CRFLogConditionalObjectiveFunction func = getObjectiveFunction(data, labels);
cliquePotentialFunctionHelper = func;
// create feature grouping
Map<String, Set<Integer>> featureSets = null;
if (flags.groupByOutputClass) {
featureSets = new HashMap<String, Set<Integer>>();
if (flags.groupByFeatureTemplate) {
int pIndex = 0;
for (int fIndex = 0; fIndex < map.length; fIndex++) {
int cliqueType = map[fIndex];
int numCliqueTypeOutputClass = labelIndices.get(map[fIndex]).size();
for (int cliqueOutClass = 0; cliqueOutClass < numCliqueTypeOutputClass; cliqueOutClass++) {
String name = "c:"+cliqueType+"-o:"+cliqueOutClass+"-g:"+featureIndexToTemplateIndex.get(fIndex);
if (featureSets.containsKey(name)) {
featureSets.get(name).add(pIndex);
} else {
Set<Integer> newSet = new HashSet<Integer>();
newSet.add(pIndex);
featureSets.put(name, newSet);
}
pIndex++;
}
}
} else {
int pIndex = 0;
for (int cliqueType : map) {
int numCliqueTypeOutputClass = labelIndices.get(cliqueType).size();
for (int cliqueOutClass = 0; cliqueOutClass < numCliqueTypeOutputClass; cliqueOutClass++) {
String name = "c:" + cliqueType + "-o:" + cliqueOutClass;
if (featureSets.containsKey(name)) {
featureSets.get(name).add(pIndex);
} else {
Set<Integer> newSet = new HashSet<Integer>();
newSet.add(pIndex);
featureSets.put(name, newSet);
}
pIndex++;
}
}
}
} else if (flags.groupByFeatureTemplate) {
featureSets = new HashMap<String, Set<Integer>>();
int pIndex = 0;
for (int fIndex = 0; fIndex < map.length; fIndex++) {
int cliqueType = map[fIndex];
int numCliqueTypeOutputClass = labelIndices.get(map[fIndex]).size();
for (int cliqueOutClass = 0; cliqueOutClass < numCliqueTypeOutputClass; cliqueOutClass++) {
String name = "c:"+cliqueType+"-g:"+featureIndexToTemplateIndex.get(fIndex);
if (featureSets.containsKey(name)) {
featureSets.get(name).add(pIndex);
} else {
Set<Integer> newSet = new HashSet<Integer>();
newSet.add(pIndex);
featureSets.put(name, newSet);
}
pIndex++;
}
}
}
if (featureSets != null) {
int[][] fg = new int[featureSets.size()][];
System.err.println("After feature grouping, total of "+fg.length+" groups");
int count = 0;
for (Set<Integer> aSet: featureSets.values()) {
fg[count] = new int[aSet.size()];
int i = 0;
for (Integer val : aSet)
fg[count][i++] = val;
count++;
}
func.setFeatureGrouping(fg);
}
Minimizer<DiffFunction> minimizer = getMinimizer(pruneFeatureItr, evaluators);
double[] initialWeights;
if (flags.initialWeights == null) {
initialWeights = func.initial();
} else {
try {
System.err.println("Reading initial weights from file " + flags.initialWeights);
DataInputStream dis = new DataInputStream(new BufferedInputStream(new GZIPInputStream(new FileInputStream(
flags.initialWeights))));
initialWeights = ConvertByteArray.readDoubleArr(dis);
} catch (IOException e) {
throw new RuntimeException("Could not read from double initial weight file " + flags.initialWeights);
}
}
System.err.println("numWeights: " + initialWeights.length);
if (flags.testObjFunction) {
StochasticDiffFunctionTester tester = new StochasticDiffFunctionTester(func);
if (tester.testSumOfBatches(initialWeights, 1e-4)) {
System.err.println("Successfully tested stochastic objective function.");
} else {
throw new IllegalStateException("Testing of stochastic objective function failed.");
}
}
//check gradient
if (flags.checkGradient) {
if (func.gradientCheck()) {
System.err.println("gradient check passed");
} else {
throw new RuntimeException("gradient check failed");
}
}
return minimizer.minimize(func, flags.tolerance, initialWeights);
}
public Minimizer<DiffFunction> getMinimizer() {
return getMinimizer(0, null);
}
public Minimizer<DiffFunction> getMinimizer(int featurePruneIteration, Evaluator[] evaluators) {
Minimizer<DiffFunction> minimizer = null;
if (flags.useQN) {
int QNmem;
if (featurePruneIteration == 0) {
QNmem = flags.QNsize;
} else {
QNmem = flags.QNsize2;
}
if (flags.interimOutputFreq != 0) {
Function monitor = new ResultStoringMonitor(flags.interimOutputFreq, flags.serializeTo);
minimizer = new QNMinimizer(monitor, QNmem, flags.useRobustQN);
} else {
minimizer = new QNMinimizer(QNmem, flags.useRobustQN);
}
((QNMinimizer) minimizer).terminateOnMaxItr(flags.maxQNItr);
((QNMinimizer) minimizer).terminateOnEvalImprovement(flags.terminateOnEvalImprovement);
((QNMinimizer) minimizer).setTerminateOnEvalImprovementNumOfEpoch(flags.terminateOnEvalImprovementNumOfEpoch);
((QNMinimizer) minimizer).suppressTestPrompt(flags.suppressTestDebug);
if (flags.useOWLQN) {
((QNMinimizer) minimizer).useOWLQN(flags.useOWLQN, flags.priorLambda);
}
} else if (flags.useInPlaceSGD) {
SGDMinimizer<DiffFunction> sgdMinimizer =
new SGDMinimizer<DiffFunction>(flags.sigma, flags.SGDPasses, flags.tuneSampleSize, flags.stochasticBatchSize);
if (flags.useSGDtoQN) {
QNMinimizer qnMinimizer;
int QNmem;
if (featurePruneIteration == 0) {
QNmem = flags.QNsize;
} else {
QNmem = flags.QNsize2;
}
if (flags.interimOutputFreq != 0) {
Function monitor = new ResultStoringMonitor(flags.interimOutputFreq, flags.serializeTo);
qnMinimizer = new QNMinimizer(monitor, QNmem, flags.useRobustQN);
} else {
qnMinimizer = new QNMinimizer(QNmem, flags.useRobustQN);
}
minimizer = new HybridMinimizer(sgdMinimizer, qnMinimizer, flags.SGDPasses);
} else {
minimizer = sgdMinimizer;
}
} else if (flags.useAdaGradFOBOS) {
double lambda = 0.5 / (flags.sigma * flags.sigma);
minimizer = new SGDWithAdaGradAndFOBOS<DiffFunction>(
flags.initRate, lambda, flags.SGDPasses, flags.stochasticBatchSize,
flags.priorType, flags.priorAlpha, flags.useAdaDelta, flags.useAdaDiff, flags.adaGradEps, flags.adaDeltaRho);
((SGDWithAdaGradAndFOBOS) minimizer).terminateOnEvalImprovement(flags.terminateOnEvalImprovement);
((SGDWithAdaGradAndFOBOS) minimizer).terminateOnAvgImprovement(flags.terminateOnAvgImprovement, flags.tolerance);
((SGDWithAdaGradAndFOBOS) minimizer).setTerminateOnEvalImprovementNumOfEpoch(flags.terminateOnEvalImprovementNumOfEpoch);
((SGDWithAdaGradAndFOBOS) minimizer).suppressTestPrompt(flags.suppressTestDebug);
} else if (flags.useSGDtoQN) {
minimizer = new SGDToQNMinimizer(flags.initialGain, flags.stochasticBatchSize,
flags.SGDPasses, flags.QNPasses, flags.SGD2QNhessSamples,
flags.QNsize, flags.outputIterationsToFile);
} else if (flags.useSMD) {
minimizer = new SMDMinimizer<DiffFunction>(flags.initialGain, flags.stochasticBatchSize, flags.stochasticMethod,
flags.SGDPasses);
} else if (flags.useSGD) {
minimizer = new InefficientSGDMinimizer<DiffFunction>(flags.initialGain, flags.stochasticBatchSize);
} else if (flags.useScaledSGD) {
minimizer = new ScaledSGDMinimizer(flags.initialGain, flags.stochasticBatchSize, flags.SGDPasses,
flags.scaledSGDMethod);
} else if (flags.l1reg > 0.0) {
minimizer = ReflectionLoading.loadByReflection("edu.stanford.nlp.optimization.OWLQNMinimizer", flags.l1reg);
}
if (minimizer instanceof HasEvaluators) {
if (minimizer instanceof QNMinimizer) {
((QNMinimizer) minimizer).setEvaluators(flags.evaluateIters, flags.startEvaluateIters, evaluators);
} else
((HasEvaluators) minimizer).setEvaluators(flags.evaluateIters, evaluators);
}
if (minimizer == null) {
throw new RuntimeException("No minimizer assigned!");
}
return minimizer;
}
/**
* Creates a new CRFDatum from the preprocessed allData format, given the
* document number, position number, and a List of Object labels.
*
* @return A new CRFDatum
*/
protected List<CRFDatum<? extends Collection<String>, ? extends CharSequence>> extractDatumSequence(int[][][] allData, int beginPosition, int endPosition,
List<IN> labeledWordInfos) {
List<CRFDatum<? extends Collection<String>, ? extends CharSequence>> result = new ArrayList<CRFDatum<? extends Collection<String>, ? extends CharSequence>>();
int beginContext = beginPosition - windowSize + 1;
if (beginContext < 0) {
beginContext = 0;
}
// for the beginning context, add some dummy datums with no features!
// TODO: is there any better way to do this?
for (int position = beginContext; position < beginPosition; position++) {
List<Collection<String>> cliqueFeatures = new ArrayList<Collection<String>>();
List<double[]> featureVals = new ArrayList<double[]>();
for (int i = 0; i < windowSize; i++) {
// create a feature list
cliqueFeatures.add(Collections.<String>emptyList());
featureVals.add(null);
}
CRFDatum<Collection<String>, String> datum = new CRFDatum<Collection<String>, String>(cliqueFeatures,
labeledWordInfos.get(position).get(CoreAnnotations.AnswerAnnotation.class), featureVals);
result.add(datum);
}
// now add the real datums
for (int position = beginPosition; position <= endPosition; position++) {
List<Collection<String>> cliqueFeatures = new ArrayList<Collection<String>>();
List<double[]> featureVals = new ArrayList<double[]>();
for (int i = 0; i < windowSize; i++) {
// create a feature list
Collection<String> features = new ArrayList<String>();
for (int j = 0; j < allData[position][i].length; j++) {
features.add(featureIndex.get(allData[position][i][j]));
}
cliqueFeatures.add(features);
featureVals.add(null);
}
CRFDatum<Collection<String>,String> datum = new CRFDatum<Collection<String>,String>(cliqueFeatures,
labeledWordInfos.get(position).get(CoreAnnotations.AnswerAnnotation.class), featureVals);
result.add(datum);
}
return result;
}
/**
* Adds the List of Lists of CRFDatums to the data and labels arrays, treating
* each datum as if it were its own document. Adds context labels in addition
* to the target label for each datum, meaning that for a particular document,
* the number of labels will be windowSize-1 greater than the number of
* datums.
*
* @param processedData
* a List of Lists of CRFDatums
*/
protected void addProcessedData(List<List<CRFDatum<Collection<String>, String>>> processedData, int[][][][] data,
int[][] labels, double[][][][] featureVals, int offset) {
for (int i = 0, pdSize = processedData.size(); i < pdSize; i++) {
int dataIndex = i + offset;
List<CRFDatum<Collection<String>, String>> document = processedData.get(i);
int dsize = document.size();
labels[dataIndex] = new int[dsize];
data[dataIndex] = new int[dsize][][];
if (featureVals != null)
featureVals[dataIndex] = new double[dsize][][];
for (int j = 0; j < dsize; j++) {
CRFDatum<Collection<String>, String> crfDatum = document.get(j);
// add label, they are offset by extra context
labels[dataIndex][j] = classIndex.indexOf(crfDatum.label());
// add featureVals
List<double[]> featureValList = null;
if (featureVals != null)
featureValList = crfDatum.asFeatureVals();
// add features
List<Collection<String>> cliques = crfDatum.asFeatures();
int csize = cliques.size();
data[dataIndex][j] = new int[csize][];
if (featureVals != null)
featureVals[dataIndex][j] = new double[csize][];
for (int k = 0; k < csize; k++) {
Collection<String> features = cliques.get(k);
data[dataIndex][j][k] = new int[features.size()];
if (featureVals != null)
featureVals[dataIndex][j][k] = featureValList.get(k);
int m = 0;
try {
for (String feature : features) {
// System.err.println("feature " + feature);
// if (featureIndex.indexOf(feature)) ;
if (featureIndex == null) {
System.out.println("Feature is NULL!");
}
data[dataIndex][j][k][m] = featureIndex.indexOf(feature);
m++;
}
} catch (Exception e) {
e.printStackTrace();
System.err.printf("[index=%d, j=%d, k=%d, m=%d]%n", dataIndex, j, k, m);
System.err.println("data.length " + data.length);
System.err.println("data[dataIndex].length " + data[dataIndex].length);
System.err.println("data[dataIndex][j].length " + data[dataIndex][j].length);
System.err.println("data[dataIndex][j][k].length " + data[dataIndex][j].length);
System.err.println("data[dataIndex][j][k][m] " + data[dataIndex][j][k][m]);
return;
}
}
}
}
}
protected static void saveProcessedData(List datums, String filename) {
System.err.print("Saving processed data of size " + datums.size() + " to serialized file...");
ObjectOutputStream oos = null;
try {
oos = new ObjectOutputStream(new FileOutputStream(filename));
oos.writeObject(datums);
} catch (IOException e) {
// do nothing
} finally {
IOUtils.closeIgnoringExceptions(oos);
}
System.err.println("done.");
}
protected static List<List<CRFDatum<Collection<String>, String>>> loadProcessedData(String filename) {
System.err.print("Loading processed data from serialized file...");
ObjectInputStream ois = null;
List<List<CRFDatum<Collection<String>, String>>> result = Collections.emptyList();
try {
ois = new ObjectInputStream(new FileInputStream(filename));
result = (List<List<CRFDatum<Collection<String>, String>>>) ois.readObject();
} catch (Exception e) {
e.printStackTrace();
} finally {
IOUtils.closeIgnoringExceptions(ois);
}
System.err.println("done. Got " + result.size() + " datums.");
return result;
}
protected void loadTextClassifier(BufferedReader br) throws Exception {
String line = br.readLine();
// first line should be this format:
// labelIndices.size()=\t%d
String[] toks = line.split("\\t");
if (!toks[0].equals("labelIndices.length=")) {
throw new RuntimeException("format error");
}
int size = Integer.parseInt(toks[1]);
labelIndices = new ArrayList<Index<CRFLabel>>(size);
for (int labelIndicesIdx = 0; labelIndicesIdx < size; labelIndicesIdx++) {
line = br.readLine();
// first line should be this format:
// labelIndices.length=\t%d
// labelIndices[0].size()=\t%d
toks = line.split("\\t");
if (!(toks[0].startsWith("labelIndices[") && toks[0].endsWith("].size()="))) {
throw new RuntimeException("format error");
}
int labelIndexSize = Integer.parseInt(toks[1]);
labelIndices.add(new HashIndex<CRFLabel>());
int count = 0;
while (count < labelIndexSize) {
line = br.readLine();
toks = line.split("\\t");
int idx = Integer.parseInt(toks[0]);
if (count != idx) {
throw new RuntimeException("format error");
}
String[] crflabelstr = toks[1].split(" ");
int[] crflabel = new int[crflabelstr.length];
for (int i = 0; i < crflabelstr.length; i++) {
crflabel[i] = Integer.parseInt(crflabelstr[i]);
}
CRFLabel crfL = new CRFLabel(crflabel);
labelIndices.get(labelIndicesIdx).add(crfL);
count++;
}
}
for (Index<CRFLabel> index : labelIndices) {
for (int j = 0; j < index.size(); j++) {
int[] label = index.get(j).getLabel();
List<Integer> list = new ArrayList<Integer>();
for (int l : label) {
list.add(l);
}
}
}
line = br.readLine();
toks = line.split("\\t");
if (!toks[0].equals("classIndex.size()=")) {
throw new RuntimeException("format error");
}
int classIndexSize = Integer.parseInt(toks[1]);
classIndex = new HashIndex<String>();
int count = 0;
while (count < classIndexSize) {
line = br.readLine();
toks = line.split("\\t");
int idx = Integer.parseInt(toks[0]);
if (count != idx) {
throw new RuntimeException("format error");
}
classIndex.add(toks[1]);
count++;
}
line = br.readLine();
toks = line.split("\\t");
if (!toks[0].equals("featureIndex.size()=")) {
throw new RuntimeException("format error");
}
int featureIndexSize = Integer.parseInt(toks[1]);
featureIndex = new HashIndex<String>();
count = 0;
while (count < featureIndexSize) {
line = br.readLine();
toks = line.split("\\t");
int idx = Integer.parseInt(toks[0]);
if (count != idx) {
throw new RuntimeException("format error");
}
featureIndex.add(toks[1]);
count++;
}
line = br.readLine();
if (!line.equals("<flags>")) {
throw new RuntimeException("format error");
}
Properties p = new Properties();
line = br.readLine();
while (!line.equals("</flags>")) {
// System.err.println("DEBUG: flags line: "+line);
String[] keyValue = line.split("=");
// System.err.printf("DEBUG: p.setProperty(%s,%s)%n", keyValue[0],
// keyValue[1]);
p.setProperty(keyValue[0], keyValue[1]);
line = br.readLine();
}
// System.err.println("DEBUG: out from flags");
flags = new SeqClassifierFlags(p);
if (flags.useEmbedding) {
line = br.readLine();
toks = line.split("\\t");
if (!toks[0].equals("embeddings.size()=")) {
throw new RuntimeException("format error in embeddings");
}
int embeddingSize = Integer.parseInt(toks[1]);
embeddings = Generics.newHashMap(embeddingSize);
count = 0;
while (count < embeddingSize) {
line = br.readLine().trim();
toks = line.split("\\t");
String word = toks[0];
double[] arr = ArrayUtils.toDoubleArray(toks[1].split(" "));
embeddings.put(word, arr);
count++;
}
}
// <featureFactory>
// edu.stanford.nlp.wordseg.Gale2007ChineseSegmenterFeatureFactory
// </featureFactory>
line = br.readLine();
String[] featureFactoryName = line.split(" ");
if (featureFactoryName.length < 2 || !featureFactoryName[0].equals("<featureFactory>") || !featureFactoryName[featureFactoryName.length - 1].equals("</featureFactory>")) {
throw new RuntimeException("format error unexpected featureFactory line: " + line);
}
featureFactories = Generics.newArrayList();
for (int ff = 1; ff < featureFactoryName.length - 1; ++ff) {
FeatureFactory featureFactory = (edu.stanford.nlp.sequences.FeatureFactory<IN>) Class.forName(featureFactoryName[1]).newInstance();
featureFactory.init(flags);
featureFactories.add(featureFactory);
}
reinit();
// <windowSize> 2 </windowSize>
line = br.readLine();
String[] windowSizeName = line.split(" ");
if (!windowSizeName[0].equals("<windowSize>") || !windowSizeName[2].equals("</windowSize>")) {
throw new RuntimeException("format error");
}
windowSize = Integer.parseInt(windowSizeName[1]);
// weights.length= 2655170
line = br.readLine();
toks = line.split("\\t");
if (!toks[0].equals("weights.length=")) {
throw new RuntimeException("format error");
}
int weightsLength = Integer.parseInt(toks[1]);
weights = new double[weightsLength][];
count = 0;
while (count < weightsLength) {
line = br.readLine();
toks = line.split("\\t");
int weights2Length = Integer.parseInt(toks[0]);
weights[count] = new double[weights2Length];
String[] weightsValue = toks[1].split(" ");
if (weights2Length != weightsValue.length) {
throw new RuntimeException("weights format error");
}
for (int i2 = 0; i2 < weights2Length; i2++) {
weights[count][i2] = Double.parseDouble(weightsValue[i2]);
}
count++;
}
System.err.printf("DEBUG: double[%d][] weights loaded%n", weightsLength);
line = br.readLine();
if (line != null) {
throw new RuntimeException("weights format error");
}
}
public void loadTextClassifier(String text, Properties props) throws ClassCastException, IOException,
ClassNotFoundException, InstantiationException, IllegalAccessException {
// System.err.println("DEBUG: in loadTextClassifier");
System.err.println("Loading Text Classifier from " + text);
try {
BufferedReader br = IOUtils.readerFromString(text);
loadTextClassifier(br);
br.close();
} catch (Exception ex) {
System.err.println("Exception in loading text classifier from " + text);
ex.printStackTrace();
}
}
protected void serializeTextClassifier(PrintWriter pw) throws Exception {
pw.printf("labelIndices.length=\t%d%n", labelIndices.size());
for (int i = 0; i < labelIndices.size(); i++) {
pw.printf("labelIndices[%d].size()=\t%d%n", i, labelIndices.get(i).size());
for (int j = 0; j < labelIndices.get(i).size(); j++) {
int[] label = labelIndices.get(i).get(j).getLabel();
List<Integer> list = new ArrayList<Integer>();
for (int l : label) {
list.add(l);
}
pw.printf("%d\t%s%n", j, StringUtils.join(list, " "));
}
}
pw.printf("classIndex.size()=\t%d%n", classIndex.size());
for (int i = 0; i < classIndex.size(); i++) {
pw.printf("%d\t%s%n", i, classIndex.get(i));
}
// pw.printf("</classIndex>%n");
pw.printf("featureIndex.size()=\t%d%n", featureIndex.size());
for (int i = 0; i < featureIndex.size(); i++) {
pw.printf("%d\t%s%n", i, featureIndex.get(i));
}
// pw.printf("</featureIndex>%n");
pw.println("<flags>");
pw.print(flags.toString());
pw.println("</flags>");
if (flags.useEmbedding) {
pw.printf("embeddings.size()=\t%d%n", embeddings.size());
for (String word: embeddings.keySet()) {
double[] arr = embeddings.get(word);
Double[] arrUnboxed = new Double[arr.length];
for(int i = 0; i < arr.length; i++)
arrUnboxed[i] = arr[i];
pw.printf("%s\t%s%n", word, StringUtils.join(arrUnboxed, " "));
}
}
pw.printf("<featureFactory>");
for (FeatureFactory featureFactory : featureFactories) {
pw.printf(" %s ", featureFactory.getClass().getName());
}
pw.printf("</featureFactory>%n");
pw.printf("<windowSize> %d </windowSize>%n", windowSize);
pw.printf("weights.length=\t%d%n", weights.length);
for (double[] ws : weights) {
ArrayList<Double> list = new ArrayList<Double>();
for (double w : ws) {
list.add(w);
}
pw.printf("%d\t%s%n", ws.length, StringUtils.join(list, " "));
}
}
/**
* Serialize the model to a human readable format. It's not yet complete. It
* should now work for Chinese segmenter though. TODO: check things in
* serializeClassifier and add other necessary serialization back.
*
* @param serializePath
* File to write text format of classifier to.
*/
public void serializeTextClassifier(String serializePath) {
System.err.print("Serializing Text classifier to " + serializePath + "...");
try {
PrintWriter pw = new PrintWriter(new GZIPOutputStream(new FileOutputStream(serializePath)));
serializeTextClassifier(pw);
pw.close();
System.err.println("done.");
} catch (Exception e) {
System.err.println("Failed");
e.printStackTrace();
}
}
public void serializeClassIndex(String serializePath) {
System.err.print("Serializing class index to " + serializePath + "...");
ObjectOutputStream oos = null;
try {
oos = IOUtils.writeStreamFromString(serializePath);
oos.writeObject(classIndex);
System.err.println("done.");
} catch (Exception e) {
System.err.println("Failed");
e.printStackTrace();
} finally {
IOUtils.closeIgnoringExceptions(oos);
}
}
public static Index<String> loadClassIndexFromFile(String serializePath) {
System.err.print("Reading class index from " + serializePath + "...");
ObjectInputStream ois = null;
Index<String> c = null;
try {
ois = IOUtils.readStreamFromString(serializePath);
c = (Index<String>) ois.readObject();
System.err.println("done.");
} catch (Exception e) {
System.err.println("Failed");
e.printStackTrace();
} finally {
IOUtils.closeIgnoringExceptions(ois);
}
return c;
}
public void serializeWeights(String serializePath) {
System.err.print("Serializing weights to " + serializePath + "...");
ObjectOutputStream oos = null;
try {
oos = IOUtils.writeStreamFromString(serializePath);
oos.writeObject(weights);
System.err.println("done.");
} catch (Exception e) {
System.err.println("Failed");
e.printStackTrace();
} finally {
IOUtils.closeIgnoringExceptions(oos);
}
}
public static double[][] loadWeightsFromFile(String serializePath) {
System.err.print("Reading weights from " + serializePath + "...");
ObjectInputStream ois = null;
double[][] w = null;
try {
ois = IOUtils.readStreamFromString(serializePath);
w = (double[][]) ois.readObject();
System.err.println("done.");
} catch (Exception e) {
System.err.println("Failed");
e.printStackTrace();
} finally {
IOUtils.closeIgnoringExceptions(ois);
}
return w;
}
public void serializeFeatureIndex(String serializePath) {
System.err.print("Serializing FeatureIndex to " + serializePath + "...");
ObjectOutputStream oos = null;
try {
oos = IOUtils.writeStreamFromString(serializePath);
oos.writeObject(featureIndex);
System.err.println("done.");
} catch (Exception e) {
System.err.println("Failed");
e.printStackTrace();
} finally {
IOUtils.closeIgnoringExceptions(oos);
}
}
public static Index<String> loadFeatureIndexFromFile(String serializePath) {
System.err.print("Reading FeatureIndex from " + serializePath + "...");
ObjectInputStream ois = null;
Index<String> f = null;
try {
ois = IOUtils.readStreamFromString(serializePath);
f = (Index<String>) ois.readObject();
System.err.println("done.");
} catch (Exception e) {
System.err.println("Failed");
e.printStackTrace();
} finally {
IOUtils.closeIgnoringExceptions(ois);
}
return f;
}
/**
* {@inheritDoc}
*/
@Override
public void serializeClassifier(String serializePath) {
System.err.print("Serializing classifier to " + serializePath + "...");
ObjectOutputStream oos = null;
try {
oos = IOUtils.writeStreamFromString(serializePath);
serializeClassifier(oos);
System.err.println("done.");
} catch (Exception e) {
throw new RuntimeIOException("Failed to save classifier", e);
} finally {
IOUtils.closeIgnoringExceptions(oos);
}
}
/**
* Serialize the classifier to the given ObjectOutputStream.
* <br>
* (Since the classifier is a processor, we don't want to serialize the
* whole classifier but just the data that represents a classifier model.)
*/
public void serializeClassifier(ObjectOutputStream oos) {
try {
oos.writeObject(labelIndices);
oos.writeObject(classIndex);
oos.writeObject(featureIndex);
oos.writeObject(flags);
if (flags.useEmbedding) {
oos.writeObject(embeddings);
}
// For some reason, writing out the array of FeatureFactory
// objects doesn't seem to work. The resulting classifier
// doesn't have the lexicon (distsim object) correctly saved. So now custom write the list
oos.writeObject(featureFactories.size());
for (FeatureFactory ff : featureFactories) {
oos.writeObject(ff);
}
oos.writeInt(windowSize);
oos.writeObject(weights);
// oos.writeObject(WordShapeClassifier.getKnownLowerCaseWords());
oos.writeObject(knownLCWords);
if (labelDictionary != null) {
oos.writeObject(labelDictionary);
}
} catch (IOException e) {
throw new RuntimeIOException(e);
}
}
/**
* Loads a classifier from the specified InputStream. This version works
* quietly (unless VERBOSE is true). If props is non-null then any properties
* it specifies override those in the serialized file. However, only some
* properties are sensible to change (you shouldn't change how features are
* defined).
* <p>
* <i>Note:</i> This method does not close the ObjectInputStream. (But earlier
* versions of the code used to, so beware....)
*/
@Override
@SuppressWarnings( { "unchecked" })
// can't have right types in deserialization
public void loadClassifier(ObjectInputStream ois, Properties props) throws ClassCastException, IOException,
ClassNotFoundException {
Object o = ois.readObject();
// TODO: when we next break serialization, get rid of this fork and only read the List<Index> (i.e., keep first case)
if (o instanceof List) {
labelIndices = (List<Index<CRFLabel>>) o;
} else {
Index<CRFLabel>[] indexArray = (Index<CRFLabel>[]) o;
labelIndices = new ArrayList<Index<CRFLabel>>(indexArray.length);
Collections.addAll(labelIndices, indexArray);
}
classIndex = (Index<String>) ois.readObject();
featureIndex = (Index<String>) ois.readObject();
flags = (SeqClassifierFlags) ois.readObject();
if (flags.useEmbedding) {
embeddings = (Map<String, double[]>) ois.readObject();
}
Object featureFactory = ois.readObject();
if (featureFactory instanceof List) {
featureFactories = ErasureUtils.uncheckedCast(featureFactories);
} else if (featureFactory instanceof FeatureFactory) {
featureFactories = Generics.newArrayList();
featureFactories.add((FeatureFactory) featureFactory);
} else if (featureFactory instanceof Integer) {
// this is the current format (2014) since writing list didn't work (see note in save).
int size = (Integer) featureFactory;
featureFactories = Generics.newArrayList(size);
for (int i = 0; i < size; ++i) {
featureFactory = ois.readObject();
if (!(featureFactory instanceof FeatureFactory)) {
throw new RuntimeIOException("Should have FeatureFactory but got " + featureFactory.getClass());
}
featureFactories.add((FeatureFactory) featureFactory);
}
}
if (props != null) {
flags.setProperties(props, false);
}
reinit();
windowSize = ois.readInt();
weights = (double[][]) ois.readObject();
// WordShapeClassifier.setKnownLowerCaseWords((Set) ois.readObject());
knownLCWords = (Set<String>) ois.readObject();
if (flags.labelDictionaryCutoff > 0) {
labelDictionary = (LabelDictionary) ois.readObject();
}
if (VERBOSE) {
System.err.println("windowSize=" + windowSize);
System.err.println("flags=\n" + flags);
}
}
/**
* This is used to load the default supplied classifier stored within the jar
* file. THIS FUNCTION WILL ONLY WORK IF THE CODE WAS LOADED FROM A JAR FILE
* WHICH HAS A SERIALIZED CLASSIFIER STORED INSIDE IT.
*/
public void loadDefaultClassifier() {
loadJarClassifier(DEFAULT_CLASSIFIER, null);
}
public void loadTagIndex() {
if (tagIndex == null) {
tagIndex = new HashIndex<String>();
for (String tag: classIndex.objectsList()) {
String[] parts = tag.split("-");
// if (parts.length > 1)
tagIndex.add(parts[parts.length-1]);
}
tagIndex.add(flags.backgroundSymbol);
}
if (flags.useNERPriorBIO) {
if (entityMatrices == null)
entityMatrices = readEntityMatrices(flags.entityMatrix, tagIndex);
}
}
static double[][] parseMatrix(String[] lines, Index<String> tagIndex, int matrixSize, boolean smooth) {
return parseMatrix(lines, tagIndex, matrixSize, smooth, true);
}
/**
* @return a matrix where each entry m[i][j] is logP(j|i)
* in other words, each row vector is normalized log conditional likelihood
*/
static double[][] parseMatrix(String[] lines, Index<String> tagIndex, int matrixSize, boolean smooth, boolean useLogProb) {
double[][] matrix = new double[matrixSize][matrixSize];
for (int i = 0; i < matrix.length; i++)
matrix[i] = new double[matrixSize];
for (String line: lines) {
String[] parts = line.split("\t");
for (String part: parts) {
String[] subparts = part.split(" ");
String[] subsubparts = subparts[0].split(":");
double counts = Double.parseDouble(subparts[1]);
if (counts == 0.0 && smooth) // smoothing
counts = 1.0;
int tagIndex1 = tagIndex.indexOf(subsubparts[0]);
int tagIndex2 = tagIndex.indexOf(subsubparts[1]);
matrix[tagIndex1][tagIndex2] = counts;
}
}
for (int i = 0; i < matrix.length; i++) {
double sum = ArrayMath.sum(matrix[i]);
for (int j = 0; j < matrix[i].length; j++) {
// log conditional probability
if (useLogProb)
matrix[i][j] = Math.log(matrix[i][j] / sum);
else
matrix[i][j] = matrix[i][j] / sum;
}
}
return matrix;
}
static Pair<double[][], double[][]> readEntityMatrices(String fileName, Index<String> tagIndex) {
int numTags = tagIndex.size();
int matrixSize = numTags-1;
String[] matrixLines = new String[matrixSize];
String[] subMatrixLines = new String[matrixSize];
try {
BufferedReader br = IOUtils.readerFromString(fileName);
int lineCount = 0;
for (String line; (line = br.readLine()) != null; ) {
line = line.trim();
if (lineCount < matrixSize)
matrixLines[lineCount] = line;
else
subMatrixLines[lineCount-matrixSize] = line;
lineCount++;
}
} catch (Exception ex) {
throw new RuntimeIOException(ex);
}
double[][] matrix = parseMatrix(matrixLines, tagIndex, matrixSize, true);
double[][] subMatrix = parseMatrix(subMatrixLines, tagIndex, matrixSize, true);
// In Jenny's paper, use the square root of non-log prob for matrix, but not for subMatrix
for (int i = 0; i < matrix.length; i++) {
for (int j = 0; j < matrix[i].length; j++)
matrix[i][j] = matrix[i][j] / 2;
}
System.err.println("Matrix: ");
System.err.println(ArrayUtils.toString(matrix));
System.err.println("SubMatrix: ");
System.err.println(ArrayUtils.toString(subMatrix));
return new Pair<double[][], double[][]>(matrix, subMatrix);
}
/**
* This is used to load the default supplied classifier stored within the jar
* file. THIS FUNCTION WILL ONLY WORK IF THE CODE WAS LOADED FROM A JAR FILE
* WHICH HAS A SERIALIZED CLASSIFIER STORED INSIDE IT.
*/
public void loadDefaultClassifier(Properties props) {
loadJarClassifier(DEFAULT_CLASSIFIER, props);
}
/**
* Used to get the default supplied classifier inside the jar file. THIS
* FUNCTION WILL ONLY WORK IF THE CODE WAS LOADED FROM A JAR FILE WHICH HAS A
* SERIALIZED CLASSIFIER STORED INSIDE IT.
*
* @return The default CRFClassifier in the jar file (if there is one)
*/
public static <INN extends CoreMap> CRFClassifier<INN> getDefaultClassifier() {
CRFClassifier<INN> crf = new CRFClassifier<INN>();
crf.loadDefaultClassifier();
return crf;
}
/**
* Used to get the default supplied classifier inside the jar file. THIS
* FUNCTION WILL ONLY WORK IF THE CODE WAS LOADED FROM A JAR FILE WHICH HAS A
* SERIALIZED CLASSIFIER STORED INSIDE IT.
*
* @return The default CRFClassifier in the jar file (if there is one)
*/
public static <INN extends CoreMap> CRFClassifier<INN> getDefaultClassifier(Properties props) {
CRFClassifier<INN> crf = new CRFClassifier<INN>();
crf.loadDefaultClassifier(props);
return crf;
}
/**
* Used to load a classifier stored as a resource inside a jar file. THIS
* FUNCTION WILL ONLY WORK IF THE CODE WAS LOADED FROM A JAR FILE WHICH HAS A
* SERIALIZED CLASSIFIER STORED INSIDE IT.
*
* @param resourceName Name of classifier resource inside the jar file.
* @return A CRFClassifier stored in the jar file
*/
public static <INN extends CoreMap> CRFClassifier<INN> getJarClassifier(String resourceName, Properties props) {
CRFClassifier<INN> crf = new CRFClassifier<INN>();
crf.loadJarClassifier(resourceName, props);
return crf;
}
/**
* Loads a CRF classifier from a filepath, and returns it.
*
* @param file
* File to load classifier from
* @return The CRF classifier
*
* @throws IOException
* If there are problems accessing the input stream
* @throws ClassCastException
* If there are problems interpreting the serialized data
* @throws ClassNotFoundException
* If there are problems interpreting the serialized data
*/
public static <INN extends CoreMap> CRFClassifier<INN> getClassifier(File file) throws IOException, ClassCastException,
ClassNotFoundException {
CRFClassifier<INN> crf = new CRFClassifier<INN>();
crf.loadClassifier(file);
return crf;
}
/**
* Loads a CRF classifier from an InputStream, and returns it. This method
* does not buffer the InputStream, so you should have buffered it before
* calling this method.
*
* @param in InputStream to load classifier from
* @return The CRF classifier
*
* @throws IOException If there are problems accessing the input stream
* @throws ClassCastException If there are problems interpreting the serialized data
* @throws ClassNotFoundException If there are problems interpreting the serialized data
*/
public static <INN extends CoreMap> CRFClassifier<INN> getClassifier(InputStream in) throws IOException, ClassCastException,
ClassNotFoundException {
CRFClassifier<INN> crf = new CRFClassifier<INN>();
crf.loadClassifier(in);
return crf;
}
public static <INN extends CoreMap> CRFClassifier<INN> getClassifierNoExceptions(String loadPath) {
CRFClassifier<INN> crf = new CRFClassifier<INN>();
crf.loadClassifierNoExceptions(loadPath);
return crf;
}
public static CRFClassifier<CoreLabel> getClassifier(String loadPath) throws IOException, ClassCastException,
ClassNotFoundException {
CRFClassifier<CoreLabel> crf = new CRFClassifier<CoreLabel>();
crf.loadClassifier(loadPath);
return crf;
}
public static <INN extends CoreMap> CRFClassifier<INN> getClassifier(String loadPath, Properties props) throws IOException, ClassCastException,
ClassNotFoundException {
CRFClassifier<INN> crf = new CRFClassifier<INN>();
crf.loadClassifier(loadPath, props);
return crf;
}
private static CRFClassifier<CoreLabel> chooseCRFClassifier(SeqClassifierFlags flags) {
CRFClassifier<CoreLabel> crf; // initialized in if/else
if (flags.useFloat) {
crf = new CRFClassifierFloat<CoreLabel>(flags);
} else if (flags.nonLinearCRF) {
crf = new CRFClassifierNonlinear<CoreLabel>(flags);
} else if (flags.numLopExpert > 1) {
crf = new CRFClassifierWithLOP<CoreLabel>(flags);
} else if (flags.priorType.equals("DROPOUT")) {
crf = new CRFClassifierWithDropout<CoreLabel>(flags);
} else if (flags.useNoisyLabel) {
crf = new CRFClassifierNoisyLabel<CoreLabel>(flags);
} else {
crf = new CRFClassifier<CoreLabel>(flags);
}
return crf;
}
/** The main method. See the class documentation. */
public static void main(String[] args) throws Exception {
StringUtils.printErrInvocationString("CRFClassifier", args);
Properties props = StringUtils.argsToProperties(args);
SeqClassifierFlags flags = new SeqClassifierFlags(props);
CRFClassifier<CoreLabel> crf = chooseCRFClassifier(flags);
String testFile = flags.testFile;
String testFiles = flags.testFiles;
String textFile = flags.textFile;
String textFiles = flags.textFiles;
String loadPath = flags.loadClassifier;
String loadTextPath = flags.loadTextClassifier;
String serializeTo = flags.serializeTo;
String serializeToText = flags.serializeToText;
if (crf.flags.useEmbedding && crf.flags.embeddingWords != null && crf.flags.embeddingVectors != null) {
System.err.println("Reading Embedding Files");
BufferedReader br = IOUtils.readerFromString(crf.flags.embeddingWords);
List<String> wordList = new ArrayList<String>();
for (String line ; (line = br.readLine()) != null; ) {
wordList.add(line.trim());
}
System.err.println("Found a dictionary of size " + wordList.size());
br.close();
crf.embeddings = Generics.newHashMap();
int count = 0;
int vectorSize = -1;
boolean warned = false;
br = IOUtils.readerFromString(crf.flags.embeddingVectors);
for (String line ; (line = br.readLine()) != null; ) {
double[] vector = ArrayUtils.toDoubleArray(line.trim().split(" "));
if (vectorSize < 0) {
vectorSize = vector.length;
} else {
if (vectorSize != vector.length && ! warned) {
System.err.println("Inconsistent vector lengths: " + vectorSize + " vs. " + vector.length);
warned = true;
}
}
crf.embeddings.put(wordList.get(count++), vector);
}
System.err.println("Found " + count + " matching embeddings of dimension " + vectorSize);
}
if (crf.flags.loadClassIndexFrom != null) {
crf.classIndex = loadClassIndexFromFile(crf.flags.loadClassIndexFrom);
}
if (loadPath != null) {
crf.loadClassifierNoExceptions(loadPath, props);
} else if (loadTextPath != null) {
System.err.println("Warning: this is now only tested for Chinese Segmenter");
System.err.println("(Sun Dec 23 00:59:39 2007) (pichuan)");
try {
crf.loadTextClassifier(loadTextPath, props);
// System.err.println("DEBUG: out from crf.loadTextClassifier");
} catch (Exception e) {
throw new RuntimeException("error loading " + loadTextPath, e);
}
} else if (crf.flags.loadJarClassifier != null) {
crf.loadJarClassifier(crf.flags.loadJarClassifier, props);
} else if (crf.flags.trainFile != null || crf.flags.trainFileList != null) {
Timing timing = new Timing();
crf.train();
timing.done("CRFClassifier training");
} else {
crf.loadDefaultClassifier();
}
crf.loadTagIndex();
if (serializeTo != null) {
crf.serializeClassifier(serializeTo);
}
if (crf.flags.serializeWeightsTo != null) {
crf.serializeWeights(crf.flags.serializeWeightsTo);
}
if (crf.flags.serializeFeatureIndexTo != null) {
crf.serializeFeatureIndex(crf.flags.serializeFeatureIndexTo);
}
if (serializeToText != null) {
crf.serializeTextClassifier(serializeToText);
}
if (testFile != null) {
DocumentReaderAndWriter<CoreLabel> readerAndWriter = crf.defaultReaderAndWriter();
if (crf.flags.searchGraphPrefix != null) {
crf.classifyAndWriteViterbiSearchGraph(testFile, crf.flags.searchGraphPrefix, crf.makeReaderAndWriter());
} else if (crf.flags.printFirstOrderProbs) {
crf.printFirstOrderProbs(testFile, readerAndWriter);
} else if (crf.flags.printFactorTable) {
crf.printFactorTable(testFile, readerAndWriter);
} else if (crf.flags.printProbs) {
crf.printProbs(testFile, readerAndWriter);
} else if (crf.flags.useKBest) {
int k = crf.flags.kBest;
crf.classifyAndWriteAnswersKBest(testFile, k, readerAndWriter);
} else if (crf.flags.printLabelValue) {
crf.printLabelInformation(testFile, readerAndWriter);
} else {
crf.classifyAndWriteAnswers(testFile, readerAndWriter, true);
}
}
if (testFiles != null) {
List<File> files = new ArrayList<File>();
for (String filename : testFiles.split(",")) {
files.add(new File(filename));
}
crf.classifyFilesAndWriteAnswers(files, crf.defaultReaderAndWriter(), true);
}
if (textFile != null) {
crf.classifyAndWriteAnswers(textFile);
}
if (textFiles != null) {
List<File> files = new ArrayList<File>();
for (String filename : textFiles.split(",")) {
files.add(new File(filename));
}
crf.classifyFilesAndWriteAnswers(files);
}
if (crf.flags.readStdin) {
crf.classifyStdin();
}
} // end main
@Override
public List<IN> classifyWithGlobalInformation(List<IN> tokenSeq, final CoreMap doc, final CoreMap sent) {
return classify(tokenSeq);
}
public void writeWeights(PrintStream p) {
for (String feature : featureIndex) {
int index = featureIndex.indexOf(feature);
// line.add(feature+"["+(-p)+"]");
// rowHeaders.add(feature + '[' + (-p) + ']');
double[] v = weights[index];
Index<CRFLabel> l = this.labelIndices.get(0);
p.println(feature + "\t\t");
for (CRFLabel label : l) {
p.print(label.toString(classIndex) + ":" + v[l.indexOf(label)] + "\t");
}
p.println();
}
}
public Map<String, Counter<String>> topWeights() {
Map<String, Counter<String>> w = new HashMap<String, Counter<String>>();
for (String feature : featureIndex) {
int index = featureIndex.indexOf(feature);
// line.add(feature+"["+(-p)+"]");
// rowHeaders.add(feature + '[' + (-p) + ']');
double[] v = weights[index];
Index<CRFLabel> l = this.labelIndices.get(0);
for (CRFLabel label : l) {
if(!w.containsKey(label.toString(classIndex)))
w.put(label.toString(classIndex), new ClassicCounter<String>());
w.get(label.toString(classIndex)).setCount(feature, v[l.indexOf(label)]);
}
}
return w;
}
} // end class CRFClassifier