Package edu.stanford.nlp.parser.dvparser

Source Code of edu.stanford.nlp.parser.dvparser.CombineDVModels

package edu.stanford.nlp.parser.dvparser;

import java.io.FileFilter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import edu.stanford.nlp.parser.common.ArgUtils;
import edu.stanford.nlp.parser.lexparser.EvaluateTreebank;
import edu.stanford.nlp.parser.lexparser.LexicalizedParser;
import edu.stanford.nlp.parser.lexparser.Options;
import edu.stanford.nlp.parser.lexparser.Reranker;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.util.Pair;

/**
* This class combines multiple DVParsers into one by adding their
* scores.  If the models are somewhat different but have similar
* accuracy, this gives a slight accuracy increase.
*
* @author John Bauer
*/
public class CombineDVModels {

  public static void main(String[] args)
    throws IOException, ClassNotFoundException
  {
    String modelPath = null;

    List<String> baseModelPaths = null;

    String testTreebankPath = null;
    FileFilter testTreebankFilter = null;

    List<String> unusedArgs = new ArrayList<String>();

    for (int argIndex = 0; argIndex < args.length; ) {
      if (args[argIndex].equalsIgnoreCase("-model")) {
        modelPath = args[argIndex + 1];
        argIndex += 2;
      } else if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
        Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
        argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
        testTreebankPath = treebankDescription.first();
        testTreebankFilter = treebankDescription.second();
      } else if (args[argIndex].equalsIgnoreCase("-baseModels")) {
        argIndex++;
        baseModelPaths = new ArrayList<String>();
        while (argIndex < args.length && args[argIndex].charAt(0) != '-') {
          baseModelPaths.add(args[argIndex++]);
        }
        if (baseModelPaths.size() == 0) {
          throw new IllegalArgumentException("Found an argument -baseModels with no actual models named");
        }
      } else {
        unusedArgs.add(args[argIndex++]);
      }
    }

    String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
    LexicalizedParser underlyingParser = null;
    Options options = null;
    LexicalizedParser combinedParser = null;
    if (baseModelPaths != null) {
      List<DVModel> dvparsers = new ArrayList<DVModel>();
      for (String baseModelPath : baseModelPaths) {
        System.err.println("Loading serialized DVParser from " + baseModelPath);
        LexicalizedParser dvparser = LexicalizedParser.loadModel(baseModelPath);
        Reranker reranker = dvparser.reranker;
        if (!(reranker instanceof DVModelReranker)) {
          throw new IllegalArgumentException("Expected parsers with DVModel embedded");
        }
        dvparsers.add(((DVModelReranker) reranker).getModel());
        if (underlyingParser == null) {
          underlyingParser = dvparser;
          options = underlyingParser.getOp();
          // TODO: other parser's options?
          options.setOptions(newArgs);
        }
        System.err.println("... done");
      }
      combinedParser = LexicalizedParser.copyLexicalizedParser(underlyingParser);
      CombinedDVModelReranker reranker = new CombinedDVModelReranker(options, dvparsers);
      combinedParser.reranker = reranker;
      combinedParser.saveParserToSerialized(modelPath);
    } else {
      throw new IllegalArgumentException("Need to specify -model to load an already prepared CombinedParser");
    }

    Treebank testTreebank = null;
    if (testTreebankPath != null) {
      System.err.println("Reading in trees from " + testTreebankPath);
      if (testTreebankFilter != null) {
        System.err.println("Filtering on " + testTreebankFilter);
      }
      testTreebank = combinedParser.getOp().tlpParams.memoryTreebank();;
      testTreebank.loadPath(testTreebankPath, testTreebankFilter);
      System.err.println("Read in " + testTreebank.size() + " trees for testing");

      EvaluateTreebank evaluator = new EvaluateTreebank(combinedParser.getOp(), null, combinedParser);
      evaluator.testOnTreebank(testTreebank);
    }
  }
}
TOP

Related Classes of edu.stanford.nlp.parser.dvparser.CombineDVModels

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.