package edu.stanford.nlp.parser.lexparser;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import edu.stanford.nlp.international.Languages;
import edu.stanford.nlp.international.Languages.Language;
import edu.stanford.nlp.international.arabic.ArabicMorphoFeatureSpecification;
import edu.stanford.nlp.international.french.FrenchMorphoFeatureSpecification;
import edu.stanford.nlp.international.morph.MorphoFeatureSpecification;
import edu.stanford.nlp.international.morph.MorphoFeatureSpecification.MorphoFeatureType;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.Label;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.stats.TwoDimensionalIntCounter;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
/**
*
* @author Spence Green
*
*/
public class FactoredLexicon extends BaseLexicon {
private static final long serialVersionUID = -744693222804176489L;
private static final boolean DEBUG = false;
private MorphoFeatureSpecification morphoSpec;
private static final String NO_MORPH_ANALYSIS = "xXxNONExXx";
private Index<String> morphIndex = new HashIndex<String>();
private TwoDimensionalIntCounter<Integer,Integer> wordTag = new TwoDimensionalIntCounter<Integer,Integer>(40000);
private Counter<Integer> wordTagUnseen = new ClassicCounter<Integer>(500);
private TwoDimensionalIntCounter<Integer,Integer> lemmaTag = new TwoDimensionalIntCounter<Integer,Integer>(40000);
private Counter<Integer> lemmaTagUnseen = new ClassicCounter<Integer>(500);
private TwoDimensionalIntCounter<Integer,Integer> morphTag = new TwoDimensionalIntCounter<Integer,Integer>(500);
private Counter<Integer> morphTagUnseen = new ClassicCounter<Integer>(500);
private Counter<Integer> tagCounter = new ClassicCounter<Integer>(300);
public FactoredLexicon(MorphoFeatureSpecification morphoSpec, Index<String> wordIndex, Index<String> tagIndex) {
super(wordIndex, tagIndex);
this.morphoSpec = morphoSpec;
}
public FactoredLexicon(Options op, MorphoFeatureSpecification morphoSpec, Index<String> wordIndex, Index<String> tagIndex) {
super(op, wordIndex, tagIndex);
this.morphoSpec = morphoSpec;
}
/**
* Rule table is lemmas. So isKnown() is slightly trickier.
*/
@Override
public Iterator<IntTaggedWord> ruleIteratorByWord(int word, int loc, String featureSpec) {
if (word == wordIndex.indexOf(BOUNDARY)) {
// Deterministic tagging of the boundary symbol
return rulesWithWord[word].iterator();
} else if (isKnown(word)) {
// Strict lexical tagging for seen *lemma* types
// We need to copy the word form into the rules, which currently have lemmas in them
return rulesWithWord[word].iterator();
} else {
if (DEBUG) System.err.println("UNKNOWN WORD");
// Unknown word signatures
Set<IntTaggedWord> lexRules = Generics.newHashSet(10);
List<IntTaggedWord> uwRules = rulesWithWord[wordIndex.indexOf(UNKNOWN_WORD)];
// Inject the word into these rules instead of the UW signature
for (IntTaggedWord iTW : uwRules) {
lexRules.add(new IntTaggedWord(word, iTW.tag));
}
return lexRules.iterator();
}
}
@Override
public float score(IntTaggedWord iTW, int loc, String word, String featureSpec) {
final int wordId = iTW.word();
final int tagId = iTW.tag();
// Force 1-best path to go through the boundary symbol
// (deterministic tagging)
final int boundaryId = wordIndex.indexOf(BOUNDARY);
final int boundaryTagId = tagIndex.indexOf(BOUNDARY_TAG);
if (wordId == boundaryId && tagId == boundaryTagId) {
return 0.0f;
}
// Morphological features
String tag = tagIndex.get(iTW.tag());
Pair<String,String> lemmaMorph = MorphoFeatureSpecification.splitMorphString(word, featureSpec);
String lemma = lemmaMorph.first();
int lemmaId = wordIndex.indexOf(lemma);
String richMorphTag = lemmaMorph.second();
String reducedMorphTag = morphoSpec.strToFeatures(richMorphTag).toString().trim();
reducedMorphTag = reducedMorphTag.length() == 0 ? NO_MORPH_ANALYSIS : reducedMorphTag;
int morphId = morphIndex.addToIndex(reducedMorphTag);
// Score the factors and create the rule score p_W_T
double p_W_Tf = Math.log(probWordTag(word, loc, wordId, tagId));
// double p_L_T = Math.log(probLemmaTag(word, loc, tagId, lemmaId));
double p_L_T = 0.0;
double p_M_T = Math.log(probMorphTag(tagId, morphId));
double p_W_T = p_W_Tf + p_L_T + p_M_T;
if (DEBUG) {
// String tag = tagIndex.get(tagId);
System.err.printf("WSGDEBUG: %s --> %s %s %s || %.10f (%.5f / %.5f / %.5f)%n", tag, word, lemma,
reducedMorphTag, p_W_T, p_W_Tf, p_L_T, p_M_T);
}
// Filter low probability taggings
return p_W_T > -100.0 ? (float) p_W_T : Float.NEGATIVE_INFINITY;
}
private double probWordTag(String word, int loc, int wordId, int tagId) {
double cW = wordTag.totalCount(wordId);
double cWT = wordTag.getCount(wordId, tagId);
// p_L
double p_W = cW / wordTag.totalCount();
// p_T
double cTseen = tagCounter.getCount(tagId);
double p_T = cTseen / tagCounter.totalCount();
// p_T_L
double p_W_T = 0.0;
if (cW > 0.0) { // Seen lemma
double p_T_W = 0.0;
if (cW > 100.0 && cWT > 0.0) {
p_T_W = cWT / cW;
} else {
double cTunseen = wordTagUnseen.getCount(tagId);
// TODO p_T_U is 0?
double p_T_U = cTunseen / wordTagUnseen.totalCount();
p_T_W = (cWT + smooth[1]*p_T_U) / (cW + smooth[1]);
}
p_W_T = p_T_W * p_W / p_T;
} else { // Unseen word. Score based on the word signature (of the surface form)
IntTaggedWord iTW = new IntTaggedWord(wordId, tagId);
double c_T = tagCounter.getCount(tagId);
p_W_T = Math.exp(getUnknownWordModel().score(iTW, loc, c_T, tagCounter.totalCount(), smooth[0], word));
}
return p_W_T;
}
/**
* This method should never return 0!!
*/
private double probLemmaTag(String word, int loc, int tagId, int lemmaId) {
double cL = lemmaTag.totalCount(lemmaId);
double cLT = lemmaTag.getCount(lemmaId, tagId);
// p_L
double p_L = cL / lemmaTag.totalCount();
// p_T
double cTseen = tagCounter.getCount(tagId);
double p_T = cTseen / tagCounter.totalCount();
// p_T_L
double p_L_T = 0.0;
if (cL > 0.0) { // Seen lemma
double p_T_L = 0.0;
if (cL > 100.0 && cLT > 0.0) {
p_T_L = cLT / cL;
} else {
double cTunseen = lemmaTagUnseen.getCount(tagId);
// TODO(spenceg): p_T_U is 0??
double p_T_U = cTunseen / lemmaTagUnseen.totalCount();
p_T_L = (cLT + smooth[1]*p_T_U) / (cL + smooth[1]);
}
p_L_T = p_T_L * p_L / p_T;
} else { // Unseen lemma. Score based on the word signature (of the surface form)
// Hack
double cTunseen = lemmaTagUnseen.getCount(tagId);
p_L_T = cTunseen / tagCounter.totalCount();
// int wordId = wordIndex.indexOf(word);
// IntTaggedWord iTW = new IntTaggedWord(wordId, tagId);
// double c_T = tagCounter.getCount(tagId);
// p_L_T = Math.exp(getUnknownWordModel().score(iTW, loc, c_T, tagCounter.totalCount(), smooth[0], word));
}
return p_L_T;
}
/**
* This method should never return 0!
*/
private double probMorphTag(int tagId, int morphId) {
double cM = morphTag.totalCount(morphId);
double cMT = morphTag.getCount(morphId, tagId);
// p_M
double p_M = cM / morphTag.totalCount();
// p_T
double cTseen = tagCounter.getCount(tagId);
double p_T = cTseen / tagCounter.totalCount();
double p_M_T = 0.0;
if (cM > 100.0 && cMT > 0.0) {
double p_T_M = cMT / cM;
// else {
// double cTunseen = morphTagUnseen.getCount(tagId);
// double p_T_U = cTunseen / morphTagUnseen.totalCount();
// p_T_M = (cMT + smooth[1]*p_T_U) / (cM + smooth[1]);
// }
p_M_T = p_T_M * p_M / p_T;
} else { // Unseen morphological analysis
// Hack....unseen morph tags are extremely rare
// Add+1 smoothing
p_M_T = 1.0 / (morphTag.totalCount() + tagIndex.size() + 1.0);
}
return p_M_T;
}
/**
* This method should populate wordIndex, tagIndex, and morphIndex.
*/
@Override
public void train(Collection<Tree> trees, Collection<Tree> rawTrees) {
double weight = 1.0;
// Train uw model on words
uwModelTrainer.train(trees, weight);
final double numTrees = trees.size();
Iterator<Tree> rawTreesItr = rawTrees == null ? null : rawTrees.iterator();
Iterator<Tree> treeItr = trees.iterator();
// Train factored lexicon on lemmas and morph tags
int treeId = 0;
while (treeItr.hasNext()) {
Tree tree = treeItr.next();
// CoreLabels, with morph analysis in the originalText annotation
List<Label> yield = rawTrees == null ? tree.yield() : rawTreesItr.next().yield();
// Annotated, binarized tree for the tags (labels are usually CategoryWordTag)
List<Label> pretermYield = tree.preTerminalYield();
int yieldLen = yield.size();
for (int i = 0; i < yieldLen; ++i) {
String word = yield.get(i).value();
int wordId = wordIndex.addToIndex(word); // Don't do anything with words
String tag = pretermYield.get(i).value();
int tagId = tagIndex.addToIndex(tag);
// Use the word as backup if there is no lemma
String featureStr = ((CoreLabel) yield.get(i)).originalText();
Pair<String,String> lemmaMorph = MorphoFeatureSpecification.splitMorphString(word, featureStr);
String lemma = lemmaMorph.first();
int lemmaId = wordIndex.addToIndex(lemma);
String richMorphTag = lemmaMorph.second();
String reducedMorphTag = morphoSpec.strToFeatures(richMorphTag).toString().trim();
reducedMorphTag = reducedMorphTag.isEmpty() ? NO_MORPH_ANALYSIS : reducedMorphTag;
int morphId = morphIndex.addToIndex(reducedMorphTag);
// Seen event counts
wordTag.incrementCount(wordId, tagId);
lemmaTag.incrementCount(lemmaId, tagId);
morphTag.incrementCount(morphId, tagId);
tagCounter.incrementCount(tagId);
// Unseen event counts
if (treeId > op.trainOptions.fractionBeforeUnseenCounting*numTrees) {
if (! wordTag.firstKeySet().contains(wordId) || wordTag.getCounter(wordId).totalCount() < 2) {
wordTagUnseen.incrementCount(tagId);
}
if (! lemmaTag.firstKeySet().contains(lemmaId) || lemmaTag.getCounter(lemmaId).totalCount() < 2) {
lemmaTagUnseen.incrementCount(tagId);
}
if (! morphTag.firstKeySet().contains(morphId) || morphTag.getCounter(morphId).totalCount() < 2) {
morphTagUnseen.incrementCount(tagId);
}
}
}
++treeId;
if (DEBUG && (treeId % 100) == 0) {
System.err.printf("[%d]",treeId);
}
if (DEBUG && (treeId % 10000) == 0) {
System.err.println();
}
}
}
/**
* Rule table is lemmas!
*/
@Override
protected void initRulesWithWord() {
// Add synthetic symbols to the indices
int unkWord = wordIndex.addToIndex(UNKNOWN_WORD);
int boundaryWordId = wordIndex.addToIndex(BOUNDARY);
int boundaryTagId = tagIndex.addToIndex(BOUNDARY_TAG);
// Initialize rules table
final int numWords = wordIndex.size();
rulesWithWord = new List[numWords];
for (int w = 0; w < numWords; w++) {
rulesWithWord[w] = new ArrayList<IntTaggedWord>(1);
}
// Collect rules, indexed by word
Set<IntTaggedWord> lexRules = Generics.newHashSet(40000);
for (int wordId : wordTag.firstKeySet()) {
for (int tagId : wordTag.getCounter(wordId).keySet()) {
lexRules.add(new IntTaggedWord(wordId, tagId));
lexRules.add(new IntTaggedWord(nullWord, tagId));
}
}
// Known words and signatures
for (IntTaggedWord iTW : lexRules) {
if (iTW.word() == nullWord) {
// Mix in UW signature rules for open class types
double types = uwModel.unSeenCounter().getCount(iTW);
if (types > trainOptions.openClassTypesThreshold) {
IntTaggedWord iTU = new IntTaggedWord(unkWord, iTW.tag);
if (!rulesWithWord[unkWord].contains(iTU)) {
rulesWithWord[unkWord].add(iTU);
}
}
} else {
// Known word
rulesWithWord[iTW.word].add(iTW);
}
}
System.err.print("The " + rulesWithWord[unkWord].size() + " open class tags are: [");
for (IntTaggedWord item : rulesWithWord[unkWord]) {
System.err.print(" " + tagIndex.get(item.tag()));
}
System.err.println(" ] ");
// Boundary symbol has one tagging
rulesWithWord[boundaryWordId].add(new IntTaggedWord(boundaryWordId, boundaryTagId));
}
/**
* Convert a treebank to factored lexicon events for fast iteration in the
* optimizer.
*/
private static List<FactoredLexiconEvent> treebankToLexiconEvents(List<Tree> treebank,
FactoredLexicon lexicon) {
List<FactoredLexiconEvent> events = new ArrayList<FactoredLexiconEvent>(70000);
for (Tree tree : treebank) {
List<Label> yield = tree.yield();
List<Label> preterm = tree.preTerminalYield();
assert yield.size() == preterm.size();
int yieldLen = yield.size();
for (int i = 0; i < yieldLen; ++i) {
String tag = preterm.get(i).value();
int tagId = lexicon.tagIndex.indexOf(tag);
String word = yield.get(i).value();
int wordId = lexicon.wordIndex.indexOf(word);
// Two checks to see if we keep this example
if (tagId < 0) {
System.err.println("Discarding training example: " + word + " " + tag);
continue;
}
// if (counts.probWordTag(wordId, tagId) == 0.0) {
// System.err.println("Discarding low counts <w,t> pair: " + word + " " + tag);
// continue;
// }
String featureStr = ((CoreLabel) yield.get(i)).originalText();
Pair<String,String> lemmaMorph = MorphoFeatureSpecification.splitMorphString(word, featureStr);
String lemma = lemmaMorph.first();
String richTag = lemmaMorph.second();
String reducedTag = lexicon.morphoSpec.strToFeatures(richTag).toString();
reducedTag = reducedTag.length() == 0 ? NO_MORPH_ANALYSIS : reducedTag;
int lemmaId = lexicon.wordIndex.indexOf(lemma);
int morphId = lexicon.morphIndex.indexOf(reducedTag);
FactoredLexiconEvent event = new FactoredLexiconEvent(wordId, tagId, lemmaId, morphId, i, word, featureStr);
events.add(event);
}
}
return events;
}
private static List<FactoredLexiconEvent> getTuningSet(Treebank devTreebank,
FactoredLexicon lexicon, TreebankLangParserParams tlpp) {
List<Tree> devTrees = new ArrayList<Tree>(3000);
for (Tree tree : devTreebank) {
for (Tree subTree : tree) {
if (!subTree.isLeaf()) {
tlpp.transformTree(subTree, tree);
}
}
devTrees.add(tree);
}
List<FactoredLexiconEvent> tuningSet = treebankToLexiconEvents(devTrees, lexicon);
return tuningSet;
}
private static Options getOptions(Language language) {
Options options = new Options();
if (language.equals(Language.Arabic)) {
options.lexOptions.useUnknownWordSignatures = 9;
options.lexOptions.unknownPrefixSize = 1;
options.lexOptions.unknownSuffixSize = 1;
options.lexOptions.uwModelTrainer = "edu.stanford.nlp.parser.lexparser.ArabicUnknownWordModelTrainer";
} else if (language.equals(Language.French)) {
options.lexOptions.useUnknownWordSignatures = 1;
options.lexOptions.unknownPrefixSize = 1;
options.lexOptions.unknownSuffixSize = 2;
options.lexOptions.uwModelTrainer = "edu.stanford.nlp.parser.lexparser.FrenchUnknownWordModelTrainer";
} else {
throw new UnsupportedOperationException();
}
return options;
}
/**
* @param args
*/
public static void main(String[] args) {
if (args.length != 4) {
System.err.printf("Usage: java %s language features train_file dev_file%n", FactoredLexicon.class.getName());
System.exit(-1);
}
// Command line options
Language language = Language.valueOf(args[0]);
TreebankLangParserParams tlpp = Languages.getLanguageParams(language);
Treebank trainTreebank = tlpp.diskTreebank();
trainTreebank.loadPath(args[2]);
Treebank devTreebank = tlpp.diskTreebank();
devTreebank.loadPath(args[3]);
MorphoFeatureSpecification morphoSpec;
Options options = getOptions(language);
if (language.equals(Language.Arabic)) {
morphoSpec = new ArabicMorphoFeatureSpecification();
String[] languageOptions = {"-arabicFactored"};
tlpp.setOptionFlag(languageOptions, 0);
} else if (language.equals(Language.French)) {
morphoSpec = new FrenchMorphoFeatureSpecification();
String[] languageOptions = {"-frenchFactored"};
tlpp.setOptionFlag(languageOptions, 0);
} else {
throw new UnsupportedOperationException();
}
String featureList = args[1];
String[] features = featureList.trim().split(",");
for (String feature : features) {
morphoSpec.activate(MorphoFeatureType.valueOf(feature));
}
System.out.println("Language: " + language.toString());
System.out.println("Features: " + args[1]);
// Create word and tag indices
// Save trees in a collection since the interface requires that....
System.out.print("Loading training trees...");
List<Tree> trainTrees = new ArrayList<Tree>(19000);
Index<String> wordIndex = new HashIndex<String>();
Index<String> tagIndex = new HashIndex<String>();
for (Tree tree : trainTreebank) {
for (Tree subTree : tree) {
if (!subTree.isLeaf()) {
tlpp.transformTree(subTree, tree);
}
}
trainTrees.add(tree);
}
System.out.printf("Done! (%d trees)%n", trainTrees.size());
// Setup and train the lexicon.
System.out.print("Collecting sufficient statistics for lexicon...");
FactoredLexicon lexicon = new FactoredLexicon(options, morphoSpec, wordIndex, tagIndex);
lexicon.initializeTraining(trainTrees.size());
lexicon.train(trainTrees, null);
lexicon.finishTraining();
System.out.println("Done!");
trainTrees = null;
// Load the tuning set
System.out.print("Loading tuning set...");
List<FactoredLexiconEvent> tuningSet = getTuningSet(devTreebank, lexicon, tlpp);
System.out.printf("...Done! (%d events)%n", tuningSet.size());
// Print the probabilities that we obtain
// TODO(spenceg): Implement tagging accuracy with FactLex
int nCorrect = 0;
Counter<String> errors = new ClassicCounter<String>();
for (FactoredLexiconEvent event : tuningSet) {
Iterator<IntTaggedWord> itr = lexicon.ruleIteratorByWord(event.word(), event.getLoc(), event.featureStr());
Counter<Integer> logScores = new ClassicCounter<Integer>();
boolean noRules = true;
int goldTagId = -1;
while (itr.hasNext()) {
noRules = false;
IntTaggedWord iTW = itr.next();
if (iTW.tag() == event.tagId()) {
System.err.print("GOLD-");
goldTagId = iTW.tag();
}
float tagScore = lexicon.score(iTW, event.getLoc(), event.word(), event.featureStr());
logScores.incrementCount(iTW.tag(), tagScore);
}
if (noRules) {
System.err.printf("NO TAGGINGS: %s %s%n", event.word(), event.featureStr());
} else {
// Score the tagging
int hypTagId = Counters.argmax(logScores);
if (hypTagId == goldTagId) {
++nCorrect;
} else {
String goldTag = goldTagId < 0 ? "UNSEEN" : lexicon.tagIndex.get(goldTagId);
errors.incrementCount(goldTag);
}
}
System.err.println();
}
// Output accuracy
double acc = (double) nCorrect / (double) tuningSet.size();
System.err.printf("%n%nACCURACY: %.2f%n%n", acc*100.0);
System.err.println("% of errors by type:");
List<String> biggestKeys = new ArrayList<String>(errors.keySet());
Collections.sort(biggestKeys, Counters.toComparator(errors, false, true));
Counters.normalize(errors);
for (String key : biggestKeys) {
System.err.printf("%s\t%.2f%n", key, errors.getCount(key)*100.0);
}
}
}