// CRFClassifier -- a probabilistic (CRF) sequence model, mainly used for NER.
// Copyright (c) 2002-2008 The Board of Trustees of
// The Leland Stanford Junior University. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
//
// For more information, bug reports, fixes, contact:
// Christopher Manning
// Dept of Computer Science, Gates 1A
// Stanford CA 94305-9010
// USA
// Support/Questions: java-nlp-user@lists.stanford.edu
// Licensing: java-nlp-support@lists.stanford.edu
package edu.stanford.nlp.ie.crf;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.optimization.*;
import edu.stanford.nlp.sequences.*;
import edu.stanford.nlp.util.*;
import java.io.*;
import java.util.*;
import java.util.zip.GZIPInputStream;
/**
* Subclass of {@link edu.stanford.nlp.ie.crf.CRFClassifier} for implementing the nonlinear architecture in [Wang and Manning IJCNLP-2013 Effect of Nonlinear ...].
*
* @author Mengqiu Wang
*/
public class CRFClassifierNonlinear<IN extends CoreMap> extends CRFClassifier<IN> {
/** Parameter weights of the classifier. */
double[][] linearWeights;
double[][] inputLayerWeights4Edge;
double[][] outputLayerWeights4Edge;
double[][] inputLayerWeights;
double[][] outputLayerWeights;
protected CRFClassifierNonlinear() {
super(new SeqClassifierFlags());
}
public CRFClassifierNonlinear(Properties props) {
super(props);
}
public CRFClassifierNonlinear(SeqClassifierFlags flags) {
super(flags);
}
@Override
public Triple<int[][][], int[], double[][][]> documentToDataAndLabels(List<IN> document) {
Triple<int[][][], int[], double[][][]> result = super.documentToDataAndLabels(document);
int[][][] data = result.first();
data = transformDocData(data);
return new Triple<int[][][], int[], double[][][]>(data, result.second(), result.third());
}
private int[][][] transformDocData(int[][][] docData) {
int[][][] transData = new int[docData.length][][];
for (int i = 0; i < docData.length; i++) {
transData[i] = new int[docData[i].length][];
for (int j = 0; j < docData[i].length; j++) {
int[] cliqueFeatures = docData[i][j];
transData[i][j] = new int[cliqueFeatures.length];
for (int n = 0; n < cliqueFeatures.length; n++) {
int transFeatureIndex = -1;
if (j == 0) {
transFeatureIndex = nodeFeatureIndicesMap.indexOf(cliqueFeatures[n]);
if (transFeatureIndex == -1)
throw new RuntimeException("node cliqueFeatures[n]="+cliqueFeatures[n]+" not found, nodeFeatureIndicesMap.size="+nodeFeatureIndicesMap.size());
} else {
transFeatureIndex = edgeFeatureIndicesMap.indexOf(cliqueFeatures[n]);
if (transFeatureIndex == -1)
throw new RuntimeException("edge cliqueFeatures[n]="+cliqueFeatures[n]+" not found, edgeFeatureIndicesMap.size="+edgeFeatureIndicesMap.size());
}
transData[i][j][n] = transFeatureIndex;
}
}
}
return transData;
}
@Override
protected CliquePotentialFunction getCliquePotentialFunctionForTest() {
if (cliquePotentialFunction == null) {
if (flags.secondOrderNonLinear)
cliquePotentialFunction = new NonLinearSecondOrderCliquePotentialFunction(inputLayerWeights4Edge, outputLayerWeights4Edge, inputLayerWeights, outputLayerWeights, flags);
else
cliquePotentialFunction = new NonLinearCliquePotentialFunction(linearWeights, inputLayerWeights, outputLayerWeights, flags);
}
return cliquePotentialFunction;
}
@Override
protected double[] trainWeights(int[][][][] data, int[][] labels, Evaluator[] evaluators, int pruneFeatureItr, double[][][][] featureVals) {
if (flags.secondOrderNonLinear) {
CRFNonLinearSecondOrderLogConditionalObjectiveFunction func = new CRFNonLinearSecondOrderLogConditionalObjectiveFunction(data, labels,
windowSize, classIndex, labelIndices, map, flags, nodeFeatureIndicesMap.size(), edgeFeatureIndicesMap.size());
cliquePotentialFunctionHelper = func;
double[] allWeights = trainWeightsUsingNonLinearCRF(func, evaluators);
Quadruple<double[][], double[][], double[][], double[][]> params = func.separateWeights(allWeights);
this.inputLayerWeights4Edge = params.first();
this.outputLayerWeights4Edge = params.second();
this.inputLayerWeights = params.third();
this.outputLayerWeights = params.fourth();
} else {
CRFNonLinearLogConditionalObjectiveFunction func = new CRFNonLinearLogConditionalObjectiveFunction(data, labels,
windowSize, classIndex, labelIndices, map, flags, nodeFeatureIndicesMap.size(), edgeFeatureIndicesMap.size(), featureVals);
if (flags.useAdaGradFOBOS) {
func.gradientsOnly = true;
}
cliquePotentialFunctionHelper = func;
double[] allWeights = trainWeightsUsingNonLinearCRF(func, evaluators);
Triple<double[][], double[][], double[][]> params = func.separateWeights(allWeights);
this.linearWeights = params.first();
this.inputLayerWeights = params.second();
this.outputLayerWeights = params.third();
}
return null;
}
private double[] trainWeightsUsingNonLinearCRF(AbstractCachingDiffFunction func, Evaluator[] evaluators) {
Minimizer minimizer = getMinimizer(0, evaluators);
double[] initialWeights;
if (flags.initialWeights == null) {
initialWeights = func.initial();
} else {
try {
System.err.println("Reading initial weights from file " + flags.initialWeights);
DataInputStream dis = new DataInputStream(new BufferedInputStream(new GZIPInputStream(new FileInputStream(
flags.initialWeights))));
initialWeights = ConvertByteArray.readDoubleArr(dis);
} catch (IOException e) {
throw new RuntimeException("Could not read from double initial weight file " + flags.initialWeights);
}
}
System.err.println("numWeights: " + initialWeights.length);
if (flags.testObjFunction) {
StochasticDiffFunctionTester tester = new StochasticDiffFunctionTester(func);
if (tester.testSumOfBatches(initialWeights, 1e-4)) {
System.err.println("Testing complete... exiting");
System.exit(1);
} else {
System.err.println("Testing failed....exiting");
System.exit(1);
}
}
//check gradient
if (flags.checkGradient) {
if (func.gradientCheck()) {
System.err.println("gradient check passed");
} else {
throw new RuntimeException("gradient check failed");
}
}
return minimizer.minimize(func, flags.tolerance, initialWeights);
}
@Override
protected void serializeTextClassifier(PrintWriter pw) throws Exception {
super.serializeTextClassifier(pw);
pw.printf("nodeFeatureIndicesMap.size()=\t%d%n", nodeFeatureIndicesMap.size());
for (int i = 0; i < nodeFeatureIndicesMap.size(); i++) {
pw.printf("%d\t%d%n", i, nodeFeatureIndicesMap.get(i));
}
pw.printf("edgeFeatureIndicesMap.size()=\t%d%n", edgeFeatureIndicesMap.size());
for (int i = 0; i < edgeFeatureIndicesMap.size(); i++) {
pw.printf("%d\t%d%n", i, edgeFeatureIndicesMap.get(i));
}
if (flags.secondOrderNonLinear) {
pw.printf("inputLayerWeights4Edge.length=\t%d%n", inputLayerWeights4Edge.length);
for (double[] ws : inputLayerWeights4Edge) {
ArrayList<Double> list = new ArrayList<Double>();
for (double w : ws) {
list.add(w);
}
pw.printf("%d\t%s%n", ws.length, StringUtils.join(list, " "));
}
pw.printf("outputLayerWeights4Edge.length=\t%d%n", outputLayerWeights4Edge.length);
for (double[] ws : outputLayerWeights4Edge) {
ArrayList<Double> list = new ArrayList<Double>();
for (double w : ws) {
list.add(w);
}
pw.printf("%d\t%s%n", ws.length, StringUtils.join(list, " "));
}
} else {
pw.printf("linearWeights.length=\t%d%n", linearWeights.length);
for (double[] ws : linearWeights) {
ArrayList<Double> list = new ArrayList<Double>();
for (double w : ws) {
list.add(w);
}
pw.printf("%d\t%s%n", ws.length, StringUtils.join(list, " "));
}
}
pw.printf("inputLayerWeights.length=\t%d%n", inputLayerWeights.length);
for (double[] ws : inputLayerWeights) {
ArrayList<Double> list = new ArrayList<Double>();
for (double w : ws) {
list.add(w);
}
pw.printf("%d\t%s%n", ws.length, StringUtils.join(list, " "));
}
pw.printf("outputLayerWeights.length=\t%d%n", outputLayerWeights.length);
for (double[] ws : outputLayerWeights) {
ArrayList<Double> list = new ArrayList<Double>();
for (double w : ws) {
list.add(w);
}
pw.printf("%d\t%s%n", ws.length, StringUtils.join(list, " "));
}
}
@Override
protected void loadTextClassifier(BufferedReader br) throws Exception {
super.loadTextClassifier(br);
String line = br.readLine();
String[] toks = line.split("\\t");
if (!toks[0].equals("nodeFeatureIndicesMap.size()=")) {
throw new RuntimeException("format error in nodeFeatureIndicesMap");
}
int nodeFeatureIndicesMapSize = Integer.parseInt(toks[1]);
nodeFeatureIndicesMap = new HashIndex<Integer>();
int count = 0;
while (count < nodeFeatureIndicesMapSize) {
line = br.readLine();
toks = line.split("\\t");
int idx = Integer.parseInt(toks[0]);
if (count != idx) {
throw new RuntimeException("format error");
}
nodeFeatureIndicesMap.add(Integer.parseInt(toks[1]));
count++;
}
line = br.readLine();
toks = line.split("\\t");
if (!toks[0].equals("edgeFeatureIndicesMap.size()=")) {
throw new RuntimeException("format error");
}
int edgeFeatureIndicesMapSize = Integer.parseInt(toks[1]);
edgeFeatureIndicesMap = new HashIndex<Integer>();
count = 0;
while (count < edgeFeatureIndicesMapSize) {
line = br.readLine();
toks = line.split("\\t");
int idx = Integer.parseInt(toks[0]);
if (count != idx) {
throw new RuntimeException("format error");
}
edgeFeatureIndicesMap.add(Integer.parseInt(toks[1]));
count++;
}
int weightsLength = -1;
if (flags.secondOrderNonLinear) {
line = br.readLine();
toks = line.split("\\t");
if (!toks[0].equals("inputLayerWeights4Edge.length=")) {
throw new RuntimeException("format error");
}
weightsLength = Integer.parseInt(toks[1]);
inputLayerWeights4Edge = new double[weightsLength][];
count = 0;
while (count < weightsLength) {
line = br.readLine();
toks = line.split("\\t");
int weights2Length = Integer.parseInt(toks[0]);
inputLayerWeights4Edge[count] = new double[weights2Length];
String[] weightsValue = toks[1].split(" ");
if (weights2Length != weightsValue.length) {
throw new RuntimeException("weights format error");
}
for (int i2 = 0; i2 < weights2Length; i2++) {
inputLayerWeights4Edge[count][i2] = Double.parseDouble(weightsValue[i2]);
}
count++;
}
line = br.readLine();
toks = line.split("\\t");
if (!toks[0].equals("outputLayerWeights4Edge.length=")) {
throw new RuntimeException("format error");
}
weightsLength = Integer.parseInt(toks[1]);
outputLayerWeights4Edge = new double[weightsLength][];
count = 0;
while (count < weightsLength) {
line = br.readLine();
toks = line.split("\\t");
int weights2Length = Integer.parseInt(toks[0]);
outputLayerWeights4Edge[count] = new double[weights2Length];
String[] weightsValue = toks[1].split(" ");
if (weights2Length != weightsValue.length) {
throw new RuntimeException("weights format error");
}
for (int i2 = 0; i2 < weights2Length; i2++) {
outputLayerWeights4Edge[count][i2] = Double.parseDouble(weightsValue[i2]);
}
count++;
}
} else {
line = br.readLine();
toks = line.split("\\t");
if (!toks[0].equals("linearWeights.length=")) {
throw new RuntimeException("format error");
}
weightsLength = Integer.parseInt(toks[1]);
linearWeights = new double[weightsLength][];
count = 0;
while (count < weightsLength) {
line = br.readLine();
toks = line.split("\\t");
int weights2Length = Integer.parseInt(toks[0]);
linearWeights[count] = new double[weights2Length];
String[] weightsValue = toks[1].split(" ");
if (weights2Length != weightsValue.length) {
throw new RuntimeException("weights format error");
}
for (int i2 = 0; i2 < weights2Length; i2++) {
linearWeights[count][i2] = Double.parseDouble(weightsValue[i2]);
}
count++;
}
}
line = br.readLine();
toks = line.split("\\t");
if (!toks[0].equals("inputLayerWeights.length=")) {
throw new RuntimeException("format error");
}
weightsLength = Integer.parseInt(toks[1]);
inputLayerWeights = new double[weightsLength][];
count = 0;
while (count < weightsLength) {
line = br.readLine();
toks = line.split("\\t");
int weights2Length = Integer.parseInt(toks[0]);
inputLayerWeights[count] = new double[weights2Length];
String[] weightsValue = toks[1].split(" ");
if (weights2Length != weightsValue.length) {
throw new RuntimeException("weights format error");
}
for (int i2 = 0; i2 < weights2Length; i2++) {
inputLayerWeights[count][i2] = Double.parseDouble(weightsValue[i2]);
}
count++;
}
line = br.readLine();
toks = line.split("\\t");
if (!toks[0].equals("outputLayerWeights.length=")) {
throw new RuntimeException("format error");
}
weightsLength = Integer.parseInt(toks[1]);
outputLayerWeights = new double[weightsLength][];
count = 0;
while (count < weightsLength) {
line = br.readLine();
toks = line.split("\\t");
int weights2Length = Integer.parseInt(toks[0]);
outputLayerWeights[count] = new double[weights2Length];
String[] weightsValue = toks[1].split(" ");
if (weights2Length != weightsValue.length) {
throw new RuntimeException("weights format error");
}
for (int i2 = 0; i2 < weights2Length; i2++) {
outputLayerWeights[count][i2] = Double.parseDouble(weightsValue[i2]);
}
count++;
}
}
@Override
public void serializeClassifier(ObjectOutputStream oos) {
try {
super.serializeClassifier(oos);
oos.writeObject(nodeFeatureIndicesMap);
oos.writeObject(edgeFeatureIndicesMap);
if (flags.secondOrderNonLinear) {
oos.writeObject(inputLayerWeights4Edge);
oos.writeObject(outputLayerWeights4Edge);
} else {
oos.writeObject(linearWeights);
}
oos.writeObject(inputLayerWeights);
oos.writeObject(outputLayerWeights);
} catch (IOException e) {
throw new RuntimeIOException(e);
}
}
@Override
@SuppressWarnings( { "unchecked" })
// can't have right types in deserialization
public void loadClassifier(ObjectInputStream ois, Properties props) throws ClassCastException, IOException,
ClassNotFoundException {
super.loadClassifier(ois, props);
nodeFeatureIndicesMap = (Index<Integer>) ois.readObject();
edgeFeatureIndicesMap = (Index<Integer>) ois.readObject();
if (flags.secondOrderNonLinear) {
inputLayerWeights4Edge = (double[][]) ois.readObject();
outputLayerWeights4Edge = (double[][]) ois.readObject();
} else {
linearWeights = (double[][]) ois.readObject();
}
inputLayerWeights = (double[][]) ois.readObject();
outputLayerWeights = (double[][]) ois.readObject();
}
} // end class CRFClassifierNonlinear