package edu.stanford.nlp.ie.machinereading;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Properties;
import java.util.Set;
import java.util.logging.ConsoleHandler;
import java.util.logging.Handler;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.logging.SimpleFormatter;
import edu.stanford.nlp.ie.machinereading.structure.*;
import edu.stanford.nlp.ie.machinereading.structure.MachineReadingAnnotations.EntityMentionsAnnotation;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.CoreAnnotations.SentencesAnnotation;
import edu.stanford.nlp.ling.CoreAnnotations.TextAnnotation;
import edu.stanford.nlp.ling.CoreAnnotations.TokensAnnotation;
import edu.stanford.nlp.trees.TreeCoreAnnotations.TreeAnnotation;
import edu.stanford.nlp.pipeline.Annotation;
import edu.stanford.nlp.pipeline.StanfordCoreNLP;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Execution;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;
/**
* Main driver for Machine Reading training, annotation, and evaluation. Does
* entity, relation, and event extraction for all corpora.
*
* This code has been adapted for 4 domains, all defined in the edu.stanford.nlp.ie.machinereading.domains package.
* For each domain, you need a properties file that is the only command line parameter for MachineReading.
* Minimally, for each domain you need to define a reader class that extends the GenericDataSetReader class
* and overrides the public Annotation read(String path) method.
*
* How to run: java edu.stanford.nlp.ie.machinereading.MachineReading --arguments propertiesfile
*
* This method creates an Annotation with additional objects per sentence: EntityMentions and RelationMentions.
* Using these objects, the classifiers that get called from MachineReading train entity and relation extractors.
* The simplest example domain currently is in edu.stanford.nlp.ie.machinereading.domains.roth,
* which is a simple entity and relation extraction using a dataset created by Dan Roth. The properties file for the domain is at
* projects/more/src/edu/stanford/nlp/ie/machinereading/domains/roth/roth.properties
*
* @author David McCLosky
* @author mrsmith
* @author Mihai
*
*/
public class MachineReading {
// Store command-line args so they can be passed to other classes
private String[] args;
/*
* class attributes
*/
public GenericDataSetReader reader;
public GenericDataSetReader auxReader;
public Extractor entityExtractor;
// TODO could add an entityExtractorPostProcessor if we need one
public Extractor relationExtractor;
public Extractor relationExtractionPostProcessor;
public Extractor eventExtractor;
public Extractor consistencyChecker;
protected boolean forceRetraining;
public boolean forceParseSentences;
/**
* Array of pairs of datasets (training, testing)
* If cross validation is enabled, the length of this array is the number of folds; otherwise it is 1
* The first element in each pair is the training corpus; the second is testing
*/
private Pair<Annotation, Annotation> [] datasets;
/**
* Stores the predictions of the extractors
* The first index is the partition number (of length 1 is cross validation is not enabled)
* The second index is the task: 0 - entities, 1 - relations, 2 - events
* Note: we need to store separate predictions per task because they may not be compatible with each other.
* For example, we may have predicted entities in task 0 but use gold entities for task 1.
*/
private Annotation [][] predictions;
private Set<ResultsPrinter> entityResultsPrinterSet;
private Set<ResultsPrinter> relationResultsPrinterSet;
@SuppressWarnings("unused")
private Set<ResultsPrinter> eventResultsPrinterSet;
public static final int ENTITY_LEVEL = 0;
public static final int RELATION_LEVEL = 1;
public static final int EVENT_LEVEL = 2;
public static void main(String[] args) throws Exception {
MachineReading mr = makeMachineReading(args);
mr.run();
}
public static void setLoggerLevel(Level level) {
setConsoleLevel(Level.FINEST);
MachineReadingProperties.logger.setLevel(level);
}
public static void setConsoleLevel(Level level) {
// get the top Logger:
Logger topLogger = java.util.logging.Logger.getLogger("");
// Handler for console (reuse it if it already exists)
Handler consoleHandler = null;
// see if there is already a console handler
for (Handler handler : topLogger.getHandlers()) {
if (handler instanceof ConsoleHandler) {
// found the console handler
consoleHandler = handler;
break;
}
}
if (consoleHandler == null) {
// there was no console handler found, create a new one
consoleHandler = new ConsoleHandler();
topLogger.addHandler(consoleHandler);
}
// set the console handler level:
consoleHandler.setLevel(level);
consoleHandler.setFormatter(new SimpleFormatter());
}
/**
* Use the makeMachineReading* methods to create MachineReading objects!
*/
private MachineReading(String [] args) {
this.args = args;
}
protected MachineReading() {
this.args = new String[0];
}
/**
* Creates a MR object to be used only for annotation purposes (no training)
* This is needed in order to integrate MachineReading with BaselineNLProcessor
*/
public static MachineReading makeMachineReadingForAnnotation(
GenericDataSetReader reader,
Extractor entityExtractor,
Extractor relationExtractor,
Extractor eventExtractor,
Extractor consistencyChecker,
Extractor relationPostProcessor,
boolean testRelationsUsingPredictedEntities,
boolean verbose) {
MachineReading mr = new MachineReading();
// readers needed to assign syntactic heads to predicted entities
mr.reader = reader;
mr.auxReader = null;
// no results printers needed
mr.entityResultsPrinterSet = new HashSet<ResultsPrinter>();
mr.setRelationResultsPrinterSet(new HashSet<ResultsPrinter>());
// create the storage for the generated annotations
mr.predictions = new Annotation[3][1];
// create the entity/relation classifiers
mr.entityExtractor = entityExtractor;
MachineReadingProperties.extractEntities = (entityExtractor != null ? true : false);
mr.relationExtractor = relationExtractor;
MachineReadingProperties.extractRelations = (relationExtractor != null ? true : false);
MachineReadingProperties.testRelationsUsingPredictedEntities = testRelationsUsingPredictedEntities;
mr.eventExtractor = eventExtractor;
MachineReadingProperties.extractEvents = (eventExtractor != null ? true : false);
mr.consistencyChecker = consistencyChecker;
mr.relationExtractionPostProcessor = relationPostProcessor;
Level level = verbose ? Level.FINEST : Level.SEVERE;
if (entityExtractor != null)
entityExtractor.setLoggerLevel(level);
if (mr.relationExtractor != null)
mr.relationExtractor.setLoggerLevel(level);
if (mr.eventExtractor != null)
mr.eventExtractor.setLoggerLevel(level);
return mr;
}
public static MachineReading makeMachineReading(String [] args) throws IOException {
// install global parameters
MachineReading mr = new MachineReading(args);
//TODO:
Execution.fillOptions(MachineReadingProperties.class, args);
//Arguments.parse(args, mr);
System.err.println("PERCENTAGE OF TRAIN: " + MachineReadingProperties.percentageOfTrain);
// convert args to properties
Properties props = StringUtils.argsToProperties(args);
if (props == null) {
throw new RuntimeException("ERROR: failed to find Properties in the given arguments!");
}
String logLevel = props.getProperty("logLevel", "INFO");
setLoggerLevel(Level.parse(logLevel.toUpperCase()));
// install reader specific parameters
GenericDataSetReader reader = mr.makeReader(props);
GenericDataSetReader auxReader = mr.makeAuxReader();
Level readerLogLevel = Level.parse(MachineReadingProperties.readerLogLevel.toUpperCase());
reader.setLoggerLevel(readerLogLevel);
if (auxReader != null) {
auxReader.setLoggerLevel(readerLogLevel);
}
System.err.println("The reader log level is set to " + readerLogLevel);
//Execution.fillOptions(GenericDataSetReaderProps.class, args);
//Arguments.parse(args, reader);
// create the pre-processing pipeline
StanfordCoreNLP pipe = new StanfordCoreNLP(props, false);
reader.setProcessor(pipe);
if (auxReader != null) {
auxReader.setProcessor(pipe);
}
// create the results printers
mr.makeResultsPrinters(args);
return mr;
}
/**
* Performs extraction. This will train a new extraction model and evaluate
* the model on the test set. Depending on the MachineReading instance's
* parameters, it may skip training if a model already exists or skip
* evaluation.
*
* returns results string, can be compared in a utest
*/
public List<String> run() throws Exception {
this.forceRetraining= !MachineReadingProperties.loadModel;
if (MachineReadingProperties.trainOnly) {
this.forceRetraining= true;
}
List<String> retMsg = new ArrayList<String>();
boolean haveSerializedEntityExtractor = serializedModelExists(MachineReadingProperties.serializedEntityExtractorPath);
boolean haveSerializedRelationExtractor = serializedModelExists(MachineReadingProperties.serializedRelationExtractorPath);
boolean haveSerializedEventExtractor = serializedModelExists(MachineReadingProperties.serializedEventExtractorPath);
Annotation training = null;
Annotation aux = null;
if ((MachineReadingProperties.extractEntities && !haveSerializedEntityExtractor) ||
(MachineReadingProperties.extractRelations && !haveSerializedRelationExtractor) ||
(MachineReadingProperties.extractEvents && !haveSerializedEventExtractor) ||
this.forceRetraining|| MachineReadingProperties.crossValidate){
// load training sentences
training = loadOrMakeSerializedSentences(MachineReadingProperties.trainPath, reader, new File(MachineReadingProperties.serializedTrainingSentencesPath));
if (auxReader != null) {
MachineReadingProperties.logger.severe("Reading auxiliary dataset from " + MachineReadingProperties.auxDataPath + "...");
aux = loadOrMakeSerializedSentences(MachineReadingProperties.auxDataPath, auxReader, new File(
MachineReadingProperties.serializedAuxTrainingSentencesPath));
MachineReadingProperties.logger.severe("Done reading auxiliary dataset.");
}
}
Annotation testing = null;
if (!MachineReadingProperties.trainOnly && !MachineReadingProperties.crossValidate) {
// load test sentences
File serializedTestSentences = new File(MachineReadingProperties.serializedTestSentencesPath);
testing = loadOrMakeSerializedSentences(MachineReadingProperties.testPath, reader, serializedTestSentences);
}
//
// create the actual datasets to be used for training and annotation
//
makeDataSets(training, testing, aux);
//
// process (training + annotate) one partition at a time
//
for(int partition = 0; partition < datasets.length; partition ++){
assert(datasets.length > partition);
assert(datasets[partition] != null);
assert(MachineReadingProperties.trainOnly || datasets[partition].second() != null);
// train all models
train(datasets[partition].first(), (MachineReadingProperties.crossValidate ? partition : -1));
// annotate using all models
if(! MachineReadingProperties.trainOnly){
MachineReadingProperties.logger.info("annotating partition " + partition );
annotate(datasets[partition].second(), (MachineReadingProperties.crossValidate ? partition: -1));
}
}
//
// now report overall results
//
if(! MachineReadingProperties.trainOnly){
// merge test sets for the gold data
Annotation gold = new Annotation("");
for(int i = 0; i < datasets.length; i ++) AnnotationUtils.addSentences(gold, datasets[i].second().get(CoreAnnotations.SentencesAnnotation.class));
// merge test sets with predicted annotations
Annotation[] mergedPredictions = new Annotation[3];
assert(predictions != null);
for (int taskLevel = 0; taskLevel < mergedPredictions.length; taskLevel++) {
mergedPredictions[taskLevel] = new Annotation("");
for(int fold = 0; fold < predictions[taskLevel].length; fold ++){
if (predictions[taskLevel][fold] == null) continue;
AnnotationUtils.addSentences(mergedPredictions[taskLevel], predictions[taskLevel][fold].get(CoreAnnotations.SentencesAnnotation.class));
}
}
//
// evaluate all tasks: entity, relation, and event recognition
//
if(MachineReadingProperties.extractEntities && ! entityResultsPrinterSet.isEmpty()){
retMsg.addAll(printTask("entity extraction", entityResultsPrinterSet, gold, mergedPredictions[ENTITY_LEVEL]));
}
if(MachineReadingProperties.extractRelations && ! getRelationResultsPrinterSet().isEmpty()){
retMsg.addAll(printTask("relation extraction", getRelationResultsPrinterSet(), gold, mergedPredictions[RELATION_LEVEL]));
}
//
// Save the sentences with the predicted annotations
//
if (MachineReadingProperties.extractEntities && MachineReadingProperties.serializedEntityExtractionResults != null)
IOUtils.writeObjectToFile(mergedPredictions[ENTITY_LEVEL], MachineReadingProperties.serializedEntityExtractionResults);
if (MachineReadingProperties.extractRelations && MachineReadingProperties.serializedRelationExtractionResults != null)
IOUtils.writeObjectToFile(mergedPredictions[RELATION_LEVEL],MachineReadingProperties.serializedRelationExtractionResults);
if (MachineReadingProperties.extractEvents && MachineReadingProperties.serializedEventExtractionResults != null)
IOUtils.writeObjectToFile(mergedPredictions[EVENT_LEVEL],MachineReadingProperties.serializedEventExtractionResults);
}
return retMsg;
}
protected List<String> printTask(String taskName, Set<ResultsPrinter> printers, Annotation gold, Annotation pred) {
List<String> retMsg = new ArrayList<String>();
for (ResultsPrinter rp : printers){
String msg = rp.printResults(gold, pred);
retMsg.add(msg);
MachineReadingProperties.logger.severe("Overall " + taskName + " results, using printer " + rp.getClass() + ":\n" + msg);
}
return retMsg;
}
protected void train(Annotation training, int partition) throws Exception {
//
// train entity extraction
//
if (MachineReadingProperties.extractEntities) {
MachineReadingProperties.logger.info("Training entity extraction model(s)");
if (partition != -1) MachineReadingProperties.logger.info("In partition #" + partition);
String modelName = MachineReadingProperties.serializedEntityExtractorPath;
if (partition != -1) modelName += "." + partition;
File modelFile = new File(modelName);
MachineReadingProperties.logger.fine("forceRetraining = " + this.forceRetraining+ ", modelFile.exists = " + modelFile.exists());
if(! this.forceRetraining&& modelFile.exists()){
MachineReadingProperties.logger.info("Loading entity extraction model from " + modelName + " ...");
entityExtractor = BasicEntityExtractor.load(modelName, MachineReadingProperties.entityClassifier, false);
} else {
MachineReadingProperties.logger.info("Training entity extraction model...");
entityExtractor = makeEntityExtractor(MachineReadingProperties.entityClassifier, MachineReadingProperties.entityGazetteerPath);
entityExtractor.train(training);
MachineReadingProperties.logger.info("Serializing entity extraction model to " + modelName + " ...");
entityExtractor.save(modelName);
}
}
//
// train relation extraction
//
if (MachineReadingProperties.extractRelations) {
MachineReadingProperties.logger.info("Training relation extraction model(s)");
if (partition != -1)
MachineReadingProperties.logger.info("In partition #" + partition);
String modelName = MachineReadingProperties.serializedRelationExtractorPath;
if (partition != -1)
modelName += "." + partition;
Annotation predicted = null;
if (MachineReadingProperties.useRelationExtractionModelMerging) {
String[] modelNames = MachineReadingProperties.serializedRelationExtractorPath.split(",");
if (partition != -1) {
for (int i = 0; i < modelNames.length; i++) {
modelNames[i] += "." + partition;
}
}
relationExtractor = ExtractorMerger.buildRelationExtractorMerger(modelNames);
} else if (!this.forceRetraining&& new File(modelName).exists()) {
MachineReadingProperties.logger.info("Loading relation extraction model from " + modelName + " ...");
//TODO change this to load any type of BasicRelationExtractor
relationExtractor = BasicRelationExtractor.load(modelName);
} else {
RelationFeatureFactory rff = makeRelationFeatureFactory(MachineReadingProperties.relationFeatureFactoryClass, MachineReadingProperties.relationFeatures, MachineReadingProperties.doNotLexicalizeFirstArg);
Execution.fillOptions(rff, args);
if (MachineReadingProperties.trainRelationsUsingPredictedEntities) {
// generate predicted entities
assert(entityExtractor != null);
predicted = AnnotationUtils.deepMentionCopy(training);
entityExtractor.annotate(predicted);
for (ResultsPrinter rp : entityResultsPrinterSet){
String msg = rp.printResults(training, predicted);
MachineReadingProperties.logger.info("Training relation extraction using predicted entitities: entity scores using printer " + rp.getClass() + ":\n" + msg);
}
// change relation mentions to use predicted entity mentions rather than gold ones
try {
changeGoldRelationArgsToPredicted(predicted);
} catch (Exception e) {
// we may get here for unknown EntityMentionComparator class
throw new RuntimeException(e);
}
}
Annotation dataset;
if (MachineReadingProperties.trainRelationsUsingPredictedEntities) {
dataset = predicted;
} else {
dataset = training;
}
Set<String> relationsToSkip = new HashSet<String>(StringUtils.split(MachineReadingProperties.relationsToSkipDuringTraining, ","));
List<List<RelationMention>> backedUpRelations = new ArrayList<List<RelationMention>>();
if (relationsToSkip.size() > 0) {
// we need to backup the relations since removeSkippableRelations modifies dataset in place and we can't duplicate CoreMaps safely (or can we?)
for (CoreMap sent : dataset.get(CoreAnnotations.SentencesAnnotation.class)) {
List<RelationMention> relationMentions = sent.get(MachineReadingAnnotations.RelationMentionsAnnotation.class);
backedUpRelations.add(relationMentions);
}
removeSkippableRelations(dataset, relationsToSkip);
}
//relationExtractor = new BasicRelationExtractor(rff, MachineReadingProperties.createUnrelatedRelations, makeRelationMentionFactory(MachineReadingProperties.relationMentionFactoryClass));
relationExtractor = makeRelationExtractor(MachineReadingProperties.relationClassifier, rff, MachineReadingProperties.createUnrelatedRelations,
makeRelationMentionFactory(MachineReadingProperties.relationMentionFactoryClass));
Execution.fillOptions(relationExtractor, args);
//Arguments.parse(args,relationExtractor);
MachineReadingProperties.logger.info("Training relation extraction model...");
relationExtractor.train(dataset);
MachineReadingProperties.logger.info("Serializing relation extraction model to " + modelName + " ...");
relationExtractor.save(modelName);
if (relationsToSkip.size() > 0) {
// restore backed up relations into dataset
int sentenceIndex = 0;
for (CoreMap sentence : dataset.get(CoreAnnotations.SentencesAnnotation.class)) {
List<RelationMention> relationMentions = backedUpRelations.get(sentenceIndex);
sentence.set(MachineReadingAnnotations.RelationMentionsAnnotation.class, relationMentions);
sentenceIndex++;
}
}
}
}
//
// train event extraction -- currently just works with MSTBasedEventExtractor
//
if (MachineReadingProperties.extractEvents) {
MachineReadingProperties.logger.info("Training event extraction model(s)");
if (partition != -1) MachineReadingProperties.logger.info("In partition #" + partition);
String modelName = MachineReadingProperties.serializedEventExtractorPath;
if (partition != -1) modelName += "." + partition;
File modelFile = new File(modelName);
Annotation predicted = null;
if(!this.forceRetraining&& modelFile.exists()) {
MachineReadingProperties.logger.info("Loading event extraction model from " + modelName + " ...");
Method mstLoader = (Class.forName("MSTBasedEventExtractor")).getMethod("load", String.class);
eventExtractor = (Extractor) mstLoader.invoke(null, modelName);
} else {
if (MachineReadingProperties.trainEventsUsingPredictedEntities) {
// generate predicted entities
assert(entityExtractor != null);
predicted = AnnotationUtils.deepMentionCopy(training);
entityExtractor.annotate(predicted);
for (ResultsPrinter rp : entityResultsPrinterSet){
String msg = rp.printResults(training, predicted);
MachineReadingProperties.logger.info("Training event extraction using predicted entitities: entity scores using printer " + rp.getClass() + ":\n" + msg);
}
// TODO: need an equivalent of changeGoldRelationArgsToPredicted here?
}
Constructor<?> mstConstructor = (Class.forName("edu.stanford.nlp.ie.machinereading.MSTBasedEventExtractor")).getConstructor(boolean.class);
eventExtractor = (Extractor) mstConstructor.newInstance(MachineReadingProperties.trainEventsUsingPredictedEntities);
MachineReadingProperties.logger.info("Training event extraction model...");
if (MachineReadingProperties.trainRelationsUsingPredictedEntities) {
eventExtractor.train(predicted);
} else {
eventExtractor.train(training);
}
MachineReadingProperties.logger.info("Serializing event extraction model to " + modelName + " ...");
eventExtractor.save(modelName);
}
}
}
/**
* Removes any relations with relation types in relationsToSkip from a dataset. Dataset is modified in place.
*/
private static void removeSkippableRelations(Annotation dataset, Set<String> relationsToSkip) {
if (relationsToSkip == null || relationsToSkip.size() == 0) {
return;
}
for (CoreMap sent : dataset.get(CoreAnnotations.SentencesAnnotation.class)) {
List<RelationMention> relationMentions = sent.get(MachineReadingAnnotations.RelationMentionsAnnotation.class);
if (relationMentions == null) {
continue;
}
List<RelationMention> newRelationMentions = new ArrayList<RelationMention>();
for (RelationMention rm: relationMentions) {
if (!relationsToSkip.contains(rm.getType())) {
newRelationMentions.add(rm);
}
}
sent.set(MachineReadingAnnotations.RelationMentionsAnnotation.class, newRelationMentions);
}
}
/**
* Replaces all relation arguments with predicted entities
*/
private static void changeGoldRelationArgsToPredicted(Annotation dataset) {
for (CoreMap sent : dataset.get(CoreAnnotations.SentencesAnnotation.class)) {
List<EntityMention> entityMentions = sent.get(MachineReadingAnnotations.EntityMentionsAnnotation.class);
List<RelationMention> relationMentions = sent.get(MachineReadingAnnotations.RelationMentionsAnnotation.class);
List<RelationMention> newRels = new ArrayList<RelationMention>();
for (RelationMention rm : relationMentions) {
rm.setSentence(sent);
if (rm.replaceGoldArgsWithPredicted(entityMentions)) {
MachineReadingProperties.logger.info("Successfully mapped all arguments in relation mention: " + rm);
newRels.add(rm);
} else {
MachineReadingProperties.logger.info("Dropped relation mention due to failed argument mapping: " + rm);
}
}
sent.set(MachineReadingAnnotations.RelationMentionsAnnotation.class, newRels);
// we may have added new mentions to the entity list, so let's store it again
sent.set(MachineReadingAnnotations.EntityMentionsAnnotation.class, entityMentions);
}
}
public Annotation annotate(Annotation testing) {
return annotate(testing, -1);
}
protected Annotation annotate(Annotation testing, int partition) {
int partitionIndex = (partition != -1 ? partition : 0);
//
// annotate entities
//
if (MachineReadingProperties.extractEntities) {
assert(entityExtractor != null);
Annotation predicted = AnnotationUtils.deepMentionCopy(testing);
entityExtractor.annotate(predicted);
for (ResultsPrinter rp : entityResultsPrinterSet){
String msg = rp.printResults(testing, predicted);
MachineReadingProperties.logger.info("Entity extraction results " + (partition != -1 ? "for partition #" + partition : "") + " using printer " + rp.getClass() + ":\n" + msg);
}
predictions[ENTITY_LEVEL][partitionIndex] = predicted;
}
//
// annotate relations
//
if (MachineReadingProperties.extractRelations) {
assert(relationExtractor != null);
Annotation predicted = (MachineReadingProperties.testRelationsUsingPredictedEntities ? predictions[ENTITY_LEVEL][partitionIndex] : AnnotationUtils.deepMentionCopy(testing));
// make sure the entities have the syntactic head and span set. we need this for relation extraction features
assignSyntacticHeadToEntities(predicted);
relationExtractor.annotate(predicted);
if (relationExtractionPostProcessor == null) {
relationExtractionPostProcessor = makeExtractor(MachineReadingProperties.relationExtractionPostProcessorClass);
}
if (relationExtractionPostProcessor != null) {
MachineReadingProperties.logger.info("Using relation extraction post processor: " + MachineReadingProperties.relationExtractionPostProcessorClass);
relationExtractionPostProcessor.annotate(predicted);
}
for (ResultsPrinter rp : getRelationResultsPrinterSet()){
String msg = rp.printResults(testing, predicted);
MachineReadingProperties.logger.info("Relation extraction results " + (partition != -1 ? "for partition #" + partition : "") + " using printer " + rp.getClass() + ":\n" + msg);
}
//
// apply the domain-specific consistency checks
//
if (consistencyChecker == null) {
consistencyChecker = makeExtractor(MachineReadingProperties.consistencyCheck);
}
if (consistencyChecker != null) {
MachineReadingProperties.logger.info("Using consistency checker: " + MachineReadingProperties.consistencyCheck);
consistencyChecker.annotate(predicted);
for (ResultsPrinter rp : entityResultsPrinterSet){
String msg = rp.printResults(testing, predicted);
MachineReadingProperties.logger.info("Entity extraction results AFTER consistency checks " + (partition != -1 ? "for partition #" + partition : "") + " using printer " + rp.getClass() + ":\n" + msg);
}
for (ResultsPrinter rp : getRelationResultsPrinterSet()){
String msg = rp.printResults(testing, predicted);
MachineReadingProperties.logger.info("Relation extraction results AFTER consistency checks " + (partition != -1 ? "for partition #" + partition : "") + " using printer " + rp.getClass() + ":\n" + msg);
}
}
predictions[RELATION_LEVEL][partitionIndex] = predicted;
}
//
// TODO: annotate events
//
return predictions[RELATION_LEVEL][partitionIndex];
}
private void assignSyntacticHeadToEntities(Annotation corpus) {
assert(corpus != null);
assert(corpus.get(SentencesAnnotation.class) != null);
for(CoreMap sent: corpus.get(SentencesAnnotation.class)){
List<CoreLabel> tokens = sent.get(TokensAnnotation.class);
assert(tokens != null);
Tree tree = sent.get(TreeAnnotation.class);
if (MachineReadingProperties.forceGenerationOfIndexSpans) {
tree.indexSpans(0);
}
assert(tree != null);
if(sent.get(EntityMentionsAnnotation.class) != null){
for(EntityMention e: sent.get(EntityMentionsAnnotation.class)){
reader.assignSyntacticHead(e, tree, tokens, true);
}
}
}
}
public static Extractor makeExtractor(Class<Extractor> extractorClass) {
if (extractorClass == null) return null;
Extractor ex;
try {
ex = extractorClass.getConstructor().newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
return ex;
}
@SuppressWarnings("unchecked")
protected void makeDataSets(Annotation training, Annotation testing, Annotation auxDataset) {
if(! MachineReadingProperties.crossValidate){
datasets = new Pair[1];
Annotation trainingEnhanced = training;
if (auxDataset != null) {
trainingEnhanced = new Annotation(training.get(TextAnnotation.class));
for(int i = 0; i < AnnotationUtils.sentenceCount(training); i ++){
AnnotationUtils.addSentence(trainingEnhanced, AnnotationUtils.getSentence(training, i));
}
for (int ind = 0; ind < AnnotationUtils.sentenceCount(auxDataset); ind++) {
AnnotationUtils.addSentence(trainingEnhanced, AnnotationUtils.getSentence(auxDataset, ind));
}
}
datasets[0] = new Pair<Annotation, Annotation>(trainingEnhanced, testing);
predictions = new Annotation[3][1];
} else {
assert(MachineReadingProperties.kfold > 1);
datasets = new Pair[MachineReadingProperties.kfold];
AnnotationUtils.shuffleSentences(training);
for (int partition = 0; partition <MachineReadingProperties.kfold; partition++) {
int begin = AnnotationUtils.sentenceCount(training) * partition / MachineReadingProperties.kfold;
int end = AnnotationUtils.sentenceCount(training) * (partition + 1) / MachineReadingProperties.kfold;
MachineReadingProperties.logger.info("Creating partition #" + partition + " using offsets [" + begin + ", " + end + ") out of " + AnnotationUtils.sentenceCount(training));
Annotation partitionTrain = new Annotation("");
Annotation partitionTest = new Annotation("");
for(int i = 0; i < AnnotationUtils.sentenceCount(training); i ++){
if(i < begin){
AnnotationUtils.addSentence(partitionTrain, AnnotationUtils.getSentence(training, i));
} else if(i < end){
AnnotationUtils.addSentence(partitionTest, AnnotationUtils.getSentence(training, i));
} else {
AnnotationUtils.addSentence(partitionTrain, AnnotationUtils.getSentence(training, i));
}
}
// for learning curve experiments
// partitionTrain = keepPercentage(partitionTrain, percentageOfTrain);
if (auxDataset != null) {
for (int ind = 0; ind < AnnotationUtils.sentenceCount(auxDataset); ind++) {
AnnotationUtils.addSentence(partitionTrain, AnnotationUtils
.getSentence(auxDataset, ind));
}
}
datasets[partition] = new Pair<Annotation, Annotation>(partitionTrain, partitionTest);
}
predictions = new Annotation[3][MachineReadingProperties.kfold];
}
}
/** Keeps only the first percentage sentences from the given corpus */
static Annotation keepPercentage(Annotation corpus, double percentage) {
System.err.println("Using percentage of train: " + percentage);
Annotation smaller = new Annotation("");
List<CoreMap> sents = new ArrayList<CoreMap>();
List<CoreMap> fullSents = corpus.get(SentencesAnnotation.class);
double smallSize = (double) fullSents.size() * percentage;
for(int i = 0; i < smallSize; i ++){
sents.add(fullSents.get(i));
}
System.err.println("TRAIN corpus size reduced from " + fullSents.size() + " to " + sents.size());
smaller.set(SentencesAnnotation.class, sents);
return smaller;
}
protected boolean serializedModelExists(String prefix) {
if (!MachineReadingProperties.crossValidate) {
File f = new File(prefix);
return f.exists();
}
// in cross validation we serialize models to prefix.<FOLD COUNT>
for (int i = 0; i < MachineReadingProperties.kfold; i++) {
File f = new File(prefix + "." + Integer.toString(i));
if (!f.exists()) {
return false;
}
}
return true;
}
/**
* Creates ResultsPrinter instances based on the resultsPrinters argument
* @param args
*/
private void makeResultsPrinters(String[] args) {
entityResultsPrinterSet = makeResultsPrinters(MachineReadingProperties.entityResultsPrinters, args);
setRelationResultsPrinterSet(makeResultsPrinters(MachineReadingProperties.relationResultsPrinters, args));
eventResultsPrinterSet = makeResultsPrinters(MachineReadingProperties.eventResultsPrinters, args);
}
private Set<ResultsPrinter> makeResultsPrinters(String classes, String [] args) {
MachineReadingProperties.logger.info("Making result printers from " + classes);
String[] printerClassNames = classes.trim().split(",\\s*");
HashSet<ResultsPrinter> printers = new HashSet<ResultsPrinter>();
for (String printerClassName : printerClassNames) {
if(printerClassName.length() == 0) continue;
ResultsPrinter rp;
try {
rp = (ResultsPrinter) Class.forName(printerClassName).getConstructor().newInstance();
printers.add(rp);
} catch (Exception e) {
throw new RuntimeException(e);
}
//Execution.fillOptions(ResultsPrinterProps.class, args);
//Arguments.parse(args,rp);
}
return printers;
}
/**
* Constructs the corpus reader class and sets it as the reader for this MachineReading instance.
*
* @return corpus reader specified by datasetReaderClass
*/
private GenericDataSetReader makeReader(Properties props) {
try {
if(reader == null){
try {
reader = MachineReadingProperties.datasetReaderClass.getConstructor(Properties.class).newInstance(props);
} catch(java.lang.NoSuchMethodException e) {
// if no c'tor with props found let's use the default one
reader = MachineReadingProperties.datasetReaderClass.getConstructor().newInstance();
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
reader.setUseNewHeadFinder(MachineReadingProperties.useNewHeadFinder);
return reader;
}
/**
* Constructs the corpus reader class and sets it as the reader for this MachineReading instance.
*
* @return corpus reader specified by datasetAuxReaderClass
*/
private GenericDataSetReader makeAuxReader() {
try {
if (auxReader == null) {
if (MachineReadingProperties.datasetAuxReaderClass != null) {
auxReader = MachineReadingProperties.datasetAuxReaderClass.getConstructor().newInstance();
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
return auxReader;
}
public static Extractor makeEntityExtractor(
Class<? extends BasicEntityExtractor> entityExtractorClass,
String gazetteerPath) {
if (entityExtractorClass == null) return null;
BasicEntityExtractor ex;
try {
ex = entityExtractorClass.getConstructor(String.class).newInstance(gazetteerPath);
} catch (Exception e) {
throw new RuntimeException(e);
}
return ex;
}
public static Extractor makeRelationExtractor(
Class<? extends BasicRelationExtractor> relationExtractorClass,RelationFeatureFactory featureFac, boolean createUnrelatedRelations, RelationMentionFactory factory) {
if (relationExtractorClass == null) return null;
BasicRelationExtractor ex;
try {
ex = relationExtractorClass.getConstructor(RelationFeatureFactory.class, Boolean.class, RelationMentionFactory.class).newInstance(featureFac, createUnrelatedRelations, factory);
} catch (Exception e) {
throw new RuntimeException(e);
}
return ex;
}
public static RelationFeatureFactory makeRelationFeatureFactory(
Class<? extends RelationFeatureFactory> relationFeatureFactoryClass,
String relationFeatureList,
boolean doNotLexicalizeFirstArg) {
if (relationFeatureList == null || relationFeatureFactoryClass == null)
return null;
Object[] featureList = new Object [] {relationFeatureList.trim().split(",\\s*")};
RelationFeatureFactory rff;
try {
rff = relationFeatureFactoryClass.getConstructor(String[].class).newInstance(featureList);
rff.setDoNotLexicalizeFirstArgument(doNotLexicalizeFirstArg);
} catch (Exception e) {
throw new RuntimeException(e);
}
return rff;
}
public static RelationMentionFactory makeRelationMentionFactory(
Class<RelationMentionFactory> relationMentionFactoryClass) {
RelationMentionFactory rmf;
try {
rmf = relationMentionFactoryClass.getConstructor().newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
return rmf;
}
/**
* Gets the serialized sentences for a data set. If the serialized sentences
* are already on disk, it loads them from there. Otherwise, the data set is
* read with the corpus reader and the serialized sentences are saved to disk.
*
* @param sentencesPath
* location of the raw data set
* @param reader
* the corpus reader
* @param serializedSentences
* where the serialized sentences should be stored on disk
* @return a list of RelationsSentences
*/
protected Annotation loadOrMakeSerializedSentences(
String sentencesPath, GenericDataSetReader reader,
File serializedSentences) throws IOException, ClassNotFoundException {
Annotation corpusSentences;
// if the serialized file exists, just read it. otherwise read the source
// and and save the serialized file to disk
if (MachineReadingProperties.serializeCorpora && serializedSentences.exists() && !forceParseSentences) {
MachineReadingProperties.logger.info("Loaded serialized sentences from " + serializedSentences.getAbsolutePath() + "...");
corpusSentences = (Annotation) IOUtils.readObjectFromFile(serializedSentences);
MachineReadingProperties.logger.info("Done. Loaded " + corpusSentences.get(CoreAnnotations.SentencesAnnotation.class).size() + " sentences.");
} else {
// read the corpus
MachineReadingProperties.logger.info("Parsing corpus sentences...");
if(MachineReadingProperties.serializeCorpora)
MachineReadingProperties.logger.info("These sentences will be serialized to " + serializedSentences.getAbsolutePath());
corpusSentences = reader.parse(sentencesPath);
MachineReadingProperties.logger.info("Done. Parsed " + AnnotationUtils.sentenceCount(corpusSentences) + " sentences.");
// save corpusSentences
if(MachineReadingProperties.serializeCorpora){
MachineReadingProperties.logger.info("Serializing parsed sentences to " + serializedSentences.getAbsolutePath() + "...");
IOUtils.writeObjectToFile(corpusSentences,serializedSentences);
MachineReadingProperties.logger.info("Done. Serialized " + AnnotationUtils.sentenceCount(corpusSentences) + " sentences.");
}
}
return corpusSentences;
}
public void setExtractEntities(boolean extractEntities) {
MachineReadingProperties.extractEntities = extractEntities;
}
public void setExtractRelations(boolean extractRelations) {
MachineReadingProperties.extractRelations = extractRelations;
}
public void setExtractEvents(boolean extractEvents) {
MachineReadingProperties.extractEvents = extractEvents;
}
public void setForceParseSentences(boolean forceParseSentences) {
this.forceParseSentences = forceParseSentences;
}
public void setDatasets(Pair<Annotation, Annotation> [] datasets) {
this.datasets = datasets;
}
public Pair<Annotation, Annotation> [] getDatasets() {
return datasets;
}
public void setPredictions(Annotation [][] predictions) {
this.predictions = predictions;
}
public Annotation [][] getPredictions() {
return predictions;
}
public void setReader(GenericDataSetReader reader) {
this.reader = reader;
}
public GenericDataSetReader getReader() {
return reader;
}
public void setAuxReader(GenericDataSetReader auxReader) {
this.auxReader = auxReader;
}
public GenericDataSetReader getAuxReader() {
return auxReader;
}
public void setEntityResultsPrinterSet(Set<ResultsPrinter> entityResultsPrinterSet) {
this.entityResultsPrinterSet = entityResultsPrinterSet;
}
public Set<ResultsPrinter> getEntityResultsPrinterSet() {
return entityResultsPrinterSet;
}
public void setRelationResultsPrinterSet(Set<ResultsPrinter> relationResultsPrinterSet) {
this.relationResultsPrinterSet = relationResultsPrinterSet;
}
public Set<ResultsPrinter> getRelationResultsPrinterSet() {
return relationResultsPrinterSet;
}
}