Package edu.stanford.nlp.ie.machinereading

Source Code of edu.stanford.nlp.ie.machinereading.MachineReading

package edu.stanford.nlp.ie.machinereading;

import java.io.File;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Properties;
import java.util.Set;
import java.util.logging.ConsoleHandler;
import java.util.logging.Handler;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.logging.SimpleFormatter;

import edu.stanford.nlp.ie.machinereading.structure.*;
import edu.stanford.nlp.ie.machinereading.structure.MachineReadingAnnotations.EntityMentionsAnnotation;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.CoreAnnotations.SentencesAnnotation;
import edu.stanford.nlp.ling.CoreAnnotations.TextAnnotation;
import edu.stanford.nlp.ling.CoreAnnotations.TokensAnnotation;
import edu.stanford.nlp.trees.TreeCoreAnnotations.TreeAnnotation;
import edu.stanford.nlp.pipeline.Annotation;
import edu.stanford.nlp.pipeline.StanfordCoreNLP;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Execution;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;

/**
* Main driver for Machine Reading training, annotation, and evaluation. Does
* entity, relation, and event extraction for all corpora.
*
* This code has been adapted for 4 domains, all defined in the edu.stanford.nlp.ie.machinereading.domains package.
* For each domain, you need a properties file that is the only command line parameter for MachineReading.
* Minimally, for each domain you need to define a reader class that extends the GenericDataSetReader class
* and overrides the public Annotation read(String path) method.
*
* How to run: java edu.stanford.nlp.ie.machinereading.MachineReading --arguments propertiesfile
*
* This method creates an Annotation with additional objects per sentence: EntityMentions and RelationMentions.
* Using these objects, the classifiers that get called from MachineReading train entity and relation extractors.
* The simplest example domain currently is in edu.stanford.nlp.ie.machinereading.domains.roth,
* which is a simple entity and relation extraction using a dataset created by Dan Roth. The properties file for the domain is at
* projects/more/src/edu/stanford/nlp/ie/machinereading/domains/roth/roth.properties
*
* @author David McCLosky
* @author mrsmith
* @author Mihai
*
*/
public class MachineReading {
 
  // Store command-line args so they can be passed to other classes
  private String[] args;
 
  /*
   * class attributes
   */
  public GenericDataSetReader reader;
  public GenericDataSetReader auxReader;
 

  public Extractor entityExtractor;
  // TODO could add an entityExtractorPostProcessor if we need one
  public Extractor relationExtractor;
  public Extractor relationExtractionPostProcessor;
  public Extractor eventExtractor;
  public Extractor consistencyChecker;
 
  protected boolean forceRetraining;
  public boolean forceParseSentences;
 
 
  /**
   * Array of pairs of datasets (training, testing)
   * If cross validation is enabled, the length of this array is the number of folds; otherwise it is 1
   * The first element in each pair is the training corpus; the second is testing
   */
  private Pair<Annotation, Annotation> [] datasets;
 
  /**
   * Stores the predictions of the extractors
   * The first index is the partition number (of length 1 is cross validation is not enabled)
   * The second index is the task: 0 - entities, 1 - relations, 2 - events
   * Note: we need to store separate predictions per task because they may not be compatible with each other.
   *       For example, we may have predicted entities in task 0 but use gold entities for task 1.
   */
  private Annotation [][] predictions;
 
  private Set<ResultsPrinter> entityResultsPrinterSet;
  private Set<ResultsPrinter> relationResultsPrinterSet;
  @SuppressWarnings("unused")
  private Set<ResultsPrinter> eventResultsPrinterSet;
 
  public static final int ENTITY_LEVEL = 0;
  public static final int RELATION_LEVEL = 1;
  public static final int EVENT_LEVEL = 2;
 
  public static void main(String[] args) throws Exception {
    MachineReading mr = makeMachineReading(args);
    mr.run();
  }
 
  public static void setLoggerLevel(Level level) {
    setConsoleLevel(Level.FINEST);
    MachineReadingProperties.logger.setLevel(level);
  }
 
  public static void setConsoleLevel(Level level) {
    // get the top Logger:
    Logger topLogger = java.util.logging.Logger.getLogger("");

    // Handler for console (reuse it if it already exists)
    Handler consoleHandler = null;
    // see if there is already a console handler
    for (Handler handler : topLogger.getHandlers()) {
        if (handler instanceof ConsoleHandler) {
            // found the console handler
            consoleHandler = handler;
            break;
        }
    }

    if (consoleHandler == null) {
        // there was no console handler found, create a new one
        consoleHandler = new ConsoleHandler();
        topLogger.addHandler(consoleHandler);
    }
    // set the console handler level:
    consoleHandler.setLevel(level);
    consoleHandler.setFormatter(new SimpleFormatter());
  }
 
  /**
   * Use the makeMachineReading* methods to create MachineReading objects!
   */
  private MachineReading(String [] args) {
    this.args = args;
  }
  protected MachineReading() {
    this.args = new String[0];
  }
 
  /**
   * Creates a MR object to be used only for annotation purposes (no training)
   * This is needed in order to integrate MachineReading with BaselineNLProcessor
   */
  public static MachineReading makeMachineReadingForAnnotation(
      GenericDataSetReader reader,
      Extractor entityExtractor,
      Extractor relationExtractor,
      Extractor eventExtractor,
      Extractor consistencyChecker,
      Extractor relationPostProcessor,
      boolean testRelationsUsingPredictedEntities,
      boolean verbose) {
    MachineReading mr = new MachineReading();
   
    // readers needed to assign syntactic heads to predicted entities
    mr.reader = reader;
    mr.auxReader = null;
   
    // no results printers needed
    mr.entityResultsPrinterSet = new HashSet<ResultsPrinter>();
    mr.setRelationResultsPrinterSet(new HashSet<ResultsPrinter>());
   
    // create the storage for the generated annotations
    mr.predictions = new Annotation[3][1];
   
    // create the entity/relation classifiers
    mr.entityExtractor = entityExtractor;
    MachineReadingProperties.extractEntities = (entityExtractor != null ? true : false);
    mr.relationExtractor = relationExtractor;
    MachineReadingProperties.extractRelations = (relationExtractor != null ? true : false);
    MachineReadingProperties.testRelationsUsingPredictedEntities = testRelationsUsingPredictedEntities;
    mr.eventExtractor = eventExtractor;
    MachineReadingProperties.extractEvents = (eventExtractor != null ? true : false);
    mr.consistencyChecker = consistencyChecker;
    mr.relationExtractionPostProcessor = relationPostProcessor;

    Level level = verbose ? Level.FINEST : Level.SEVERE;
    if (entityExtractor != null)
      entityExtractor.setLoggerLevel(level);
    if (mr.relationExtractor != null)
      mr.relationExtractor.setLoggerLevel(level);
    if (mr.eventExtractor != null)
      mr.eventExtractor.setLoggerLevel(level);
   
    return mr;
  }
 
  public static MachineReading makeMachineReading(String [] args) throws IOException {
    // install global parameters
    MachineReading mr = new MachineReading(args);
    //TODO:
    Execution.fillOptions(MachineReadingProperties.class, args);
    //Arguments.parse(args, mr);
    System.err.println("PERCENTAGE OF TRAIN: " + MachineReadingProperties.percentageOfTrain);
   
    // convert args to properties
    Properties props = StringUtils.argsToProperties(args);
    if (props == null) {
      throw new RuntimeException("ERROR: failed to find Properties in the given arguments!");
    }
   
    String logLevel = props.getProperty("logLevel", "INFO");
    setLoggerLevel(Level.parse(logLevel.toUpperCase()));
   
    // install reader specific parameters
    GenericDataSetReader reader = mr.makeReader(props);
    GenericDataSetReader auxReader = mr.makeAuxReader();
    Level readerLogLevel = Level.parse(MachineReadingProperties.readerLogLevel.toUpperCase());
    reader.setLoggerLevel(readerLogLevel);
    if (auxReader != null) {
      auxReader.setLoggerLevel(readerLogLevel);
    }
    System.err.println("The reader log level is set to " + readerLogLevel);
    //Execution.fillOptions(GenericDataSetReaderProps.class, args);
    //Arguments.parse(args, reader);
   
    // create the pre-processing pipeline
    StanfordCoreNLP pipe = new StanfordCoreNLP(props, false);
    reader.setProcessor(pipe);
    if (auxReader != null) {
      auxReader.setProcessor(pipe);
    }
   
    // create the results printers
    mr.makeResultsPrinters(args);
   
    return mr;
  }
 
  /**
   * Performs extraction. This will train a new extraction model and evaluate
   * the model on the test set. Depending on the MachineReading instance's
   * parameters, it may skip training if a model already exists or skip
   * evaluation.
   *
   * returns results string, can be compared in a utest
   */
  public List<String> run() throws Exception {
    this.forceRetraining= !MachineReadingProperties.loadModel;
   
    if (MachineReadingProperties.trainOnly) {
      this.forceRetraining= true;
    }
    List<String> retMsg = new ArrayList<String>();
    boolean haveSerializedEntityExtractor = serializedModelExists(MachineReadingProperties.serializedEntityExtractorPath);
    boolean haveSerializedRelationExtractor = serializedModelExists(MachineReadingProperties.serializedRelationExtractorPath);
    boolean haveSerializedEventExtractor = serializedModelExists(MachineReadingProperties.serializedEventExtractorPath);
    Annotation training = null;
    Annotation aux = null;
    if ((MachineReadingProperties.extractEntities && !haveSerializedEntityExtractor) ||
        (MachineReadingProperties.extractRelations && !haveSerializedRelationExtractor) ||
        (MachineReadingProperties.extractEvents && !haveSerializedEventExtractor) ||
        this.forceRetraining|| MachineReadingProperties.crossValidate){
      // load training sentences
      training = loadOrMakeSerializedSentences(MachineReadingProperties.trainPath, reader, new File(MachineReadingProperties.serializedTrainingSentencesPath));
      if (auxReader != null) {
        MachineReadingProperties.logger.severe("Reading auxiliary dataset from " + MachineReadingProperties.auxDataPath + "...");
        aux = loadOrMakeSerializedSentences(MachineReadingProperties.auxDataPath, auxReader, new File(
            MachineReadingProperties.serializedAuxTrainingSentencesPath));
        MachineReadingProperties.logger.severe("Done reading auxiliary dataset.");
      }
    }
   
    Annotation testing = null;
    if (!MachineReadingProperties.trainOnly && !MachineReadingProperties.crossValidate) {
      // load test sentences
      File serializedTestSentences = new File(MachineReadingProperties.serializedTestSentencesPath);
      testing = loadOrMakeSerializedSentences(MachineReadingProperties.testPath, reader, serializedTestSentences);
    }
     
    //
    // create the actual datasets to be used for training and annotation
    //
    makeDataSets(training, testing, aux);
   
    //
    // process (training + annotate) one partition at a time
    //
    for(int partition = 0; partition < datasets.length; partition ++){
      assert(datasets.length > partition);
      assert(datasets[partition] != null);
      assert(MachineReadingProperties.trainOnly || datasets[partition].second() != null);
     
      // train all models
      train(datasets[partition].first(), (MachineReadingProperties.crossValidate ? partition : -1));
      // annotate using all models
      if(! MachineReadingProperties.trainOnly){
        MachineReadingProperties.logger.info("annotating partition " + partition );
        annotate(datasets[partition].second(), (MachineReadingProperties.crossValidate ? partition: -1));
      }
    }
   
    //
    // now report overall results
    //
    if(! MachineReadingProperties.trainOnly){
      // merge test sets for the gold data
      Annotation gold = new Annotation("");
      for(int i = 0; i < datasets.length; i ++) AnnotationUtils.addSentences(gold, datasets[i].second().get(CoreAnnotations.SentencesAnnotation.class));
     
      // merge test sets with predicted annotations
      Annotation[] mergedPredictions = new Annotation[3];
      assert(predictions != null);
      for (int taskLevel = 0; taskLevel < mergedPredictions.length; taskLevel++) {
        mergedPredictions[taskLevel] = new Annotation("");
        for(int fold = 0; fold < predictions[taskLevel].length; fold ++){
          if (predictions[taskLevel][fold] == null) continue;
          AnnotationUtils.addSentences(mergedPredictions[taskLevel], predictions[taskLevel][fold].get(CoreAnnotations.SentencesAnnotation.class));
        }
      }
      //
      // evaluate all tasks: entity, relation, and event recognition
      //
      if(MachineReadingProperties.extractEntities && ! entityResultsPrinterSet.isEmpty()){
        retMsg.addAll(printTask("entity extraction", entityResultsPrinterSet, gold, mergedPredictions[ENTITY_LEVEL]));
      }
     
      if(MachineReadingProperties.extractRelations && ! getRelationResultsPrinterSet().isEmpty()){
        retMsg.addAll(printTask("relation extraction", getRelationResultsPrinterSet(), gold, mergedPredictions[RELATION_LEVEL]));
      }
     
      //
      // Save the sentences with the predicted annotations
      //
      if (MachineReadingProperties.extractEntities && MachineReadingProperties.serializedEntityExtractionResults != null)
        IOUtils.writeObjectToFile(mergedPredictions[ENTITY_LEVEL], MachineReadingProperties.serializedEntityExtractionResults);
      if (MachineReadingProperties.extractRelations && MachineReadingProperties.serializedRelationExtractionResults != null)
        IOUtils.writeObjectToFile(mergedPredictions[RELATION_LEVEL],MachineReadingProperties.serializedRelationExtractionResults);
      if (MachineReadingProperties.extractEvents && MachineReadingProperties.serializedEventExtractionResults != null)
        IOUtils.writeObjectToFile(mergedPredictions[EVENT_LEVEL],MachineReadingProperties.serializedEventExtractionResults);
     
    }
    
    return retMsg;
  }
 
  protected List<String> printTask(String taskName, Set<ResultsPrinter> printers, Annotation gold, Annotation pred) {
    List<String> retMsg = new ArrayList<String>();
    for (ResultsPrinter rp : printers){
      String msg = rp.printResults(gold, pred);
      retMsg.add(msg);
      MachineReadingProperties.logger.severe("Overall " + taskName + " results, using printer " + rp.getClass() + ":\n" + msg);
    }
    return retMsg;
  }
 
  protected void train(Annotation training, int partition) throws Exception {
    //
    // train entity extraction
    //
    if (MachineReadingProperties.extractEntities) {
      MachineReadingProperties.logger.info("Training entity extraction model(s)");
      if (partition != -1) MachineReadingProperties.logger.info("In partition #" + partition);
      String modelName = MachineReadingProperties.serializedEntityExtractorPath;
      if (partition != -1) modelName += "." + partition;
      File modelFile = new File(modelName);
     
      MachineReadingProperties.logger.fine("forceRetraining = " + this.forceRetraining+ ", modelFile.exists = " + modelFile.exists());
      if(! this.forceRetraining&& modelFile.exists()){
        MachineReadingProperties.logger.info("Loading entity extraction model from " + modelName + " ...");
        entityExtractor = BasicEntityExtractor.load(modelName, MachineReadingProperties.entityClassifier, false);
      } else {
        MachineReadingProperties.logger.info("Training entity extraction model...");
        entityExtractor = makeEntityExtractor(MachineReadingProperties.entityClassifier, MachineReadingProperties.entityGazetteerPath);
        entityExtractor.train(training);
        MachineReadingProperties.logger.info("Serializing entity extraction model to " + modelName + " ...");
        entityExtractor.save(modelName);
      }
    }
   
    //
    // train relation extraction
    //
    if (MachineReadingProperties.extractRelations) {
      MachineReadingProperties.logger.info("Training relation extraction model(s)");
      if (partition != -1)
        MachineReadingProperties.logger.info("In partition #" + partition);
      String modelName = MachineReadingProperties.serializedRelationExtractorPath;
      if (partition != -1)
        modelName += "." + partition;
      Annotation predicted = null;
     
      if (MachineReadingProperties.useRelationExtractionModelMerging) {
        String[] modelNames = MachineReadingProperties.serializedRelationExtractorPath.split(",");
        if (partition != -1) {
          for (int i = 0; i < modelNames.length; i++) {
            modelNames[i] += "." + partition;
          }
        }
       
        relationExtractor = ExtractorMerger.buildRelationExtractorMerger(modelNames);
      } else if (!this.forceRetraining&& new File(modelName).exists()) {
        MachineReadingProperties.logger.info("Loading relation extraction model from " + modelName + " ...");
        //TODO change this to load any type of BasicRelationExtractor
        relationExtractor = BasicRelationExtractor.load(modelName);
      } else {
        RelationFeatureFactory rff = makeRelationFeatureFactory(MachineReadingProperties.relationFeatureFactoryClass, MachineReadingProperties.relationFeatures, MachineReadingProperties.doNotLexicalizeFirstArg);
        Execution.fillOptions(rff, args);

        if (MachineReadingProperties.trainRelationsUsingPredictedEntities) {
          // generate predicted entities
          assert(entityExtractor != null);
          predicted = AnnotationUtils.deepMentionCopy(training);
          entityExtractor.annotate(predicted);
          for (ResultsPrinter rp : entityResultsPrinterSet){
            String msg = rp.printResults(training, predicted);
            MachineReadingProperties.logger.info("Training relation extraction using predicted entitities: entity scores using printer " + rp.getClass() + ":\n" + msg);
          }
         
          // change relation mentions to use predicted entity mentions rather than gold ones
          try {
            changeGoldRelationArgsToPredicted(predicted);
          } catch (Exception e) {
            // we may get here for unknown EntityMentionComparator class
            throw new RuntimeException(e);
          }
        }

        Annotation dataset;
        if (MachineReadingProperties.trainRelationsUsingPredictedEntities) {
          dataset = predicted;
        } else {
          dataset = training;
        }
       
        Set<String> relationsToSkip = new HashSet<String>(StringUtils.split(MachineReadingProperties.relationsToSkipDuringTraining, ","));
        List<List<RelationMention>> backedUpRelations = new ArrayList<List<RelationMention>>();
        if (relationsToSkip.size() > 0) {
          // we need to backup the relations since removeSkippableRelations modifies dataset in place and we can't duplicate CoreMaps safely (or can we?)
          for (CoreMap sent : dataset.get(CoreAnnotations.SentencesAnnotation.class)) {
            List<RelationMention> relationMentions = sent.get(MachineReadingAnnotations.RelationMentionsAnnotation.class);
            backedUpRelations.add(relationMentions);
          }
         
          removeSkippableRelations(dataset, relationsToSkip);
        }
       
        //relationExtractor = new BasicRelationExtractor(rff, MachineReadingProperties.createUnrelatedRelations, makeRelationMentionFactory(MachineReadingProperties.relationMentionFactoryClass));
        relationExtractor = makeRelationExtractor(MachineReadingProperties.relationClassifier, rff, MachineReadingProperties.createUnrelatedRelations,
          makeRelationMentionFactory(MachineReadingProperties.relationMentionFactoryClass));
        Execution.fillOptions(relationExtractor, args);
        //Arguments.parse(args,relationExtractor);
        MachineReadingProperties.logger.info("Training relation extraction model...");
        relationExtractor.train(dataset);
        MachineReadingProperties.logger.info("Serializing relation extraction model to " + modelName + " ...");
        relationExtractor.save(modelName);

        if (relationsToSkip.size() > 0) {
          // restore backed up relations into dataset
          int sentenceIndex = 0;
         
          for (CoreMap sentence : dataset.get(CoreAnnotations.SentencesAnnotation.class)) {
            List<RelationMention> relationMentions = backedUpRelations.get(sentenceIndex);
            sentence.set(MachineReadingAnnotations.RelationMentionsAnnotation.class, relationMentions);
            sentenceIndex++;
          }
        }
      }
    }

    //
    // train event extraction -- currently just works with MSTBasedEventExtractor
    //
    if (MachineReadingProperties.extractEvents) {
      MachineReadingProperties.logger.info("Training event extraction model(s)");
      if (partition != -1) MachineReadingProperties.logger.info("In partition #" + partition);
      String modelName = MachineReadingProperties.serializedEventExtractorPath;
      if (partition != -1) modelName += "." + partition;
      File modelFile = new File(modelName);

      Annotation predicted = null;
      if(!this.forceRetraining&& modelFile.exists()) {
        MachineReadingProperties.logger.info("Loading event extraction model from " + modelName + " ...");
        Method mstLoader = (Class.forName("MSTBasedEventExtractor")).getMethod("load", String.class);
        eventExtractor = (Extractor) mstLoader.invoke(null, modelName);
      } else {
        if (MachineReadingProperties.trainEventsUsingPredictedEntities) {
          // generate predicted entities
          assert(entityExtractor != null);
          predicted = AnnotationUtils.deepMentionCopy(training);
          entityExtractor.annotate(predicted);
          for (ResultsPrinter rp : entityResultsPrinterSet){
            String msg = rp.printResults(training, predicted);
            MachineReadingProperties.logger.info("Training event extraction using predicted entitities: entity scores using printer " + rp.getClass() + ":\n" + msg);
          }
         
          // TODO: need an equivalent of changeGoldRelationArgsToPredicted here?
        }

        Constructor<?> mstConstructor = (Class.forName("edu.stanford.nlp.ie.machinereading.MSTBasedEventExtractor")).getConstructor(boolean.class);
        eventExtractor = (Extractor) mstConstructor.newInstance(MachineReadingProperties.trainEventsUsingPredictedEntities);

        MachineReadingProperties.logger.info("Training event extraction model...");
        if (MachineReadingProperties.trainRelationsUsingPredictedEntities) {
          eventExtractor.train(predicted);
        } else {
          eventExtractor.train(training);
        }
        MachineReadingProperties.logger.info("Serializing event extraction model to " + modelName + " ...");
        eventExtractor.save(modelName);
      }
    }
  }

  /**
   * Removes any relations with relation types in relationsToSkip from a dataset.  Dataset is modified in place.
   */
  private static void removeSkippableRelations(Annotation dataset, Set<String> relationsToSkip) {
    if (relationsToSkip == null || relationsToSkip.size() == 0) {
      return;
    }
    for (CoreMap sent : dataset.get(CoreAnnotations.SentencesAnnotation.class)) {
      List<RelationMention> relationMentions = sent.get(MachineReadingAnnotations.RelationMentionsAnnotation.class);
      if (relationMentions == null) {
        continue;
      }
      List<RelationMention> newRelationMentions = new ArrayList<RelationMention>();
      for (RelationMention rm: relationMentions) {
        if (!relationsToSkip.contains(rm.getType())) {
          newRelationMentions.add(rm);
        }
      }
      sent.set(MachineReadingAnnotations.RelationMentionsAnnotation.class, newRelationMentions);
    }
  }

  /**
   * Replaces all relation arguments with predicted entities
   */
  private static void changeGoldRelationArgsToPredicted(Annotation dataset) {
    for (CoreMap sent : dataset.get(CoreAnnotations.SentencesAnnotation.class)) {
      List<EntityMention> entityMentions = sent.get(MachineReadingAnnotations.EntityMentionsAnnotation.class);
      List<RelationMention> relationMentions = sent.get(MachineReadingAnnotations.RelationMentionsAnnotation.class);
      List<RelationMention> newRels = new ArrayList<RelationMention>();
      for (RelationMention rm : relationMentions) {
        rm.setSentence(sent);
        if (rm.replaceGoldArgsWithPredicted(entityMentions)) {
          MachineReadingProperties.logger.info("Successfully mapped all arguments in relation mention: " + rm);
          newRels.add(rm);
        } else {
          MachineReadingProperties.logger.info("Dropped relation mention due to failed argument mapping: " + rm);
        }
      }
      sent.set(MachineReadingAnnotations.RelationMentionsAnnotation.class, newRels);
      // we may have added new mentions to the entity list, so let's store it again
      sent.set(MachineReadingAnnotations.EntityMentionsAnnotation.class, entityMentions);
    }
  }
 
  public Annotation annotate(Annotation testing) {
    return annotate(testing, -1);
  }
 
  protected Annotation annotate(Annotation testing, int partition) {
    int partitionIndex = (partition != -1 ? partition : 0);

    //
    // annotate entities
    //
    if (MachineReadingProperties.extractEntities) {
      assert(entityExtractor != null);
      Annotation predicted = AnnotationUtils.deepMentionCopy(testing);
      entityExtractor.annotate(predicted);
     
      for (ResultsPrinter rp : entityResultsPrinterSet){
        String msg = rp.printResults(testing, predicted);
        MachineReadingProperties.logger.info("Entity extraction results " + (partition != -1 ? "for partition #" + partition : "") + " using printer " + rp.getClass() + ":\n" + msg);
      }
      predictions[ENTITY_LEVEL][partitionIndex] = predicted;
    }
   
    //
    // annotate relations
    //
    if (MachineReadingProperties.extractRelations) {
      assert(relationExtractor != null);
     
      Annotation predicted = (MachineReadingProperties.testRelationsUsingPredictedEntities ? predictions[ENTITY_LEVEL][partitionIndex] : AnnotationUtils.deepMentionCopy(testing));
      // make sure the entities have the syntactic head and span set. we need this for relation extraction features
      assignSyntacticHeadToEntities(predicted);
      relationExtractor.annotate(predicted);
     
      if (relationExtractionPostProcessor == null) {
        relationExtractionPostProcessor = makeExtractor(MachineReadingProperties.relationExtractionPostProcessorClass);
      }
      if (relationExtractionPostProcessor != null) {
        MachineReadingProperties.logger.info("Using relation extraction post processor: " + MachineReadingProperties.relationExtractionPostProcessorClass);
        relationExtractionPostProcessor.annotate(predicted);
      }
     
      for (ResultsPrinter rp : getRelationResultsPrinterSet()){
        String msg = rp.printResults(testing, predicted);
        MachineReadingProperties.logger.info("Relation extraction results " + (partition != -1 ? "for partition #" + partition : "") + " using printer " + rp.getClass() + ":\n" + msg);
      }
     
      //
      // apply the domain-specific consistency checks
      //
      if (consistencyChecker == null) {
        consistencyChecker = makeExtractor(MachineReadingProperties.consistencyCheck);
      }
      if (consistencyChecker != null) {
        MachineReadingProperties.logger.info("Using consistency checker: " + MachineReadingProperties.consistencyCheck);
        consistencyChecker.annotate(predicted);
       
        for (ResultsPrinter rp : entityResultsPrinterSet){
          String msg = rp.printResults(testing, predicted);
          MachineReadingProperties.logger.info("Entity extraction results AFTER consistency checks " + (partition != -1 ? "for partition #" + partition : "") + " using printer " + rp.getClass() + ":\n" + msg);
        }
        for (ResultsPrinter rp : getRelationResultsPrinterSet()){
          String msg = rp.printResults(testing, predicted);
          MachineReadingProperties.logger.info("Relation extraction results AFTER consistency checks " + (partition != -1 ? "for partition #" + partition : "") + " using printer " + rp.getClass() + ":\n" + msg);
        }
      }
     
      predictions[RELATION_LEVEL][partitionIndex] = predicted;
    }
   
    //
    // TODO: annotate events
    //
   
    return predictions[RELATION_LEVEL][partitionIndex];
  }

  private void assignSyntacticHeadToEntities(Annotation corpus) {
    assert(corpus != null);
    assert(corpus.get(SentencesAnnotation.class) != null);
    for(CoreMap sent: corpus.get(SentencesAnnotation.class)){
      List<CoreLabel> tokens = sent.get(TokensAnnotation.class);
      assert(tokens != null);
      Tree tree = sent.get(TreeAnnotation.class);
      if (MachineReadingProperties.forceGenerationOfIndexSpans) {
        tree.indexSpans(0)
      }
      assert(tree != null);
      if(sent.get(EntityMentionsAnnotation.class) != null){
        for(EntityMention e: sent.get(EntityMentionsAnnotation.class)){
          reader.assignSyntacticHead(e, tree, tokens, true);
        }
      }
    }
  }
 
  public static Extractor makeExtractor(Class<Extractor> extractorClass) {
    if (extractorClass == null) return null;
    Extractor ex;
    try {
      ex = extractorClass.getConstructor().newInstance();
    } catch (Exception e) {
      throw new RuntimeException(e);
    }
    return ex;
  }
 
  @SuppressWarnings("unchecked")
  protected void makeDataSets(Annotation training, Annotation testing, Annotation auxDataset) {
    if(! MachineReadingProperties.crossValidate){
      datasets = new Pair[1];
      Annotation trainingEnhanced = training;
      if (auxDataset != null) {
        trainingEnhanced = new Annotation(training.get(TextAnnotation.class));
        for(int i = 0; i < AnnotationUtils.sentenceCount(training); i ++){
          AnnotationUtils.addSentence(trainingEnhanced, AnnotationUtils.getSentence(training, i));
        }
        for (int ind = 0; ind < AnnotationUtils.sentenceCount(auxDataset); ind++) {
          AnnotationUtils.addSentence(trainingEnhanced, AnnotationUtils.getSentence(auxDataset, ind));
        }
      }
      datasets[0] = new Pair<Annotation, Annotation>(trainingEnhanced, testing);
     
      predictions = new Annotation[3][1];
    } else {
      assert(MachineReadingProperties.kfold > 1);
      datasets = new Pair[MachineReadingProperties.kfold];
      AnnotationUtils.shuffleSentences(training);
      for (int partition = 0; partition <MachineReadingProperties.kfold; partition++) {
        int begin = AnnotationUtils.sentenceCount(training) * partition / MachineReadingProperties.kfold;
        int end = AnnotationUtils.sentenceCount(training) * (partition + 1) / MachineReadingProperties.kfold;
        MachineReadingProperties.logger.info("Creating partition #" + partition + " using offsets [" + begin + ", " + end + ") out of " + AnnotationUtils.sentenceCount(training));
        Annotation partitionTrain = new Annotation("");
        Annotation partitionTest = new Annotation("");
        for(int i = 0; i < AnnotationUtils.sentenceCount(training); i ++){
          if(i < begin){
            AnnotationUtils.addSentence(partitionTrain, AnnotationUtils.getSentence(training, i));
          } else if(i < end){
            AnnotationUtils.addSentence(partitionTest, AnnotationUtils.getSentence(training, i));
          } else {
            AnnotationUtils.addSentence(partitionTrain, AnnotationUtils.getSentence(training, i));
          }
        }
       
        // for learning curve experiments
        // partitionTrain = keepPercentage(partitionTrain, percentageOfTrain);       
       
        if (auxDataset != null) {
          for (int ind = 0; ind < AnnotationUtils.sentenceCount(auxDataset); ind++) {
            AnnotationUtils.addSentence(partitionTrain, AnnotationUtils
                .getSentence(auxDataset, ind));
          }
        }
        datasets[partition] = new Pair<Annotation, Annotation>(partitionTrain, partitionTest);
      }
     
      predictions = new Annotation[3][MachineReadingProperties.kfold];
    }
  }
 
  /** Keeps only the first percentage sentences from the given corpus */
  static Annotation keepPercentage(Annotation corpus, double percentage) {
    System.err.println("Using percentage of train: " + percentage);
    Annotation smaller = new Annotation("");
    List<CoreMap> sents = new ArrayList<CoreMap>();
    List<CoreMap> fullSents = corpus.get(SentencesAnnotation.class);
    double smallSize = (double) fullSents.size() * percentage;
    for(int i = 0; i < smallSize; i ++){
      sents.add(fullSents.get(i));
    }
    System.err.println("TRAIN corpus size reduced from " + fullSents.size() + " to " + sents.size());
    smaller.set(SentencesAnnotation.class, sents);
    return smaller;
  }

  protected boolean serializedModelExists(String prefix) {
    if (!MachineReadingProperties.crossValidate) {
      File f = new File(prefix);
      return f.exists();
    }

    // in cross validation we serialize models to prefix.<FOLD COUNT>
    for (int i = 0; i < MachineReadingProperties.kfold; i++) {
      File f = new File(prefix + "." + Integer.toString(i));
      if (!f.exists()) {
        return false;
      }
    }
    return true;
  }
 
  /**
   * Creates ResultsPrinter instances based on the resultsPrinters argument
   * @param args
   */
  private void makeResultsPrinters(String[] args) {
    entityResultsPrinterSet = makeResultsPrinters(MachineReadingProperties.entityResultsPrinters, args);
    setRelationResultsPrinterSet(makeResultsPrinters(MachineReadingProperties.relationResultsPrinters, args));
    eventResultsPrinterSet = makeResultsPrinters(MachineReadingProperties.eventResultsPrinters, args);
  }
 
  private Set<ResultsPrinter> makeResultsPrinters(String classes, String [] args) {
    MachineReadingProperties.logger.info("Making result printers from " + classes);
    String[] printerClassNames = classes.trim().split(",\\s*");
    HashSet<ResultsPrinter> printers = new HashSet<ResultsPrinter>();
    for (String printerClassName : printerClassNames) {
      if(printerClassName.length() == 0) continue;
      ResultsPrinter rp;
      try {
        rp = (ResultsPrinter) Class.forName(printerClassName).getConstructor().newInstance();
        printers.add(rp);
      } catch (Exception e) {
        throw new RuntimeException(e);
      }
      //Execution.fillOptions(ResultsPrinterProps.class, args);
      //Arguments.parse(args,rp);
    }
    return printers;
  }
 
  /**
   * Constructs the corpus reader class and sets it as the reader for this MachineReading instance.
   *
   * @return corpus reader specified by datasetReaderClass
   */
  private GenericDataSetReader makeReader(Properties props) {
    try {
      if(reader == null){
        try {
          reader = MachineReadingProperties.datasetReaderClass.getConstructor(Properties.class).newInstance(props);
        } catch(java.lang.NoSuchMethodException e) {
          // if no c'tor with props found let's use the default one
          reader = MachineReadingProperties.datasetReaderClass.getConstructor().newInstance();
        }
      }
    } catch (Exception e) {
      throw new RuntimeException(e);
    }
   
    reader.setUseNewHeadFinder(MachineReadingProperties.useNewHeadFinder);
    return reader;
  }
 
  /**
   * Constructs the corpus reader class and sets it as the reader for this MachineReading instance.
   *
   * @return corpus reader specified by datasetAuxReaderClass
   */
  private GenericDataSetReader makeAuxReader() {
    try {
      if (auxReader == null) {
        if (MachineReadingProperties.datasetAuxReaderClass != null) {
          auxReader = MachineReadingProperties.datasetAuxReaderClass.getConstructor().newInstance();
        }
      }
    } catch (Exception e) {
      throw new RuntimeException(e);
    }
    return auxReader;
  }
 
  public static Extractor makeEntityExtractor(
      Class<? extends BasicEntityExtractor> entityExtractorClass,
      String gazetteerPath) {
    if (entityExtractorClass == null) return null;
    BasicEntityExtractor ex;
    try {
      ex = entityExtractorClass.getConstructor(String.class).newInstance(gazetteerPath);
    } catch (Exception e) {
      throw new RuntimeException(e);
    }
    return ex;
  }

  public static Extractor makeRelationExtractor(
    Class<? extends BasicRelationExtractor> relationExtractorClass,RelationFeatureFactory featureFac, boolean createUnrelatedRelations, RelationMentionFactory factory) {
    if (relationExtractorClass == null) return null;
    BasicRelationExtractor ex;
    try {
      ex = relationExtractorClass.getConstructor(RelationFeatureFactory.class, Boolean.class, RelationMentionFactory.class).newInstance(featureFac, createUnrelatedRelations, factory);
    } catch (Exception e) {
      throw new RuntimeException(e);
    }
    return ex;
  }

  public static RelationFeatureFactory makeRelationFeatureFactory(
      Class<? extends RelationFeatureFactory> relationFeatureFactoryClass,
      String relationFeatureList,
      boolean doNotLexicalizeFirstArg) {
    if (relationFeatureList == null || relationFeatureFactoryClass == null)
      return null;
    Object[] featureList = new Object [] {relationFeatureList.trim().split(",\\s*")};
    RelationFeatureFactory rff;
    try {
      rff = relationFeatureFactoryClass.getConstructor(String[].class).newInstance(featureList);
      rff.setDoNotLexicalizeFirstArgument(doNotLexicalizeFirstArg);
    } catch (Exception e) {
      throw new RuntimeException(e);
    }
    return rff;
  }
 
  public static RelationMentionFactory makeRelationMentionFactory(
      Class<RelationMentionFactory> relationMentionFactoryClass) {
    RelationMentionFactory rmf;
    try {
      rmf = relationMentionFactoryClass.getConstructor().newInstance();
    } catch (Exception e) {
      throw new RuntimeException(e);
    }
    return rmf;
  }
 
  /**
   * Gets the serialized sentences for a data set. If the serialized sentences
   * are already on disk, it loads them from there. Otherwise, the data set is
   * read with the corpus reader and the serialized sentences are saved to disk.
   *
   * @param sentencesPath
   *          location of the raw data set
   * @param reader
   *          the corpus reader
   * @param serializedSentences
   *          where the serialized sentences should be stored on disk
   * @return a list of RelationsSentences
   */
  protected Annotation loadOrMakeSerializedSentences(
      String sentencesPath, GenericDataSetReader reader,
      File serializedSentences) throws IOException, ClassNotFoundException {
    Annotation corpusSentences;
    // if the serialized file exists, just read it. otherwise read the source
    // and and save the serialized file to disk
    if (MachineReadingProperties.serializeCorpora && serializedSentences.exists() && !forceParseSentences) {
      MachineReadingProperties.logger.info("Loaded serialized sentences from " + serializedSentences.getAbsolutePath() + "...");
      corpusSentences = (Annotation) IOUtils.readObjectFromFile(serializedSentences);
      MachineReadingProperties.logger.info("Done. Loaded " + corpusSentences.get(CoreAnnotations.SentencesAnnotation.class).size() + " sentences.");
    } else {
      // read the corpus
      MachineReadingProperties.logger.info("Parsing corpus sentences...");
      if(MachineReadingProperties.serializeCorpora)
        MachineReadingProperties.logger.info("These sentences will be serialized to " + serializedSentences.getAbsolutePath());
      corpusSentences = reader.parse(sentencesPath);
      MachineReadingProperties.logger.info("Done. Parsed " + AnnotationUtils.sentenceCount(corpusSentences) + " sentences.");

      // save corpusSentences
      if(MachineReadingProperties.serializeCorpora){
        MachineReadingProperties.logger.info("Serializing parsed sentences to " + serializedSentences.getAbsolutePath() + "...");
        IOUtils.writeObjectToFile(corpusSentences,serializedSentences);
        MachineReadingProperties.logger.info("Done. Serialized " + AnnotationUtils.sentenceCount(corpusSentences) + " sentences.");
      }
    }
    return corpusSentences;
  }
 
  public void setExtractEntities(boolean extractEntities) {
    MachineReadingProperties.extractEntities = extractEntities;
  }

  public void setExtractRelations(boolean extractRelations) {
    MachineReadingProperties.extractRelations = extractRelations;
  }

  public void setExtractEvents(boolean extractEvents) {
    MachineReadingProperties.extractEvents = extractEvents;
  }
 
  public void setForceParseSentences(boolean forceParseSentences) {
    this.forceParseSentences = forceParseSentences;
  }

  public void setDatasets(Pair<Annotation, Annotation> [] datasets) {
    this.datasets = datasets;
  }

  public Pair<Annotation, Annotation> [] getDatasets() {
    return datasets;
  }

  public void setPredictions(Annotation [][] predictions) {
    this.predictions = predictions;
  }

  public Annotation [][] getPredictions() {
    return predictions;
  }

  public void setReader(GenericDataSetReader reader) {
    this.reader = reader;
  }

  public GenericDataSetReader getReader() {
    return reader;
  }

  public void setAuxReader(GenericDataSetReader auxReader) {
    this.auxReader = auxReader;
  }

  public GenericDataSetReader getAuxReader() {
    return auxReader;
  }

  public void setEntityResultsPrinterSet(Set<ResultsPrinter> entityResultsPrinterSet) {
    this.entityResultsPrinterSet = entityResultsPrinterSet;
  }

  public Set<ResultsPrinter> getEntityResultsPrinterSet() {
    return entityResultsPrinterSet;
  }

  public void setRelationResultsPrinterSet(Set<ResultsPrinter> relationResultsPrinterSet) {
    this.relationResultsPrinterSet = relationResultsPrinterSet;
  }

  public Set<ResultsPrinter> getRelationResultsPrinterSet() {
    return relationResultsPrinterSet;
  }
}
TOP

Related Classes of edu.stanford.nlp.ie.machinereading.MachineReading

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.