Package edu.stanford.nlp.ie.ner

Source Code of edu.stanford.nlp.ie.ner.CMMClassifier

// CMMClassifier -- a conditional maximum-entropy markov model, mainly used for NER.
// Copyright (c) 2002-2014 The Board of Trustees of
// The Leland Stanford Junior University. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software Foundation,
// Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
//
// For more information, bug reports, fixes, contact:
//    Christopher Manning
//    Dept of Computer Science, Gates 1A
//    Stanford CA 94305-9010
//    USA
//    Support/Questions: java-nlp-user@lists.stanford.edu
//    Licensing: java-nlp-support@lists.stanford.edu

package edu.stanford.nlp.ie.ner;

import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Properties;
import java.util.Set;
import java.util.regex.Pattern;

import edu.stanford.nlp.classify.Dataset;
import edu.stanford.nlp.classify.LinearClassifier;
import edu.stanford.nlp.classify.LinearClassifierFactory;
import edu.stanford.nlp.classify.LogPrior;
import edu.stanford.nlp.classify.NBLinearClassifierFactory;
import edu.stanford.nlp.classify.ProbabilisticClassifier;
import edu.stanford.nlp.classify.SVMLightClassifierFactory;
import edu.stanford.nlp.ie.AbstractSequenceClassifier;
import edu.stanford.nlp.ie.NERFeatureFactory;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.ling.BasicDatum;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.math.SloppyMath;
import edu.stanford.nlp.objectbank.ObjectBank;
import edu.stanford.nlp.sequences.BeamBestSequenceFinder;
import edu.stanford.nlp.sequences.Clique;
import edu.stanford.nlp.sequences.DocumentReaderAndWriter;
import edu.stanford.nlp.sequences.ExactBestSequenceFinder;
import edu.stanford.nlp.sequences.FeatureFactory;
import edu.stanford.nlp.sequences.PlainTextDocumentReaderAndWriter;
import edu.stanford.nlp.sequences.SeqClassifierFlags;
import edu.stanford.nlp.sequences.SequenceModel;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.ErasureUtils;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.PaddedList;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;


/**
* Does Sequence Classification using a Conditional Markov Model.
* It could be used for other purposes, but the provided features
* are aimed at doing Named Entity Recognition.
* The code has functionality for different document encodings, but when
* using the standard <code>ColumnDocumentReader</code>,
* input files are expected to
* be one word per line with the columns indicating things like the word,
* POS, chunk, and class.
* <p/>
* <b>Typical usage</b>
* <p>For running a trained model with a provided serialized classifier: <p>
* <code>
* java -server -mx1000m edu.stanford.nlp.ie.ner.CMMClassifier -loadClassifier
* conll.ner.gz -textFile samplesentences.txt
* </code><p>
* When specifying all parameters in a properties file (train, test, or
* runtime):<p>
* <code>
* java -mx1000m edu.stanford.nlp.ie.ner.CMMClassifier -prop propFile
* </code><p>
* To train and test a model from the command line:<p>
* <code>java -mx1000m edu.stanford.nlp.ie.ner.CMMClassifier
* -trainFile trainFile -testFile testFile -goodCoNLL &gt; output </code>
* <p/>
* Features are defined by a {@link FeatureFactory}; the
* {@link FeatureFactory} which is used by default is
* {@link NERFeatureFactory}, and you should look there for feature templates.
* Features are specified either by a Properties file (which is the
* recommended method) or on the command line.  The features are read into
* a {@link SeqClassifierFlags} object, which the
* user need not know much about, unless one wishes to add new features.
* <p/>
* CMMClassifier may also be used programmatically.  When creating a new instance, you
* <i>must</i> specify a properties file.  The other way to get a CMMClassifier is to
* deserialize one via {@link CMMClassifier#getClassifier(String)}, which returns a
* deserialized classifier.  You may then tag sentences using either the assorted
* <code>test</code> or <code>testSentence</code> methods.
*
* @author Dan Klein
* @author Jenny Finkel
* @author Christopher Manning
* @author Shipra Dingare
* @author Huy Nguyen
* @author Sarah Spikes (sdspikes@cs.stanford.edu) - cleanup and filling in types
*/

public class CMMClassifier<IN extends CoreLabel> extends AbstractSequenceClassifier<IN> {

  private ProbabilisticClassifier<String, String> classifier;

  /** The set of empirically legal label sequences (of length (order) at most
   *  <code>flags.maxLeft</code>).  Used to filter valid class sequences if
   *  <code>useObuseObservedSequencesOnly</code> is set.
   */
  Set<List<String>> answerArrays;

  /** Default place to look in Jar file for classifier. */
  public static final String DEFAULT_CLASSIFIER = "/classifiers/ner-eng-ie.cmm-3-all2006.ser.gz";

  protected CMMClassifier() {
    super(new SeqClassifierFlags());
  }

  public CMMClassifier(Properties props) {
    super(props);
  }


  public CMMClassifier(SeqClassifierFlags flags) {
    super(flags);
  }

  /**
   * Returns the Set of entities recognized by this Classifier.
   *
   * @return The Set of entities recognized by this Classifier.
   */
  public Set<String> getTags() {
    Set<String> tags = Generics.newHashSet(classIndex.objectsList());
    tags.remove(flags.backgroundSymbol);
    return tags;
  }

  /**
   * Classify a {@link List} of {@link CoreLabel}s.
   *
   * @param document A {@link List} of {@link CoreLabel}s
   *                 to be classified.
   */
  @Override
  public List<IN> classify(List<IN> document) {
    if (flags.useSequences) {
      classifySeq(document);
    } else {
      classifyNoSeq(document);
    }
    return document;
  }

  /**
   * Classify a List of {@link CoreLabel}s without using sequence information
   * (i.e. no Viterbi algorithm, just distribution over next class).
   *
   * @param document a List of {@link CoreLabel}s to be classified
   */
  private void classifyNoSeq(List<IN> document) {
    if (flags.useReverse) {
      Collections.reverse(document);
    }

    if (flags.lowerNewgeneThreshold) {
      // Used to raise recall for task 1B
      System.err.println("Using NEWGENE threshold: " + flags.newgeneThreshold);
      for (int i = 0, docSize = document.size(); i < docSize; i++) {
        CoreLabel wordInfo = document.get(i);
        Datum<String, String> d = makeDatum(document, i, featureFactories);
        Counter<String> scores = classifier.scoresOf(d);
        //String answer = BACKGROUND;
        String answer = flags.backgroundSymbol;
        // HN: The evaluation of scoresOf seems to result in some
        // kind of side effect.  Specifically, the symptom is that
        // if scoresOf is not evaluated at every position, the
        // answers are different
        if ("NEWGENE".equals(wordInfo.get(CoreAnnotations.GazAnnotation.class))) {
          for (String label : scores.keySet()) {
            if ("G".equals(label)) {
              System.err.println(wordInfo.word() + ':' + scores.getCount(label));
              if (scores.getCount(label) > flags.newgeneThreshold) {
                answer = label;
              }
            }
          }
        }
        wordInfo.set(CoreAnnotations.AnswerAnnotation.class, answer);
      }
    } else {
      for (int i = 0, listSize = document.size(); i < listSize; i++) {
        String answer = classOf(document, i);
        CoreLabel wordInfo = document.get(i);
        //System.err.println("XXX answer for " +
        //        wordInfo.word() + " is " + answer);
        wordInfo.set(CoreAnnotations.AnswerAnnotation.class, answer);
      }
      if (flags.justify && (classifier instanceof LinearClassifier)) {
        LinearClassifier<String, String> lc = (LinearClassifier<String, String>) classifier;
        for (int i = 0, lsize = document.size(); i < lsize; i++) {
          CoreLabel lineInfo = document.get(i);
          System.err.print("@@ Position " + i + ": ");
          System.err.println(lineInfo.word() + " chose " + lineInfo.get(CoreAnnotations.AnswerAnnotation.class));
          lc.justificationOf(makeDatum(document, i, featureFactories));
        }
      }
    }
    if (flags.useReverse) {
      Collections.reverse(document);
    }
  }

  /**
   * Returns the most likely class for the word at the given position.
   */
  protected String classOf(List<IN> lineInfos, int pos) {
    Datum<String, String> d = makeDatum(lineInfos, pos, featureFactories);
    return classifier.classOf(d);
  }

  /**
   * Returns the log conditional likelihood of the given dataset.
   *
   * @return The log conditional likelihood of the given dataset.
   */
  public double loglikelihood(List<IN> lineInfos) {
    double cll = 0.0;

    for (int i = 0; i < lineInfos.size(); i++) {
      Datum<String, String> d = makeDatum(lineInfos, i, featureFactories);
      Counter<String> c = classifier.logProbabilityOf(d);

      double total = Double.NEGATIVE_INFINITY;
      for (String s : c.keySet()) {
        total = SloppyMath.logAdd(total, c.getCount(s));
      }
      cll -= c.getCount(d.label()) - total;
    }
    // quadratic prior
    // HN: TODO: add other priors

    if (classifier instanceof LinearClassifier) {
      double sigmaSq = flags.sigma * flags.sigma;
      LinearClassifier<String, String> lc = (LinearClassifier<String, String>)classifier;
      for (String feature: lc.features()) {
        for (String classLabel: classIndex) {
          double w = lc.weight(feature, classLabel);
          cll += w * w / 2.0 / sigmaSq;
        }
      }
    }
    return cll;
  }

  @Override
  public SequenceModel getSequenceModel(List<IN> document) {
    //System.err.println(flags.useReverse);

    if (flags.useReverse) {
      Collections.reverse(document);
    }

    // cdm Aug 2005: why is this next line needed?  Seems really ugly!!!  [2006: it broke things! removed]
    // document.add(0, new CoreLabel());

    SequenceModel ts = new Scorer<IN>(document,
                                      classIndex,
                                      this,
                                      (!flags.useTaggySequences ? (flags.usePrevSequences ? 1 : 0) : flags.maxLeft),
                                      (flags.useNextSequences ? 1 : 0),
                                      answerArrays);

    return ts;
  }

  /**
   * Classify a List of {@link CoreLabel}s using sequence information
   * (i.e. Viterbi or Beam Search).
   *
   * @param document A List of {@link CoreLabel}s to be classified
   */
  private void classifySeq(List<IN> document) {

    if (document.isEmpty()) {
      return;
    }

    SequenceModel ts = getSequenceModel(document);

    //    TagScorer ts = new PrevOnlyScorer(document, tagIndex, this, (!flags.useTaggySequences ? (flags.usePrevSequences ? 1 : 0) : flags.maxLeft), 0, answerArrays);

    int[] tags;
    //System.err.println("***begin test***");
    if (flags.useViterbi) {
      ExactBestSequenceFinder ti = new ExactBestSequenceFinder();
      tags = ti.bestSequence(ts);
    } else {
      BeamBestSequenceFinder ti = new BeamBestSequenceFinder(flags.beamSize, true, true);
      tags = ti.bestSequence(ts, document.size());
    }
    //System.err.println("***end test***");

    // used to improve recall in task 1b
    if (flags.lowerNewgeneThreshold) {
      System.err.println("Using NEWGENE threshold: " + flags.newgeneThreshold);

      int[] copy = new int[tags.length];
      System.arraycopy(tags, 0, copy, 0, tags.length);

      // for each sequence marked as NEWGENE in the gazette
      // tag the entire sequence as NEWGENE and sum the score
      // if the score is greater than newgeneThreshold, accept
      int ngTag = classIndex.indexOf("G");
      //int bgTag = classIndex.indexOf(BACKGROUND);
      int bgTag = classIndex.indexOf(flags.backgroundSymbol);

      for (int i = 0, dSize = document.size(); i < dSize; i++) {
        CoreLabel wordInfo =document.get(i);

        if ("NEWGENE".equals(wordInfo.get(CoreAnnotations.GazAnnotation.class))) {
          int start = i;
          int j;
          for (j = i; j < document.size(); j++) {
            wordInfo = document.get(j);
            if (!"NEWGENE".equals(wordInfo.get(CoreAnnotations.GazAnnotation.class))) {
              break;
            }
          }
          int end = j;
          //int end = i + 1;

          int winStart = Math.max(0, start - 4);
          int winEnd = Math.min(tags.length, end + 4);
          // clear a window around the sequences
          for (j = winStart; j < winEnd; j++) {
            copy[j] = bgTag;
          }

          // score as nongene
          double bgScore = 0.0;
          for (j = start; j < end; j++) {
            double[] scores = ts.scoresOf(copy, j);
            scores = Scorer.recenter(scores);
            bgScore += scores[bgTag];
          }

          // first pass, compute all of the scores
          ClassicCounter<Pair<Integer,Integer>> prevScores = new ClassicCounter<Pair<Integer,Integer>>();
          for (j = start; j < end; j++) {
            // clear the sequence
            for (int k = start; k < end; k++) {
              copy[k] = bgTag;
            }

            // grow the sequence from j until the end
            for (int k = j; k < end; k++) {
              copy[k] = ngTag;
              // score the sequence
              double ngScore = 0.0;
              for (int m = start; m < end; m++) {
                double[] scores = ts.scoresOf(copy, m);
                scores = Scorer.recenter(scores);
                ngScore += scores[tags[m]];
              }
              prevScores.incrementCount(new Pair<Integer,Integer>(Integer.valueOf(j), Integer.valueOf(k)), ngScore - bgScore);
            }
          }
          for (j = start; j < end; j++) {
            // grow the sequence from j until the end
            for (int k = j; k < end; k++) {
              double score = prevScores.getCount(new Pair<Integer,Integer>(Integer.valueOf(j), Integer.valueOf(k)));
              Pair<Integer, Integer> al = new Pair<Integer,Integer>(Integer.valueOf(j - 1), Integer.valueOf(k)); // adding a word to the left
              Pair<Integer, Integer> ar = new Pair<Integer,Integer>(Integer.valueOf(j), Integer.valueOf(k + 1)); // adding a word to the right
              Pair<Integer, Integer> sl = new Pair<Integer,Integer>(Integer.valueOf(j + 1), Integer.valueOf(k)); // subtracting word from left
              Pair<Integer, Integer> sr = new Pair<Integer,Integer>(Integer.valueOf(j), Integer.valueOf(k - 1)); // subtracting word from right

              // make sure the score is greater than all its neighbors (one add or subtract)
              if (score >= flags.newgeneThreshold && (!prevScores.containsKey(al) || score > prevScores.getCount(al)) && (!prevScores.containsKey(ar) || score > prevScores.getCount(ar)) && (!prevScores.containsKey(sl) || score > prevScores.getCount(sl)) && (!prevScores.containsKey(sr) || score > prevScores.getCount(sr))) {
                StringBuilder sb = new StringBuilder();
                wordInfo = document.get(j);
                String docId = wordInfo.get(CoreAnnotations.IDAnnotation.class);
                String startIndex = wordInfo.get(CoreAnnotations.PositionAnnotation.class);
                wordInfo = document.get(k);
                String endIndex = wordInfo.get(CoreAnnotations.PositionAnnotation.class);
                for (int m = j; m <= k; m++) {
                  wordInfo = document.get(m);
                  sb.append(wordInfo.word());
                  sb.append(' ');
                }
                /*System.err.println(sb.toString()+"score:"+score+
                  " al:"+prevScores.getCount(al)+
                  " ar:"+prevScores.getCount(ar)+
                  "  sl:"+prevScores.getCount(sl)+" sr:"+ prevScores.getCount(sr));*/
                System.out.println(docId + '|' + startIndex + ' ' + endIndex + '|' + sb.toString().trim());
              }
            }
          }

          // restore the original tags
          for (j = winStart; j < winEnd; j++) {
            copy[j] = tags[j];
          }
          i = end;
        }
      }
    }

    for (int i = 0, docSize = document.size(); i < docSize; i++) {
      CoreLabel lineInfo = document.get(i);
      String answer = classIndex.get(tags[i]);
      lineInfo.set(CoreAnnotations.AnswerAnnotation.class, answer);
    }

    if (flags.justify && classifier instanceof LinearClassifier) {
      LinearClassifier<String, String> lc = (LinearClassifier<String, String>) classifier;
      if (flags.dump) {
        lc.dump();
      }
      for (int i = 0, docSize = document.size(); i < docSize; i++) {
        CoreLabel lineInfo = document.get(i);
        System.err.print("@@ Position is: " + i + ": ");
        System.err.println(lineInfo.word() + ' ' + lineInfo.get(CoreAnnotations.AnswerAnnotation.class));
        lc.justificationOf(makeDatum(document, i, featureFactories));
      }
    }

//    document.remove(0);

    if (flags.useReverse) {
      Collections.reverse(document);

    }
  } // end testSeq


  /**
   * @param filename adaptation file
   * @param trainDataset original dataset (used in training)
   */
  public void adapt(String filename, Dataset<String, String> trainDataset,
                    DocumentReaderAndWriter<IN> readerWriter) {
    flags.ocrTrain = false// ?? Do we need this? (Pi-Chuan Sat Nov  5 15:42:49 2005)
    ObjectBank<List<IN>> docs =
      makeObjectBankFromFile(filename, readerWriter);
    adapt(docs, trainDataset);
  }

  /**
   * @param featureLabels adaptation docs
   * @param trainDataset original dataset (used in training)
   */
  public void adapt(ObjectBank<List<IN>> featureLabels, Dataset<String, String> trainDataset) {
    Dataset<String, String> adapt = getDataset(featureLabels, trainDataset);
    adapt(adapt);
  }

  /**
   * @param featureLabels retrain docs
   * @param featureIndex featureIndex of original dataset (used in training)
   * @param labelIndex labelIndex of original dataset (used in training)
   */
  public void retrain(ObjectBank<List<IN>> featureLabels, Index<String> featureIndex, Index<String> labelIndex) {
    int fs = featureIndex.size(); // old dim
    int ls = labelIndex.size();   // old dim

    Dataset<String, String> adapt = getDataset(featureLabels, featureIndex, labelIndex);

    int prior = LogPrior.LogPriorType.QUADRATIC.ordinal();
    LinearClassifier<String, String> lc = (LinearClassifier<String, String>) classifier;
    LinearClassifierFactory<String, String> lcf = new LinearClassifierFactory<String, String>(flags.tolerance, flags.useSum, prior, flags.sigma, flags.epsilon, flags.QNsize);

    double[][] weights = lc.weights()// old dim
    Index<String> newF = adapt.featureIndex;
    Index<String> newL = adapt.labelIndex;
    int newFS = newF.size();
    int newLS = newL.size();
    double[] x = new double[newFS*newLS]; // new dim
    //System.err.println("old  ["+fs+"]"+"["+ls+"]");
    //System.err.println("new  ["+newFS+"]"+"["+newLS+"]");
    //System.err.println("new  ["+newFS*newLS+"]");
    for (int i = 0; i < fs; i++) {
      for (int j = 0; j < ls; j++) {
        String f = featureIndex.get(i);
        String l = labelIndex.get(j);
        int newi = newF.indexOf(f)*newLS+newL.indexOf(l);
        x[newi] = weights[i][j];
        //if (newi == 144745*2) {
        //System.err.println("What??"+i+"\t"+j);
        //}
      }
    }
    //System.err.println("x[144745*2]"+x[144745*2]);
    weights = lcf.trainWeights(adapt, x);
    //System.err.println("x[144745*2]"+x[144745*2]);
    //System.err.println("weights[144745]"+"[0]="+weights[144745][0]);

    lc.setWeights(weights);
    /*
    int delme = 0;
    if (true) {
      for (double[] dd : weights) {
        delme++;
        for (double d : dd) {
        }
      }
    }
    System.err.println(weights[delme-1][0]);
    System.err.println("size of weights: "+delme);
    */
  }


  public void retrain(ObjectBank<List<IN>> doc) {
    if (classifier == null) {
      throw new UnsupportedOperationException("Cannot retrain before you train!");
    }
    Index<String> findex = ((LinearClassifier<String, String>)classifier).featureIndex();
    Index<String> lindex = ((LinearClassifier<String, String>)classifier).labelIndex();
    System.err.println("Starting retrain:\t# of original features"+findex.size()+", # of original labels"+lindex.size());
    retrain(doc, findex, lindex);
  }


  @Override
  public void train(Collection<List<IN>> wordInfos,
                    DocumentReaderAndWriter<IN> readerAndWriter) {
    Dataset<String, String> train = getDataset(wordInfos);
    //train.summaryStatistics();
    //train.printSVMLightFormat();
    // wordInfos = null;  // cdm: I think this does no good as ptr exists in caller (could empty the list or better refactor so conversion done earlier?)
    train(train);

    for (int i = 0; i < flags.numTimesPruneFeatures; i++) {

      Index<String> featuresAboveThreshhold = getFeaturesAboveThreshhold(train, flags.featureDiffThresh);
      System.err.println("Removing features with weight below " + flags.featureDiffThresh + " and retraining...");
      train = getDataset(train, featuresAboveThreshhold);

      int tmp = flags.QNsize;
      flags.QNsize = flags.QNsize2;
      train(train);
      flags.QNsize = tmp;
    }

    if (flags.doAdaptation && flags.adaptFile != null) {
      adapt(flags.adaptFile,train,readerAndWriter);
    }

    System.err.print("Built this classifier: ");
    if (classifier instanceof LinearClassifier) {
      String classString = ((LinearClassifier<String, String>)classifier).toString(flags.printClassifier, flags.printClassifierParam);
      System.err.println(classString);
    } else {
      String classString = classifier.toString();
      System.err.println(classString);
    }
  }

  public Index<String> getFeaturesAboveThreshhold(Dataset<String, String> dataset, double thresh) {
    if (!(classifier instanceof LinearClassifier)) {
      throw new RuntimeException("Attempting to remove features based on weight from a non-linear classifier");
    }
    Index<String> featureIndex = dataset.featureIndex;
    Index<String> labelIndex = dataset.labelIndex;

    Index<String> features = new HashIndex<String>();
    Iterator<String> featureIt = featureIndex.iterator();
    LinearClassifier<String, String> lc = (LinearClassifier<String, String>)classifier;
    LOOP:
    while (featureIt.hasNext()) {
      String f = featureIt.next();
      Iterator<String> labelIt = labelIndex.iterator();
      double smallest = Double.POSITIVE_INFINITY;
      double biggest = Double.NEGATIVE_INFINITY;
      while (labelIt.hasNext()) {
        String l = labelIt.next();
        double weight = lc.weight(f, l);
        if (weight < smallest) {
          smallest = weight;
        }
        if (weight > biggest) {
          biggest = weight;
        }
        if (biggest - smallest > thresh) {
          features.add(f);
          continue LOOP;
        }
      }
    }
    return features;
  }

  /**
   * Build a Dataset from some data. Used for training a classifier.
   *
   * @param data This variable is a list of lists of CoreLabel.  That is,
   *             it is a collection of documents, each of which is represented
   *             as a sequence of CoreLabel objects.
   * @return The Dataset which is an efficient encoding of the information
   *         in a List of Datums
   */
  public Dataset<String, String> getDataset(Collection<List<IN>> data) {
    return getDataset(data, null, null);
  }

  /**
   * Build a Dataset from some data. Used for training a classifier.
   *
   * By passing in extra featureIndex and classIndex, you can get a Dataset based on featureIndex and
   * classIndex
   *
   * @param data This variable is a list of lists of CoreLabel.  That is,
   *             it is a collection of documents, each of which is represented
   *             as a sequence of CoreLabel objects.
   * @param classIndex if you want to get a Dataset based on featureIndex and
   *                    classIndex in an existing origDataset
   * @return The Dataset which is an efficient encoding of the information
   *         in a List of Datums
   */
  public Dataset<String, String> getDataset(Collection<List<IN>> data, Index<String> featureIndex, Index<String> classIndex) {
    makeAnswerArraysAndTagIndex(data);

    int size = 0;
    for (List<IN> doc : data) {
      size += doc.size();
    }

    System.err.println("Making Dataset...");
    Dataset<String, String> train;
    if (featureIndex != null && classIndex != null) {
      System.err.println("Using feature/class Index from existing Dataset...");
      System.err.println("(This is used when getting Dataset from adaptation set. We want to make the index consistent.)"); //pichuan
      train = new Dataset<String, String>(size, featureIndex, classIndex);
    } else {
      train = new Dataset<String, String>(size);
    }

    for (List<IN> doc : data) {
      if (flags.useReverse) {
        Collections.reverse(doc);
      }

      for (int i = 0, dsize = doc.size(); i < dsize; i++) {
        Datum<String, String> d = makeDatum(doc, i, featureFactories);

        //CoreLabel fl = doc.get(i);

        train.add(d);
      }

      if (flags.useReverse) {
        Collections.reverse(doc);
      }
    }

    System.err.println("done.");

    if (flags.featThreshFile != null) {
      System.err.println("applying thresholds...");
      List<Pair<Pattern, Integer>> thresh = getThresholds(flags.featThreshFile);
      train.applyFeatureCountThreshold(thresh);
    } else if (flags.featureThreshold > 1) {
      System.err.println("Removing Features with counts < " + flags.featureThreshold);
      train.applyFeatureCountThreshold(flags.featureThreshold);
    }
    train.summaryStatistics();
    return train;
  }

  public Dataset<String, String> getBiasedDataset(ObjectBank<List<IN>> data, Index<String> featureIndex, Index<String> classIndex) {
    makeAnswerArraysAndTagIndex(data);

    Index<String> origFeatIndex = new HashIndex<String>(featureIndex.objectsList()); // mg2009: TODO: check

    int size = 0;
    for (List<IN> doc : data) {
      size += doc.size();
    }

    System.err.println("Making Dataset...");
    Dataset<String, String> train = new Dataset<String, String>(size, featureIndex, classIndex);

    for (List<IN> doc : data) {
      if (flags.useReverse) {
        Collections.reverse(doc);
      }

      for (int i = 0, dsize = doc.size(); i < dsize; i++) {
        Datum<String, String> d = makeDatum(doc, i, featureFactories);
        Collection<String> newFeats = new ArrayList<String>();
        for (String f : d.asFeatures()) {
          if ( ! origFeatIndex.contains(f)) {
            newFeats.add(f);
          }
        }
//        System.err.println(d.label()+"\t"+d.asFeatures()+"\n\t"+newFeats);
//        d = new BasicDatum(newFeats, d.label());
        train.add(d);
      }

      if (flags.useReverse) {
        Collections.reverse(doc);
      }
    }

    System.err.println("done.");

    if (flags.featThreshFile != null) {
      System.err.println("applying thresholds...");
      List<Pair<Pattern, Integer>> thresh = getThresholds(flags.featThreshFile);
      train.applyFeatureCountThreshold(thresh);
    } else if (flags.featureThreshold > 1) {
      System.err.println("Removing Features with counts < " + flags.featureThreshold);
      train.applyFeatureCountThreshold(flags.featureThreshold);
    }
    train.summaryStatistics();
    return train;
  }




  /**
   * Build a Dataset from some data. Used for training a classifier.
   *
   * By passing in an extra origDataset, you can get a Dataset based on featureIndex and
   * classIndex in an existing origDataset.
   *
   * @param data This variable is a list of lists of CoreLabel.  That is,
   *             it is a collection of documents, each of which is represented
   *             as a sequence of CoreLabel objects.
   * @param origDataset if you want to get a Dataset based on featureIndex and
   *                    classIndex in an existing origDataset
   * @return The Dataset which is an efficient encoding of the information
   *         in a List of Datums
   */
  public Dataset<String, String> getDataset(ObjectBank<List<IN>> data, Dataset<String, String> origDataset) {
    if(origDataset == null) {
      return getDataset(data);
    }
    return getDataset(data, origDataset.featureIndex, origDataset.labelIndex);
  }


  /**
   * Build a Dataset from some data.
   *
   * @param oldData      This {@link Dataset} represents data for which we which to
   *                     some features, specifically those features not in the {@link edu.stanford.nlp.util.Index}
   *                     goodFeatures.
   * @param goodFeatures An {@link edu.stanford.nlp.util.Index} of features we wish to retain.
   * @return A new {@link Dataset} wheres each data point contains only features
   *         which were in goodFeatures.
   */
  public Dataset<String, String> getDataset(Dataset<String, String> oldData, Index<String> goodFeatures) {
    //public Dataset getDataset(List data, Collection goodFeatures) {
    //makeAnswerArraysAndTagIndex(data);

    int[][] oldDataArray = oldData.getDataArray();
    int[] oldLabelArray = oldData.getLabelsArray();
    Index<String> oldFeatureIndex = oldData.featureIndex;

    int[] oldToNewFeatureMap = new int[oldFeatureIndex.size()];

    int[][] newDataArray = new int[oldDataArray.length][];

    System.err.print("Building reduced dataset...");

    int size = oldFeatureIndex.size();
    int max = 0;
    for (int i = 0; i < size; i++) {
      oldToNewFeatureMap[i] = goodFeatures.indexOf(oldFeatureIndex.get(i));
      if (oldToNewFeatureMap[i] > max) {
        max = oldToNewFeatureMap[i];
      }
    }

    for (int i = 0; i < oldDataArray.length; i++) {
      int[] data = oldDataArray[i];
      size = 0;
      for (int oldF : data) {
        if (oldToNewFeatureMap[oldF] > 0) {
          size++;
        }
      }
      int[] newData = new int[size];
      int index = 0;
      for (int oldF : data) {
        int f = oldToNewFeatureMap[oldF];
        if (f > 0) {
          newData[index++] = f;
        }
      }
      newDataArray[i] = newData;
    }

    Dataset<String, String> train = new Dataset<String, String>(oldData.labelIndex, oldLabelArray, goodFeatures, newDataArray, newDataArray.length);

    System.err.println("done.");
    if (flags.featThreshFile != null) {
      System.err.println("applying thresholds...");
      List<Pair<Pattern,Integer>> thresh = getThresholds(flags.featThreshFile);
      train.applyFeatureCountThreshold(thresh);
    } else if (flags.featureThreshold > 1) {
      System.err.println("Removing Features with counts < " + flags.featureThreshold);
      train.applyFeatureCountThreshold(flags.featureThreshold);
    }
    train.summaryStatistics();
    return train;
  }

  private void adapt(Dataset<String, String> adapt) {
    if (flags.classifierType.equalsIgnoreCase("SVM")) {
      throw new UnsupportedOperationException();
    }
    adaptMaxEnt(adapt);
  }

  private void adaptMaxEnt(Dataset<String, String> adapt) {
    if (classifier instanceof LinearClassifier) {
      // So far the adaptation is only done on Gaussian Prior. Haven't checked how it'll work on other kinds of priors. -pichuan
      int prior = LogPrior.LogPriorType.QUADRATIC.ordinal();
      if (flags.useHuber) {
        throw new UnsupportedOperationException();
      } else if (flags.useQuartic) {
        throw new UnsupportedOperationException();
      }

      LinearClassifierFactory<String, String> lcf = new LinearClassifierFactory<String, String>(flags.tolerance, flags.useSum, prior, flags.adaptSigma, flags.epsilon, flags.QNsize);
      ((LinearClassifier<String, String>)classifier).adaptWeights(adapt,lcf);
    } else {
      throw new UnsupportedOperationException();
    }
  }

  private void train(Dataset<String, String> train) {
    if (flags.classifierType.equalsIgnoreCase("SVM")) {
      trainSVM(train);
    } else {
      trainMaxEnt(train);
    }
  }

  private void trainSVM(Dataset<String, String> train) {
    SVMLightClassifierFactory<String, String> fact = new SVMLightClassifierFactory<String, String>();
    classifier = fact.trainClassifier(train);

  }

  private void trainMaxEnt(Dataset<String, String> train) {
    int prior = LogPrior.LogPriorType.QUADRATIC.ordinal();
    if (flags.useHuber) {
      prior = LogPrior.LogPriorType.HUBER.ordinal();
    } else if (flags.useQuartic) {
      prior = LogPrior.LogPriorType.QUARTIC.ordinal();
    }

    LinearClassifier<String, String> lc;
    if (flags.useNB) {
      lc = new NBLinearClassifierFactory<String, String>(flags.sigma).trainClassifier(train);
    } else {
      LinearClassifierFactory<String, String> lcf = new LinearClassifierFactory<String, String>(flags.tolerance, flags.useSum, prior, flags.sigma, flags.epsilon, flags.QNsize);
      if (flags.useQN) {
        lcf.useQuasiNewton(flags.useRobustQN);
      } else if(flags.useStochasticQN) {
        lcf.useStochasticQN(flags.initialGain,flags.stochasticBatchSize);
      } else if(flags.useSMD) {
        lcf.useStochasticMetaDescent(flags.initialGain, flags.stochasticBatchSize,flags.stochasticMethod,flags.SGDPasses);
      } else if(flags.useSGD) {
        lcf.useStochasticGradientDescent(flags.gainSGD,flags.stochasticBatchSize);
      } else if(flags.useSGDtoQN) {
        lcf.useStochasticGradientDescentToQuasiNewton(flags.initialGain, flags.stochasticBatchSize,
                                       flags.SGDPasses, flags.QNPasses, flags.SGD2QNhessSamples,
                                       flags.QNsize, flags.outputIterationsToFile);
      } else if(flags.useHybrid) {
        lcf.useHybridMinimizer(flags.initialGain, flags.stochasticBatchSize ,flags.stochasticMethod ,flags.hybridCutoffIteration );
      } else {
        lcf.useConjugateGradientAscent();
      }
      lc = lcf.trainClassifier(train);
    }
    this.classifier = lc;
  }

  private void trainSemiSup(Dataset<String, String> data, Dataset<String, String> biasedData, double[][] confusionMatrix) {
    int prior = LogPrior.LogPriorType.QUADRATIC.ordinal();
    if (flags.useHuber) {
      prior = LogPrior.LogPriorType.HUBER.ordinal();
    } else if (flags.useQuartic) {
      prior = LogPrior.LogPriorType.QUARTIC.ordinal();
    }

    LinearClassifierFactory<String, String> lcf;
    lcf = new LinearClassifierFactory<String, String>(flags.tolerance, flags.useSum, prior, flags.sigma, flags.epsilon, flags.QNsize);
    if (flags.useQN) {
      lcf.useQuasiNewton();
    } else{
      lcf.useConjugateGradientAscent();
    }

    this.classifier = (LinearClassifier<String, String>) lcf.trainClassifierSemiSup(data, biasedData, confusionMatrix, null);
  }


//   public void crossValidateTrainAndTest() throws Exception {
//     crossValidateTrainAndTest(flags.trainFile);
//   }

//   public void crossValidateTrainAndTest(String filename) throws Exception {
//     // wordshapes

//     for (int fold = flags.startFold; fold <= flags.endFold; fold++) {
//       System.err.println("fold " + fold + " of " + flags.endFold);
//       // train

//       List = makeObjectBank(filename);
//       List folds = split(data, flags.numFolds);
//       data = null;

//       List train = new ArrayList();

//       for (int i = 0; i < flags.numFolds; i++) {
//         List docs = (List) folds.get(i);
//         if (i != fold) {
//           train.addAll(docs);
//         }
//       }
//       folds = null;
//       train(train);
//       train = null;

//       List test = new ArrayList();
//       data = makeObjectBank(filename);
//       folds = split(data, flags.numFolds);
//       data = null;

//       for (int i = 0; i < flags.numFolds; i++) {
//         List docs = (List) folds.get(i);
//         if (i == fold) {
//           test.addAll(docs);
//         }
//       }
//       folds = null;
//       // test
//       test(test);
//       writeAnswers(test);
//     }
//   }

//   /**
//    * Splits the given train corpus into a train and a test corpus based on the fold number.
//    * 1 / numFolds documents are held out for test, with the offset determined by the fold number.
//    *
//    * @param data     The original data
//    * @param numFolds The number of folds to split the data into
//    * @return A list of folds giving the new training set
//    */
//   private List split(List data, int numFolds) {
//     List folds = new ArrayList();
//     int foldSize = data.size() / numFolds;
//     int r = data.size() - (numFolds * foldSize);

//     int index = 0;
//     for (int i = 0; i < numFolds; i++) {
//       List fold = new ArrayList();
//       int end = (i < r ? foldSize + 1 : foldSize);
//       for (int j = 0; j < end; j++) {
//         fold.add(data.get(index++));
//       }
//       folds.add(fold);
//     }

//     return folds;
//   }

  @Override
  public void serializeClassifier(String serializePath) {

    System.err.print("Serializing classifier to " + serializePath + "...");

    try {
      ObjectOutputStream oos = IOUtils.writeStreamFromString(serializePath);

      oos.writeObject(classifier);
      oos.writeObject(flags);
      oos.writeObject(featureFactories);
      oos.writeObject(classIndex);
      oos.writeObject(answerArrays);
      //oos.writeObject(WordShapeClassifier.getKnownLowerCaseWords());

      oos.writeObject(knownLCWords);

      oos.close();
      System.err.println("Done.");

    } catch (Exception e) {
      System.err.println("Error serializing to " + serializePath);
      e.printStackTrace();
    }
  }


  /**
   * Used to load the default supplied classifier.  **THIS FUNCTION
   * WILL ONLY WORK IF RUN INSIDE A JAR FILE**
   */
  public void loadDefaultClassifier() {
    loadJarClassifier(DEFAULT_CLASSIFIER, null);
  }

  /**
   * Used to obtain the default classifier which is
   * stored inside a jar file.  <i>THIS FUNCTION
   * WILL ONLY WORK IF RUN INSIDE A JAR FILE.</i>
   *
   * @return A Default CMMClassifier from a jar file
   */
  public static CMMClassifier<? extends CoreLabel> getDefaultClassifier() {

    CMMClassifier<? extends CoreLabel> cmm = new CMMClassifier<CoreLabel>();
    cmm.loadDefaultClassifier();
    return cmm;

  }

  /** Load a classifier from the given Stream.
   *  <i>Implementation note: </i> This method <i>does not</i> close the
   *  Stream that it reads from.
   *
   *  @param ois The ObjectInputStream to load the serialized classifier from
   *
   *  @throws IOException If there are problems accessing the input stream
   *  @throws ClassCastException If there are problems interpreting the serialized data
   *  @throws ClassNotFoundException If there are problems interpreting the serialized data

   *  */
  @SuppressWarnings("unchecked")
  @Override
  public void loadClassifier(ObjectInputStream ois, Properties props) throws ClassCastException, IOException, ClassNotFoundException {
    classifier = (LinearClassifier<String, String>) ois.readObject();
    flags = (SeqClassifierFlags) ois.readObject();
    Object featureFactory = ois.readObject();
    if (featureFactory instanceof List) {
      featureFactories = ErasureUtils.uncheckedCast(featureFactory);
    } else if (featureFactory instanceof FeatureFactory) {
      featureFactories = Generics.newArrayList();
      featureFactories.add((FeatureFactory) featureFactory);
    }

    if (props != null) {
      flags.setProperties(props);
    }
    reinit();

    classIndex = (Index<String>) ois.readObject();
    answerArrays = (Set<List<String>>) ois.readObject();

    knownLCWords = (Set<String>) ois.readObject();
  }


  public static CMMClassifier<? extends CoreLabel> getClassifierNoExceptions(File file) {
    CMMClassifier<? extends CoreLabel> cmm = new CMMClassifier<CoreLabel>();
    cmm.loadClassifierNoExceptions(file);
    return cmm;

  }

  public static CMMClassifier<? extends CoreLabel> getClassifier(File file) throws IOException, ClassCastException, ClassNotFoundException {

    CMMClassifier<? extends CoreLabel> cmm = new CMMClassifier<CoreLabel>();
    cmm.loadClassifier(file);
    return cmm;
  }

  public static CMMClassifier<CoreLabel> getClassifierNoExceptions(String loadPath) {
    CMMClassifier<CoreLabel> cmm = new CMMClassifier<CoreLabel>();
    cmm.loadClassifierNoExceptions(loadPath);
    return cmm;

  }

  public static CMMClassifier<? extends CoreLabel> getClassifier(String loadPath) throws IOException, ClassCastException, ClassNotFoundException {

    CMMClassifier<? extends CoreLabel> cmm = new CMMClassifier<CoreLabel>();
    cmm.loadClassifier(loadPath);
    return cmm;
  }

  public static CMMClassifier<? extends CoreLabel> getClassifierNoExceptions(InputStream in) {
    CMMClassifier<? extends CoreLabel> cmm = new CMMClassifier<CoreLabel>();
    cmm.loadClassifierNoExceptions(new BufferedInputStream(in), null);
    return cmm;
  }

  public static CMMClassifier<? extends CoreLabel> getClassifier(InputStream in) throws IOException, ClassCastException, ClassNotFoundException {
    CMMClassifier<? extends CoreLabel> cmm = new CMMClassifier<CoreLabel>();
    cmm.loadClassifier(new BufferedInputStream(in));
    return cmm;
  }

  /** This routine builds the <code>answerArrays</code> which give the
   *  empirically legal label sequences (of length (order) at most
   *  <code>flags.maxLeft</code>) and the <code>classIndex</code>,
   *  which indexes known answer classes.
   *
   * @param docs The training data: A List of List of CoreLabel
   */
  private void makeAnswerArraysAndTagIndex(Collection<List<IN>> docs) {
    if (answerArrays == null) {
      answerArrays = Generics.newHashSet();
    }
    if (classIndex == null) {
      classIndex = new HashIndex<String>();
    }

    for (List<IN> doc : docs) {
      if (flags.useReverse) {
        Collections.reverse(doc);
      }

      int leng = doc.size();
      for (int start = 0; start < leng; start++) {
        for (int diff = 1; diff <= flags.maxLeft && start + diff <= leng; diff++) {
          String[] seq = new String[diff];
          for (int i = start; i < start + diff; i++) {
            seq[i - start] = doc.get(i).get(CoreAnnotations.AnswerAnnotation.class);
          }
          answerArrays.add(Arrays.asList(seq));
        }
      }
      for (int i = 0; i < leng; i++) {
        CoreLabel wordInfo = doc.get(i);
        classIndex.add(wordInfo.get(CoreAnnotations.AnswerAnnotation.class));
      }

      if (flags.useReverse) {
        Collections.reverse(doc);
      }
    }
  }

  /** Make an individual Datum out of the data list info, focused at position
   *  loc.
   *  @param info A List of IN objects
   *  @param loc  The position in the info list to focus feature creation on
   *  @param featureFactories The factory that constructs features out of the item
   *  @return A Datum (BasicDatum) representing this data instance
   */
  public Datum<String, String> makeDatum(List<IN> info, int loc, List<FeatureFactory<IN>> featureFactories) {
    PaddedList<IN> pInfo = new PaddedList<IN>(info, pad);

    Collection<String> features = new ArrayList<String>();
    for (FeatureFactory featureFactory : featureFactories) {
      List<Clique> cliques = featureFactory.getCliques();
      for (Clique c : cliques) {
        Collection<String> feats = featureFactory.getCliqueFeatures(pInfo, loc, c);
        feats = addOtherClasses(feats, pInfo, loc, c);
        features.addAll(feats);
      }
    }

    printFeatures(pInfo.get(loc), features);
    CoreLabel c = info.get(loc);
    return new BasicDatum<String, String>(features, c.get(CoreAnnotations.AnswerAnnotation.class));
  }


  /** This adds to the feature name the name of classes that are other than
   *  the current class that are involved in the clique.  In the CMM, these
   *  other classes become part of the conditioning feature, and only the
   *  class of the current position is being predicted.
   *
   *  @return A collection of features with extra class information put
   *          into the feature name.
   */
  private static Collection<String> addOtherClasses(Collection<String> feats, List<? extends CoreLabel> info,
                                     int loc, Clique c) {
    String addend = null;
    String pAnswer = info.get(loc - 1).get(CoreAnnotations.AnswerAnnotation.class);
    String p2Answer = info.get(loc - 2).get(CoreAnnotations.AnswerAnnotation.class);
    String p3Answer = info.get(loc - 3).get(CoreAnnotations.AnswerAnnotation.class);
    String p4Answer = info.get(loc - 4).get(CoreAnnotations.AnswerAnnotation.class);
    String p5Answer = info.get(loc - 5).get(CoreAnnotations.AnswerAnnotation.class);
    String nAnswer = info.get(loc + 1).get(CoreAnnotations.AnswerAnnotation.class);
    // cdm 2009: Is this really right? Do we not need to differentiate names that would collide???
    if (c == FeatureFactory.cliqueCpC) {
      addend = '|' + pAnswer;
    } else if (c == FeatureFactory.cliqueCp2C) {
      addend = '|' + p2Answer;
    } else if (c == FeatureFactory.cliqueCp3C) {
      addend = '|' + p3Answer;
    } else if (c == FeatureFactory.cliqueCp4C) {
      addend = '|' + p4Answer;
    } else if (c == FeatureFactory.cliqueCp5C) {
      addend = '|' + p5Answer;
    } else if (c == FeatureFactory.cliqueCpCp2C) {
      addend = '|' + pAnswer + '-' + p2Answer;
    } else if (c == FeatureFactory.cliqueCpCp2Cp3C) {
      addend = '|' + pAnswer + '-' + p2Answer + '-' + p3Answer;
    } else if (c == FeatureFactory.cliqueCpCp2Cp3Cp4C) {
      addend = '|' + pAnswer + '-' + p2Answer + '-' + p3Answer + '-' + p4Answer;
    } else if (c == FeatureFactory.cliqueCpCp2Cp3Cp4Cp5C) {
      addend = '|' + pAnswer + '-' + p2Answer + '-' + p3Answer + '-' + p4Answer + '-' + p5Answer;
    } else if (c == FeatureFactory.cliqueCnC) {
      addend = '|' + nAnswer;
    } else if (c == FeatureFactory.cliqueCpCnC) {
      addend = '|' + pAnswer + '-' + nAnswer;
    }
    if (addend == null) {
      return feats;
    }
    Collection<String> newFeats = Generics.newHashSet();
    for (String feat : feats) {
      String newFeat = feat + addend;
      newFeats.add(newFeat);
    }
    return newFeats;
  }


  private static List<Pair<Pattern, Integer>> getThresholds(String filename) {
    try {
      BufferedReader in = new BufferedReader(new FileReader(filename));
      List<Pair<Pattern, Integer>> thresholds = new ArrayList<Pair<Pattern, Integer>>();
      String line;
      while ((line = in.readLine()) != null) {
        int i = line.lastIndexOf(' ');
        Pattern p = Pattern.compile(line.substring(0, i));
        //System.err.println(":"+line.substring(0,i)+":");
        Integer t = Integer.valueOf(line.substring(i + 1));
        Pair<Pattern, Integer> pair = new Pair<Pattern, Integer>(p, t);
        thresholds.add(pair);
      }
      in.close();
      return thresholds;
    } catch (Exception e) {
      throw new RuntimeException("Error reading threshold file", e);
    }
  }

  public void trainSemiSup() {
    DocumentReaderAndWriter<IN> readerAndWriter = makeReaderAndWriter();

    String filename = flags.trainFile;
    String biasedFilename = flags.biasedTrainFile;

    ObjectBank<List<IN>> data =
      makeObjectBankFromFile(filename, readerAndWriter);
    ObjectBank<List<IN>> biasedData =
      makeObjectBankFromFile(biasedFilename, readerAndWriter);

    Index<String> featureIndex = new HashIndex<String>();
    Index<String> classIndex = new HashIndex<String>();

    Dataset<String, String> dataset = getDataset(data, featureIndex, classIndex);
    Dataset<String, String> biasedDataset = getBiasedDataset(biasedData, featureIndex, classIndex);

    double[][] confusionMatrix = new double[classIndex.size()][classIndex.size()];

    for (int i = 0; i < confusionMatrix.length; i++) {
      Arrays.fill(confusionMatrix[i], 0.0);
      confusionMatrix[i][i] = 1.0;
    }

    String cm = flags.confusionMatrix;
    String[] bits = cm.split(":");
    for (String bit : bits) {
      String[] bits1 = bit.split("\\|");
      int i1 = classIndex.indexOf(bits1[0]);
      int i2 = classIndex.indexOf(bits1[1]);
      double d = Double.parseDouble(bits1[2]);
      confusionMatrix[i2][i1] = d;
    }

    for (double[] row : confusionMatrix) {
      ArrayMath.normalize(row);
    }

    for (int i = 0; i < confusionMatrix.length; i++) {
      for (int j = 0; j < i; j++) {
        double d = confusionMatrix[i][j];
        confusionMatrix[i][j] = confusionMatrix[j][i];
        confusionMatrix[j][i] = d;
      }
    }

    for (int i = 0; i < confusionMatrix.length; i++) {
      for (int j = 0; j < confusionMatrix.length; j++) {
        System.err.println("P("+classIndex.get(j)+ '|' +classIndex.get(i)+") = "+confusionMatrix[j][i]);
      }
    }

    trainSemiSup(dataset, biasedDataset, confusionMatrix);
  }

  static class Scorer<INN extends CoreLabel> implements SequenceModel {
    private CMMClassifier<INN> classifier = null;

    private int[] tagArray = null;
    private int[] backgroundTags = null;
    private Index<String> tagIndex = null;
    private List<INN> lineInfos = null;
    private int pre = 0;
    private int post = 0;
    private Set<List<String>> legalTags = null;

    private static final boolean VERBOSE = false;

    void buildTagArray() {
      int sz = tagIndex.size();
      tagArray = new int[sz];
      for (int i = 0; i < sz; i++) {
        tagArray[i] = i;
      }
    }

    @Override
    public int length() {
      return lineInfos.size() - pre - post;
    }

    @Override
    public int leftWindow() {
      return pre;
    }

    @Override
    public int rightWindow() {
      return post;
    }

    @Override
    public int[] getPossibleValues(int position) {
      //             if (position == 0 || position == lineInfos.size() - 1) {
      //                 int[] a = new int[1];
      //                 a[0] = tagIndex.indexOf(BACKGROUND);
      //                 return a;
      //             }
      if (tagArray == null) {
        buildTagArray();
      }
      if (position < pre) {
        return backgroundTags;
      }
      return tagArray;
    }

    @Override
    public double scoreOf(int[] sequence) {
      throw new UnsupportedOperationException();
    }

    private double[] scoreCache = null;
    private int[] lastWindow = null;
    //private int lastPos = -1;

    @Override
    public double scoreOf(int[] tags, int pos) {
      if (false) {
        return scoresOf(tags, pos)[tags[pos]];
      }
      if (lastWindow == null) {
        lastWindow = new int[leftWindow() + rightWindow() + 1];
        Arrays.fill(lastWindow, -1);
      }
      boolean match = (pos == lastPos);
      for (int i = pos - leftWindow(); i <= pos + rightWindow(); i++) {
        if (i == pos || i < 0) {
          continue;
        }
        /*System.err.println("p:"+pos);
        System.err.println("lw:"+leftWindow());
        System.err.println("i:"+i);*/
        match &= tags[i] == lastWindow[i - pos + leftWindow()];
      }
      if (!match) {
        scoreCache = scoresOf(tags, pos);
        for (int i = pos - leftWindow(); i <= pos + rightWindow(); i++) {
          if (i < 0) {
            continue;
          }
          lastWindow[i - pos + leftWindow()] = tags[i];
        }
        lastPos = pos;
      }
      return scoreCache[tags[pos]];
    }

    private int percent = -1;
    private int num = 0;
    private long secs = System.currentTimeMillis();
    private long hit = 0;
    private long tot = 0;

    @Override
    public double[] scoresOf(int[] tags, int pos) {
      if (VERBOSE) {
        int p = (100 * pos) / length();
        if (p > percent) {
          long secs2 = System.currentTimeMillis();
          System.err.println(StringUtils.padLeft(p, 3) + "%   " + ((secs2 - secs == 0) ? 0 : (num * 1000 / (secs2 - secs))) + " hits per sec, position=" + pos + ", legal=" + ((tot == 0) ? 100 : ((100 * hit) / tot)));
          // + "% [hit=" + hit + ", tot=" + tot + "]");
          percent = p;
          num = 0;
          secs = secs2;
        }
        tot++;
      }
      String[] answers = new String[1 + leftWindow() + rightWindow()];
      String[] pre = new String[leftWindow()];
      for (int i = 0; i < 1 + leftWindow() + rightWindow(); i++) {
        int absPos = pos - leftWindow() + i;
        if (absPos < 0) {
          continue;
        }
        answers[i] = tagIndex.get(tags[absPos]);
        CoreLabel li = lineInfos.get(absPos);
        li.set(CoreAnnotations.AnswerAnnotation.class, answers[i]);
        if (i < leftWindow()) {
          pre[i] = answers[i];
        }
      }
      double[] scores = new double[tagIndex.size()];
      //System.out.println("Considering: "+Arrays.asList(pre));
      if (!legalTags.contains(Arrays.asList(pre)) && classifier.flags.useObservedSequencesOnly) {
        // System.out.println("Rejecting: " + Arrays.asList(pre));
        // System.out.println(legalTags);
        Arrays.fill(scores, -1000);// Double.NEGATIVE_INFINITY;
        return scores;
      }
      num++;
      hit++;
      Counter<String> c = classifier.scoresOf(lineInfos, pos);
      //System.out.println("Pos "+pos+" hist "+Arrays.asList(pre)+" result "+c);
      //System.out.println(c);
      //if (false && flags.justify) {
      //    System.out.println("Considering position " + pos + ", word is " + ((CoreLabel) lineInfos.get(pos)).word());
      //    //System.out.println("Datum is "+d.asFeatures());
      //    System.out.println("History: " + Arrays.asList(pre));
      //}
      for (String s : c.keySet()) {
        int t = tagIndex.indexOf(s);
        if (t > -1) {
          int[] tA = getPossibleValues(pos);
          for (int j = 0; j < tA.length; j++) {
            if (tA[j] == t) {
              scores[j] = c.getCount(s);
              //if (false && flags.justify) {
              //    System.out.println("Label " + s + " got score " + scores[j]);
              //}
            }
          }
        }
      }
      // normalize?
      if (classifier.normalize()) {
        ArrayMath.logNormalize(scores);
      }
      return scores;
    }

    static double[] recenter(double[] x) {
      double[] r = new double[x.length];
      // double logTotal = Double.NEGATIVE_INFINITY;
      // for (int i = 0; i < x.length; i++)
      //    logTotal = SloppyMath.logAdd(logTotal, x[i]);
      double logTotal = ArrayMath.logSum(x);
      for (int i = 0; i < x.length; i++) {
        r[i] = x[i] - logTotal;
      }
      return r;
    }

    /**
     * Build a Scorer.
     *
     * @param lineInfos  List of INN data items to classify
     * @param classifier The trained Classifier
     * @param pre        Number of previous tags that condition current tag
     * @param post       Number of following tags that condition previous tag
     *                   (if pre and post are both nonzero, then you have a
     *                   dependency network tagger)
     */
    Scorer(List<INN> lineInfos, Index<String> tagIndex, CMMClassifier<INN> classifier, int pre, int post, Set<List<String>> legalTags) {
      if (VERBOSE) {
        System.err.println("Built Scorer for " + lineInfos.size() + " words, clique pre=" + pre + " post=" + post);
      }
      this.pre = pre;
      this.post = post;
      this.lineInfos = lineInfos;
      this.tagIndex = tagIndex;
      this.classifier = classifier;
      this.legalTags = legalTags;
      backgroundTags = new int[]{tagIndex.indexOf(classifier.flags.backgroundSymbol)};
    }

  } // end class Scorer

  private boolean normalize() {
    return flags.normalize;
  }

  static int lastPos = -1;

  public Counter<String> scoresOf(List<IN> lineInfos, int pos) {
//     if (pos != lastPos) {
//       System.err.print(pos+".");
//       lastPos = pos;
//     }
//     System.err.print("!");
    Datum<String, String> d = makeDatum(lineInfos, pos, featureFactories);
    return classifier.logProbabilityOf(d);
  }


  /**
   * Takes a {@link List} of {@link CoreLabel}s and prints the likelihood
   * of each possible label at each point.
   * TODO: Finish or delete this method!
   *
   * @param document A {@link List} of {@link CoreLabel}s.
   */
  @Override
  public void printProbsDocument(List<IN> document) {

    //ClassicCounter<String> c = scoresOf(document, 0);

  }

  /** Command-line version of the classifier.  See the class
   *  comments for examples of use, and SeqClassifierFlags
   *  for more information on supported flags.
   */
  public static void main(String[] args) throws Exception {
    StringUtils.printErrInvocationString("CMMClassifier", args);

    Properties props = StringUtils.argsToProperties(args);
    CMMClassifier<CoreLabel> cmm = new CMMClassifier<CoreLabel>(props);
    String testFile = cmm.flags.testFile;
    String textFile = cmm.flags.textFile;
    String loadPath = cmm.flags.loadClassifier;
    String serializeTo = cmm.flags.serializeTo;

    // cmm.crossValidateTrainAndTest(trainFile);
    if (loadPath != null) {
      cmm.loadClassifierNoExceptions(loadPath, props);
    } else if (cmm.flags.loadJarClassifier != null) {
      cmm.loadJarClassifier(cmm.flags.loadJarClassifier, props);
    } else if (cmm.flags.trainFile != null) {
      if (cmm.flags.biasedTrainFile != null) {
        cmm.trainSemiSup();
      } else {
        cmm.train();
      }
    } else {
      cmm.loadDefaultClassifier();
    }

    if (serializeTo != null) {
      cmm.serializeClassifier(serializeTo);
    }

    if (testFile != null) {
      cmm.classifyAndWriteAnswers(testFile, cmm.makeReaderAndWriter(), true);
    } else if (cmm.flags.testFiles != null) {
      cmm.classifyAndWriteAnswers(cmm.flags.baseTestDir, cmm.flags.testFiles, cmm.makeReaderAndWriter(), true);
    }

    if (textFile != null) {
      DocumentReaderAndWriter<CoreLabel> readerAndWriter =
        new PlainTextDocumentReaderAndWriter<CoreLabel>();
      cmm.classifyAndWriteAnswers(textFile, readerAndWriter, false);
    }
  } // end main


  public double weight(String feature, String label) {
    return ((LinearClassifier<String, String>)classifier).weight(feature, label);
  }

  public double[][] weights() {
    return ((LinearClassifier<String, String>)classifier).weights();
  }

  @Override
  public List<IN> classifyWithGlobalInformation(List<IN> tokenSeq, final CoreMap doc, final CoreMap sent) {
    return classify(tokenSeq);
  }

} // end class CMMClassifier
TOP

Related Classes of edu.stanford.nlp.ie.ner.CMMClassifier

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.