package edu.stanford.nlp.parser.lexparser;
import java.io.*;
import java.util.*;
import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.stats.Distribution;
import edu.stanford.nlp.stats.GeneralizedCounter;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.international.pennchinese.RadicalMap;
import java.util.function.Function;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Timing;
/**
* @author Galen Andrew
*/
public class ChineseCharacterBasedLexicon implements Lexicon {
private final double lengthPenalty;
// penaltyType should be set as follows:
// 0: no length penalty
// 1: quadratic length penalty
// 2: penalty for continuation chars only
private final int penaltyType;
private Map<List,Distribution<Symbol>> charDistributions;
private Set<Symbol> knownChars;
private Distribution<String> POSDistribution;
private final boolean useUnknownCharacterModel;
private static final int CONTEXT_LENGTH = 2;
private final Index<String> wordIndex;
private final Index<String> tagIndex;
public ChineseCharacterBasedLexicon(ChineseTreebankParserParams params,
Index<String> wordIndex,
Index<String> tagIndex) {
this.wordIndex = wordIndex;
this.tagIndex = tagIndex;
this.lengthPenalty = params.lengthPenalty;
this.penaltyType = params.penaltyType;
this.useUnknownCharacterModel = params.useUnknownCharacterModel;
}
// We need to make two passes over the data, whereas the calling
// routines only pass in the sentences or trees once, so we keep all
// the sentences and then process them at the end
private transient List<List<TaggedWord>> trainingSentences;
@Override
public void initializeTraining(double numTrees) {
trainingSentences = new ArrayList<List<TaggedWord>>();
}
/**
* Train this lexicon on the given set of trees.
*/
@Override
public void train(Collection<Tree> trees) {
for (Tree tree : trees) {
train(tree, 1.0);
}
}
/**
* Train this lexicon on the given set of trees.
*/
@Override
public void train(Collection<Tree> trees, double weight) {
for (Tree tree : trees) {
train(tree, weight);
}
}
/**
* TODO: make this method do something with the weight
*/
@Override
public void train(Tree tree, double weight) {
trainingSentences.add(tree.taggedYield());
}
@Override
public void trainUnannotated(List<TaggedWord> sentence, double weight) {
// TODO: for now we just punt on these
throw new UnsupportedOperationException("This version of the parser does not support non-tree training data");
}
@Override
public void incrementTreesRead(double weight) {
throw new UnsupportedOperationException();
}
@Override
public void train(TaggedWord tw, int loc, double weight) {
throw new UnsupportedOperationException();
}
@Override
public void train(List<TaggedWord> sentence, double weight) {
trainingSentences.add(sentence);
}
@Override
public void finishTraining() {
Timing.tick("Counting characters...");
ClassicCounter<Symbol> charCounter = new ClassicCounter<Symbol>();
// first find all chars that occur only once
for (List<TaggedWord> labels : trainingSentences) {
for (TaggedWord label : labels) {
String word = label.word();
if (word.equals(BOUNDARY)) {
continue;
}
for (int j = 0, length = word.length(); j < length; j++) {
Symbol sym = Symbol.cannonicalSymbol(word.charAt(j));
charCounter.incrementCount(sym);
}
charCounter.incrementCount(Symbol.END_WORD);
}
}
Set<Symbol> singletons = Counters.keysBelow(charCounter, 1.5);
knownChars = Generics.newHashSet(charCounter.keySet());
Timing.tick("Counting nGrams...");
GeneralizedCounter[] POSspecificCharNGrams = new GeneralizedCounter[CONTEXT_LENGTH + 1];
for (int i = 0; i <= CONTEXT_LENGTH; i++) {
POSspecificCharNGrams[i] = new GeneralizedCounter(i + 2);
}
ClassicCounter<String> POSCounter = new ClassicCounter<String>();
List<Serializable> context = new ArrayList<Serializable>(CONTEXT_LENGTH + 1);
for (List<TaggedWord> words : trainingSentences) {
for (TaggedWord taggedWord : words) {
String word = taggedWord.word();
String tag = taggedWord.tag();
tagIndex.add(tag);
if (word.equals(BOUNDARY)) {
continue;
}
POSCounter.incrementCount(tag);
for (int i = 0, size = word.length(); i <= size; i++) {
Symbol sym;
Symbol unknownCharClass = null;
context.clear();
context.add(tag);
if (i < size) {
char thisCh = word.charAt(i);
sym = Symbol.cannonicalSymbol(thisCh);
if (singletons.contains(sym)) {
unknownCharClass = unknownCharClass(sym);
charCounter.incrementCount(unknownCharClass);
}
} else {
sym = Symbol.END_WORD;
}
POSspecificCharNGrams[0].incrementCount(context, sym); // POS-specific 1-gram
if (unknownCharClass != null) {
POSspecificCharNGrams[0].incrementCount(context, unknownCharClass); // for unknown ch model
}
// context is constructed incrementally:
// tag prevChar prevPrevChar
// this could be made faster using .sublist like in score
for (int j = 1; j <= CONTEXT_LENGTH; j++) { // poly grams
if (i - j < 0) {
context.add(Symbol.BEGIN_WORD);
POSspecificCharNGrams[j].incrementCount(context, sym);
if (unknownCharClass != null) {
POSspecificCharNGrams[j].incrementCount(context, unknownCharClass); // for unknown ch model
}
break;
} else {
Symbol prev = Symbol.cannonicalSymbol(word.charAt(i - j));
if (singletons.contains(prev)) {
context.add(unknownCharClass(prev));
} else {
context.add(prev);
}
POSspecificCharNGrams[j].incrementCount(context, sym);
if (unknownCharClass != null) {
POSspecificCharNGrams[j].incrementCount(context, unknownCharClass); // for unknown ch model
}
}
}
}
}
}
POSDistribution = Distribution.getDistribution(POSCounter);
Timing.tick("Creating character prior distribution...");
charDistributions = Generics.newHashMap();
// charDistributions = Generics.newHashMap(); // 1.5
// charCounter.incrementCount(Symbol.UNKNOWN, singletons.size());
int numberOfKeys = charCounter.size() + singletons.size();
Distribution<Symbol> prior = Distribution.goodTuringSmoothedCounter(charCounter, numberOfKeys);
charDistributions.put(Collections.EMPTY_LIST, prior);
for (int i = 0; i <= CONTEXT_LENGTH; i++) {
Set<Map.Entry<List<Serializable>, ClassicCounter<Symbol>>> counterEntries = POSspecificCharNGrams[i].lowestLevelCounterEntrySet();
Timing.tick("Creating " + counterEntries.size() + " character " + (i + 1) + "-gram distributions...");
for (Map.Entry<List<Serializable>, ClassicCounter<Symbol>> entry : counterEntries) {
context = entry.getKey();
ClassicCounter<Symbol> c = entry.getValue();
Distribution<Symbol> thisPrior = charDistributions.get(context.subList(0, context.size() - 1));
double priorWeight = thisPrior.getNumberOfKeys() / 200.0;
Distribution<Symbol> newDist = Distribution.dynamicCounterWithDirichletPrior(c, thisPrior, priorWeight);
charDistributions.put(context, newDist);
}
}
}
public Distribution<String> getPOSDistribution() {
return POSDistribution;
}
public static boolean isForeign(String s) {
for (int i = 0; i < s.length(); i++) {
int num = Character.getNumericValue(s.charAt(i));
if (num < 10 || num > 35) {
return false;
}
}
return true;
}
private Symbol unknownCharClass(Symbol ch) {
if (useUnknownCharacterModel) {
return new Symbol(Character.toString(RadicalMap.getRadical(ch.getCh()))).intern();
} else {
return Symbol.UNKNOWN;
}
}
@Override
public float score(IntTaggedWord iTW, int loc, String word, String featureSpec) {
String tag = tagIndex.get(iTW.tag);
assert !word.equals(BOUNDARY);
char[] chars = word.toCharArray();
List<Serializable> charList = new ArrayList<Serializable>(chars.length + CONTEXT_LENGTH + 1); // this starts of storing Symbol's and then starts storing String's. Clean this up someday!
// charList is constructed backward
// END_WORD char[length-1] char[length-2] ... char[0] BEGIN_WORD BEGIN_WORD
charList.add(Symbol.END_WORD);
for (int i = chars.length - 1; i >= 0; i--) {
Symbol ch = Symbol.cannonicalSymbol(chars[i]);
if (knownChars.contains(ch)) {
charList.add(ch);
} else {
charList.add(unknownCharClass(ch));
}
}
for (int i = 0; i < CONTEXT_LENGTH; i++) {
charList.add(Symbol.BEGIN_WORD);
}
double score = 0.0;
for (int i = 0, size = charList.size(); i < size - CONTEXT_LENGTH; i++) {
Symbol nextChar = (Symbol) charList.get(i);
charList.set(i, tag);
double charScore = getBackedOffDist(charList.subList(i, i + CONTEXT_LENGTH + 1)).probabilityOf(nextChar);
score += Math.log(charScore);
}
switch (penaltyType) {
case 0:
break;
case 1:
score -= (chars.length * (chars.length + 1)) * (lengthPenalty / 2);
break;
case 2:
score -= (chars.length - 1) * lengthPenalty;
break;
}
return (float) score;
}
// this is where we do backing off for unseen contexts
// (backing off for rarely seen contexts is done implicitly
// because the distributions are smoothed)
private Distribution<Symbol> getBackedOffDist(List<Serializable> context) {
// context contains [tag prevChar prevPrevChar]
for (int i = CONTEXT_LENGTH + 1; i >= 0; i--) {
List<Serializable> l = context.subList(0, i);
if (charDistributions.containsKey(l)) {
return charDistributions.get(l);
}
}
throw new RuntimeException("OOPS... no prior distribution...?");
}
/**
* Samples from the distribution over words with this POS according to the lexicon.
*
* @param tag the POS of the word to sample
* @return a sampled word
*/
public String sampleFrom(String tag) {
StringBuilder buf = new StringBuilder();
List<Serializable> context = new ArrayList<Serializable>(CONTEXT_LENGTH + 1);
// context must contain [tag prevChar prevPrevChar]
context.add(tag);
for (int i = 0; i < CONTEXT_LENGTH; i++) {
context.add(Symbol.BEGIN_WORD);
}
Distribution<Symbol> d = getBackedOffDist(context);
Symbol gen = d.sampleFrom();
genLoop:
while (gen != Symbol.END_WORD) {
buf.append(gen.getCh());
switch (penaltyType) {
case 1:
if (Math.random() > Math.pow(lengthPenalty, buf.length())) {
break genLoop;
}
break;
case 2:
if (Math.random() > lengthPenalty) {
break genLoop;
}
break;
}
for (int i = 1; i < CONTEXT_LENGTH; i++) {
context.set(i + 1, context.get(i));
}
context.set(1, gen);
d = getBackedOffDist(context);
gen = d.sampleFrom();
}
return buf.toString();
}
/**
* Samples over words regardless of POS: first samples POS, then samples
* word according to that POS
*
* @return a sampled word
*/
public String sampleFrom() {
String POS = POSDistribution.sampleFrom();
return sampleFrom(POS);
}
// don't think this should be used, but just in case...
@Override
public Iterator<IntTaggedWord> ruleIteratorByWord(int word, int loc, String featureSpec) {
throw new UnsupportedOperationException("ChineseCharacterBasedLexicon has no rule iterator!");
}
// don't think this should be used, but just in case...
@Override
public Iterator<IntTaggedWord> ruleIteratorByWord(String word, int loc, String featureSpec) {
throw new UnsupportedOperationException("ChineseCharacterBasedLexicon has no rule iterator!");
}
/** Returns the number of rules (tag rewrites as word) in the Lexicon.
* This method isn't yet implemented in this class.
* It currently just returns 0, which may or may not be helpful.
*/
@Override
public int numRules() {
return 0;
}
private Distribution<Integer> getWordLengthDistribution() {
int samples = 0;
ClassicCounter<Integer> c = new ClassicCounter<Integer>();
while (samples++ < 10000) {
String s = sampleFrom();
c.incrementCount(Integer.valueOf(s.length()));
if (samples % 1000 == 0) {
System.out.print(".");
}
}
System.out.println();
Distribution<Integer> genWordLengthDist = Distribution.getDistribution(c);
return genWordLengthDist;
}
@Override
public void readData(BufferedReader in) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public void writeData(Writer w) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public boolean isKnown(int word) {
throw new UnsupportedOperationException();
}
@Override
public boolean isKnown(String word) {
throw new UnsupportedOperationException();
}
/** {@inheritDoc} */
@Override
public Set<String> tagSet(Function<String,String> basicCategoryFunction) {
Set<String> tagSet = new HashSet<String>();
for (String tag : tagIndex.objectsList()) {
tagSet.add(basicCategoryFunction.apply(tag));
}
return tagSet;
}
static class Symbol implements Serializable {
private static final int UNKNOWN_TYPE = 0;
private static final int DIGIT_TYPE = 1;
private static final int LETTER_TYPE = 2;
private static final int BEGIN_WORD_TYPE = 3;
private static final int END_WORD_TYPE = 4;
private static final int CHAR_TYPE = 5;
private static final int UNK_CLASS_TYPE = 6;
private char ch;
private String unkClass;
int type;
public static final Symbol UNKNOWN = new Symbol(UNKNOWN_TYPE);
public static final Symbol DIGIT = new Symbol(DIGIT_TYPE);
public static final Symbol LETTER = new Symbol(LETTER_TYPE);
public static final Symbol BEGIN_WORD = new Symbol(BEGIN_WORD_TYPE);
public static final Symbol END_WORD = new Symbol(END_WORD_TYPE);
public static final Interner<Symbol> interner = new Interner<Symbol>();
public Symbol(char ch) {
type = CHAR_TYPE;
this.ch = ch;
}
public Symbol(String unkClass) {
type = UNK_CLASS_TYPE;
this.unkClass = unkClass;
}
public Symbol(int type) {
assert type != CHAR_TYPE;
this.type = type;
}
public static Symbol cannonicalSymbol(char ch) {
if (Character.isDigit(ch)) {
return DIGIT; //{ Digits.add(new Character(ch)); return DIGIT; }
}
if (Character.getNumericValue(ch) >= 10 && Character.getNumericValue(ch) <= 35) {
return LETTER; //{ Letters.add(new Character(ch)); return LETTER; }
}
return new Symbol(ch);
}
public char getCh() {
if (type == CHAR_TYPE) {
return ch;
} else {
return '*';
}
}
public Symbol intern() {
return interner.intern(this);
}
@Override
public String toString() {
if (type == CHAR_TYPE) {
return "[u" + (int) ch + "]";
} else if (type == UNK_CLASS_TYPE) {
return "UNK:" + unkClass;
} else {
return Integer.toString(type);
}
}
protected Object readResolve() throws ObjectStreamException {
switch (type) {
case CHAR_TYPE:
return intern();
case UNK_CLASS_TYPE:
return intern();
case UNKNOWN_TYPE:
return UNKNOWN;
case DIGIT_TYPE:
return DIGIT;
case LETTER_TYPE:
return LETTER;
case BEGIN_WORD_TYPE:
return BEGIN_WORD;
case END_WORD_TYPE:
return END_WORD;
default: // impossible...
throw new InvalidObjectException("ILLEGAL VALUE IN SERIALIZED SYMBOL");
}
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof Symbol)) {
return false;
}
final Symbol symbol = (Symbol) o;
if (ch != symbol.ch) {
return false;
}
if (type != symbol.type) {
return false;
}
if (unkClass != null ? !unkClass.equals(symbol.unkClass) : symbol.unkClass != null) {
return false;
}
return true;
}
@Override
public int hashCode() {
int result;
result = ch;
result = 29 * result + (unkClass != null ? unkClass.hashCode() : 0);
result = 29 * result + type;
return result;
}
private static final long serialVersionUID = 8925032621317022510L;
} // end class Symbol
private static final long serialVersionUID = -5357655683145854069L;
@Override
public UnknownWordModel getUnknownWordModel() {
// TODO Auto-generated method stub
return null;
}
@Override
public void setUnknownWordModel(UnknownWordModel uwm) {
// TODO Auto-generated method stub
}
@Override
public void train(Collection<Tree> trees, Collection<Tree> rawTrees) {
train(trees);
}
} // end class ChineseCharacterBasedLexicon