package edu.stanford.nlp.parser.eval;
import java.io.File;
import java.io.FileFilter;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.Stack;
import edu.stanford.nlp.international.Languages;
import edu.stanford.nlp.international.Languages.Language;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.parser.lexparser.TreebankLangParserParams;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.trees.DiskTreebank;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.PropertiesUtils;
import edu.stanford.nlp.util.StringUtils;
/**
* Utility class for extracting a variety of statistics from multi-lingual treebanks.
*
* TODO(spenceg) Add sample standard deviation
*
* @author Spence Green
*/
public class TreebankStats {
private final Language languageName;
private final TreebankLangParserParams tlpp;
private final List<String> pathNames;
private enum Split {Train,Dev,Test};
private Map<Split,Set<String>> splitFileLists;
private boolean useSplit = false;
private boolean makeVocab = false;
private static Set<String> trainVocab = null;
public TreebankStats(Language langName, List<String> paths, TreebankLangParserParams tlpp) {
languageName = langName;
pathNames = paths;
this.tlpp = tlpp;
}
public boolean useSplit(String prefix) {
Map<Split,File> splitMap = Generics.newHashMap();
splitMap.put(Split.Train,new File(prefix + ".train"));
splitMap.put(Split.Test,new File(prefix + ".test"));
splitMap.put(Split.Dev,new File(prefix + ".dev"));
splitFileLists = Generics.newHashMap();
for(Map.Entry<Split, File> entry : splitMap.entrySet()) {
File f = entry.getValue();
if(!f.exists()) return false;
Set<String> files = Generics.newHashSet();
for(String fileName : IOUtils.readLines(f))
files.add(fileName);
splitFileLists.put(entry.getKey(), files);
}
useSplit = true;
return true;
}
private ObservedCorpusStats gatherStats(DiskTreebank tb, String name) {
ObservedCorpusStats ocs = new ObservedCorpusStats(name);
if(makeVocab) trainVocab = Generics.newHashSet();
System.out.println("Reading treebank:");
for(Tree t : tb) {
Pair<Integer,Integer> treeFacts = dissectTree(t, ocs, makeVocab);
ocs.addStatsForTree(t.yield().size(), treeFacts.first(), treeFacts.second());
if(ocs.numTrees % 100 == 0) System.out.print(".");
else if(ocs.numTrees % 8001 == 0) System.out.println();
}
ocs.computeFinalValues();
System.out.println("done!");
return ocs;
}
/**
* Returns pair of (depth,breadth) of tree. Does a breadth-first search.
* @param t
* @param ocs
* @param addToVocab
*/
private Pair<Integer,Integer> dissectTree(Tree t, ObservedCorpusStats ocs, boolean addToVocab) {
final Stack<Pair<Integer,Tree>> stack = new Stack<Pair<Integer,Tree>>();
stack.push(new Pair<Integer,Tree>(0,t));
int maxBreadth = 0;
int maxDepth = -1;
if(t == null) {
throw new RuntimeException("Null tree passed to dissectTree()");
} else {
while(!stack.isEmpty()) {
Pair<Integer,Tree> depthNode = stack.pop();
final int nodeDepth = depthNode.first();
final Tree node = depthNode.second();
if(nodeDepth != maxDepth) {
maxDepth = nodeDepth;
if(node.isPhrasal() && stack.size() + 1 > maxBreadth) maxBreadth = stack.size() + 1;
}
if(node.isPhrasal()) {
ocs.addPhrasalBranch(node.value(), node.children().length);
} else if(node.isPreTerminal())
ocs.posTags.incrementCount(node.value());
else if(node.isLeaf()) {
ocs.words.incrementCount(node.value());
if(addToVocab) trainVocab.add(node.value());
}
for(Tree kid : node.children())
stack.push(new Pair<Integer,Tree>(nodeDepth + 1, kid));
}
}
return new Pair<Integer,Integer>(maxDepth,maxBreadth);
}
private void display(ObservedCorpusStats corpStats, boolean displayWords, boolean displayOOV) {
System.out.println("####################################################################");
System.out.println("## " + corpStats.getName());
System.out.println("####################################################################");
System.out.println();
corpStats.display(displayWords, displayOOV);
}
private ObservedCorpusStats aggregateStats(List<ObservedCorpusStats> allStats) {
if(allStats.size() == 0) return null;
else if(allStats.size() == 1) return allStats.get(0);
ObservedCorpusStats agStats = new ObservedCorpusStats("CORPUS");
for(ObservedCorpusStats ocs : allStats) {
agStats.numTrees += ocs.numTrees;
agStats.breadth2 += ocs.breadth2;
agStats.breadths.addAll(ocs.breadths);
agStats.depth2 += ocs.depth2;
agStats.depths.addAll(ocs.depths);
agStats.length2 += ocs.length2;
agStats.lengths.addAll(ocs.lengths);
if(ocs.minLength < agStats.minLength) agStats.minLength = ocs.minLength;
if(ocs.maxLength > agStats.maxLength) agStats.maxLength = ocs.maxLength;
if(ocs.minBreadth < agStats.minBreadth) agStats.minBreadth = ocs.minBreadth;
if(ocs.maxBreadth > agStats.maxBreadth) agStats.maxBreadth = ocs.maxBreadth;
if(ocs.minDepth < agStats.minDepth) agStats.minDepth = ocs.minDepth;
if(ocs.maxDepth > agStats.maxDepth) agStats.maxDepth = ocs.maxDepth;
agStats.words.addAll(ocs.words);
agStats.posTags.addAll(ocs.posTags);
agStats.phrasalBranching2.addAll(ocs.phrasalBranching2);
agStats.phrasalBranchingNum2.addAll(ocs.phrasalBranchingNum2);
}
agStats.computeFinalValues();
return agStats;
}
public void run(boolean pathsAreFiles, boolean displayWords, boolean displayOOV) {
if(useSplit) {
List<ObservedCorpusStats> allSplitStats = new ArrayList<ObservedCorpusStats>();
makeVocab = true;
for(Map.Entry<Split, Set<String>> split : splitFileLists.entrySet()) {
DiskTreebank tb = tlpp.diskTreebank();
FileFilter splitFilter = new SplitFilter(split.getValue());
for(String path : pathNames)
tb.loadPath(path, splitFilter);
ObservedCorpusStats splitStats = gatherStats(tb,languageName.toString() + "." + split.getKey().toString());
allSplitStats.add(splitStats);
makeVocab = false;
}
display(aggregateStats(allSplitStats), displayWords, displayOOV);
for(ObservedCorpusStats ocs : allSplitStats)
display(ocs, displayWords, displayOOV);
} else if(pathsAreFiles) {
makeVocab = true;
for(String path : pathNames) {
DiskTreebank tb = tlpp.diskTreebank();
tb.loadPath(path, pathname -> true);
ObservedCorpusStats stats = gatherStats(tb, languageName.toString() + " " + path.toString());
display(stats, displayWords, displayOOV);
makeVocab = false;
}
} else {
trainVocab = Generics.newHashSet();
DiskTreebank tb = tlpp.diskTreebank();
for(String path : pathNames)
tb.loadPath(path, pathname -> !pathname.isDirectory());
ObservedCorpusStats allStats = gatherStats(tb, languageName.toString());
display(allStats, displayWords, displayOOV);
}
}
protected class SplitFilter implements FileFilter {
private final Set<String> filterMap;
public SplitFilter(Set<String> fileList) {
filterMap = fileList;
}
public boolean accept(File f) {
return filterMap.contains(f.getName());
}
}
protected static class ObservedCorpusStats {
private final String corpusName;
public ObservedCorpusStats(String name) {
corpusName = name;
words = new ClassicCounter<String>();
posTags = new ClassicCounter<String>();
phrasalBranching2 = new ClassicCounter<String>();
phrasalBranchingNum2 = new ClassicCounter<String>();
lengths = new ArrayList<Integer>();
depths = new ArrayList<Integer>();
breadths = new ArrayList<Integer>();
}
public String getName() { return corpusName; }
public void addStatsForTree(int yieldLength, int depth, int breadth) {
numTrees++;
breadths.add(breadth);
breadth2 += breadth;
lengths.add(yieldLength);
length2 += yieldLength;
depths.add(depth);
depth2 += depth;
if(depth < minDepth) minDepth = depth;
else if(depth > maxDepth) maxDepth = depth;
if(yieldLength < minLength) minLength = yieldLength;
else if(yieldLength > maxLength) maxLength = yieldLength;
if(breadth < minBreadth) minBreadth = breadth;
else if(breadth > maxBreadth) maxBreadth = breadth;
}
public double getPercLensLessThan(int maxLen) {
int lens = 0;
for(Integer len : lengths)
if(len <= maxLen)
lens++;
return (double) lens / (double) lengths.size();
}
public void addPhrasalBranch(String label, int factor) {
phrasalBranching2.incrementCount(label, factor);
phrasalBranchingNum2.incrementCount(label);
}
public void display(boolean displayWords, boolean displayOOV) {
NumberFormat nf = new DecimalFormat("0.00");
System.out.println("======================================================");
System.out.println(">>> " + corpusName);
System.out.println(" trees:\t\t" + numTrees);
System.out.println(" words:\t\t" + words.keySet().size());
System.out.println(" tokens:\t" + (int) words.totalCount());
System.out.println(" tags:\t\t" + posTags.size());
System.out.println(" phrasal types:\t" + phrasalBranchingNum2.keySet().size());
System.out.println(" phrasal nodes:\t" + (int) phrasalBranchingNum2.totalCount());
System.out.println(" OOV rate:\t" + nf.format(OOVRate * 100.0) + "%");
System.out.println("======================================================");
System.out.println(">>> Per tree means");
System.out.printf(" depth:\t\t%s\t{min:%d\tmax:%d}\t\ts: %s\n",nf.format(meanDepth),minDepth,maxDepth,nf.format(stddevDepth));
System.out.printf(" breadth:\t%s\t{min:%d\tmax:%d}\ts: %s\n",nf.format(meanBreadth),minBreadth,maxBreadth,nf.format(stddevBreadth));
System.out.printf(" length:\t%s\t{min:%d\tmax:%d}\ts: %s\n",nf.format(meanLength),minLength,maxLength,nf.format(stddevLength));
System.out.println(" branching:\t" + nf.format(meanBranchingFactor));
System.out.println(" constituents:\t" + nf.format(meanConstituents));
System.out.println("======================================================");
System.out.println(">>> Branching factor means by phrasal tag:");
List<String> sortedKeys = new ArrayList<String>(meanBranchingByLabel.keySet());
Collections.sort(sortedKeys, Counters.toComparator(phrasalBranchingNum2,false,true));
for(String label : sortedKeys)
System.out.printf(" %s:\t\t%s / %d instances\n", label,nf.format(meanBranchingByLabel.getCount(label)), (int) phrasalBranchingNum2.getCount(label));
System.out.println("======================================================");
System.out.println(">>> Phrasal tag counts");
sortedKeys = new ArrayList<String>(phrasalBranchingNum2.keySet());
Collections.sort(sortedKeys, Counters.toComparator(phrasalBranchingNum2,false,true));
for(String label : sortedKeys)
System.out.println(" " + label + ":\t\t" + (int) phrasalBranchingNum2.getCount(label));
System.out.println("======================================================");
System.out.println(">>> POS tag counts");
sortedKeys = new ArrayList<String>(posTags.keySet());
Collections.sort(sortedKeys, Counters.toComparator(posTags,false,true));
for(String posTag : sortedKeys)
System.out.println(" " + posTag + ":\t\t" + (int) posTags.getCount(posTag));
System.out.println("======================================================");
if(displayWords) {
System.out.println(">>> Word counts");
sortedKeys = new ArrayList<String>(words.keySet());
Collections.sort(sortedKeys, Counters.toComparator(words,false,true));
for(String word : sortedKeys)
System.out.println(" " + word + ":\t\t" + (int) words.getCount(word));
System.out.println("======================================================");
}
if(displayOOV) {
System.out.println(">>> OOV word types");
for(String word : oovWords)
System.out.println(" " + word);
System.out.println("======================================================");
}
}
public void computeFinalValues() {
final double denom = (double) numTrees;
meanDepth = depth2 / denom;
meanLength = length2 / denom;
meanBreadth = breadth2 / denom;
meanConstituents = phrasalBranchingNum2.totalCount() / denom;
meanBranchingFactor = phrasalBranching2.totalCount() / phrasalBranchingNum2.totalCount();
//Compute *actual* stddev (we iterate over the whole population)
for(int d : depths)
stddevDepth += Math.pow(d - meanDepth, 2);
stddevDepth = Math.sqrt(stddevDepth / denom);
for(int l : lengths)
stddevLength += Math.pow(l - meanLength, 2);
stddevLength = Math.sqrt(stddevLength / denom);
for(int b : breadths)
stddevBreadth += Math.pow(b - meanBreadth, 2);
stddevBreadth = Math.sqrt(stddevBreadth / denom);
meanBranchingByLabel = new ClassicCounter<String>();
for(String label : phrasalBranching2.keySet()) {
double mean = phrasalBranching2.getCount(label) / phrasalBranchingNum2.getCount(label);
meanBranchingByLabel.incrementCount(label, mean);
}
oovWords = Generics.newHashSet(words.keySet());
oovWords.removeAll(trainVocab);
OOVRate = (double) oovWords.size() / (double) words.keySet().size();
}
//Corpus wide
public final Counter<String> words;
public final Counter<String> posTags;
private final Counter<String> phrasalBranching2;
private final Counter<String> phrasalBranchingNum2;
public int numTrees = 0;
private double depth2 = 0.0;
private double breadth2 = 0.0;
private double length2 = 0.0;
private final List<Integer> lengths;
private final List<Integer> breadths;
private final List<Integer> depths;
//Tree-level Averages
private Counter<String> meanBranchingByLabel;
private double meanDepth = 0.0;
private double stddevDepth = 0.0;
private double meanBranchingFactor = 0.0;
private double meanConstituents = 0.0;
private double meanLength = 0.0;
private double stddevLength = 0.0;
private double meanBreadth = 0.0;
private double stddevBreadth = 0.0;
private double OOVRate = 0.0;
private Set<String> oovWords;
//Mins and maxes
public int minLength = Integer.MAX_VALUE;
public int maxLength = Integer.MIN_VALUE;
public int minDepth = Integer.MAX_VALUE;
public int maxDepth = Integer.MIN_VALUE;
public int minBreadth = Integer.MAX_VALUE;
public int maxBreadth = Integer.MIN_VALUE;
}
private static final int MIN_ARGS = 2;
private static String usage() {
StringBuilder usage = new StringBuilder();
String nl = System.getProperty("line.separator");
usage.append(String.format("Usage: java %s [OPTS] LANG paths%n%n",TreebankStats.class.getName()));
usage.append("Options:").append(nl);
usage.append(" LANG is one of " + Languages.listOfLanguages()).append(nl);
usage.append(" -s prefix : Use a split (extensions must be dev/test/train)").append(nl);
usage.append(" -w : Show word distribution").append(nl);
usage.append(" -f : Path list is a set of files, and the first file is the training set").append(nl);
usage.append(" -o : Print OOV words.").append(nl);
return usage.toString();
}
private static Map<String,Integer> optArgDefs() {
Map<String,Integer> optArgDefs = Generics.newHashMap(4);
optArgDefs.put("s", 1);
optArgDefs.put("w", 0);
optArgDefs.put("f", 0);
optArgDefs.put("o", 0);
return optArgDefs;
}
/**
*
* @param args
*/
public static void main(String[] args) {
if(args.length < MIN_ARGS) {
System.err.println(usage());
System.exit(-1);
}
Properties options = StringUtils.argsToProperties(args, optArgDefs());
String splitPrefix = options.getProperty("s", null);
boolean SHOW_WORDS = PropertiesUtils.getBool(options, "w", false);
boolean pathsAreFiles = PropertiesUtils.getBool(options, "f", false);
boolean SHOW_OOV = PropertiesUtils.getBool(options, "o", false);
String[] parsedArgs = options.getProperty("","").split("\\s+");
if (parsedArgs.length != MIN_ARGS) {
System.err.println(usage());
System.exit(-1);
}
Language language = Language.valueOf(parsedArgs[0]);
List<String> corpusPaths = new ArrayList<String>(parsedArgs.length-1);
for (int i = 1; i < parsedArgs.length; ++i) {
corpusPaths.add(parsedArgs[i]);
}
TreebankLangParserParams tlpp = Languages.getLanguageParams(language);
TreebankStats cs = new TreebankStats(language,corpusPaths,tlpp);
if(splitPrefix != null) {
if(!cs.useSplit(splitPrefix)) System.err.println("Could not load split!");
}
cs.run(pathsAreFiles, SHOW_WORDS, SHOW_OOV);
}
}