Package edu.stanford.nlp.parser.lexparser

Source Code of edu.stanford.nlp.parser.lexparser.ChineseMaxentLexicon

package edu.stanford.nlp.parser.lexparser;

import edu.stanford.nlp.classify.LinearClassifier;
import edu.stanford.nlp.classify.LinearClassifierFactory;
import edu.stanford.nlp.classify.WeightedDataset;
import edu.stanford.nlp.io.NumberRangesFileFilter;
import edu.stanford.nlp.ling.BasicDatum;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.optimization.QNMinimizer;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.trees.TreebankLanguagePack;
import edu.stanford.nlp.util.*;
import edu.stanford.nlp.stats.*;

import java.io.*;
import java.util.*;
import java.util.function.Function;
import java.util.regex.Pattern;

/**
* A Lexicon class that computes the score of word|tag according to a maxent model
* of tag|word (divided by MLE estimate of P(tag)).
* <p/>
* It would be nice to factor out a superclass MaxentLexicon that takes a WordFeatureExtractor
*
* @author Galen Andrew
*/
public class ChineseMaxentLexicon implements Lexicon {

  private static final long serialVersionUID = 238834703409896852L;
  private static final boolean verbose = true;
  public static final boolean seenTagsOnly = false;
  private ChineseWordFeatureExtractor featExtractor;
  public static final boolean fixUnkFunctionWords = false;

  private static final Pattern wordPattern = Pattern.compile(".*-W");
  private static final Pattern charPattern = Pattern.compile(".*-.C");
  private static final Pattern bigramPattern = Pattern.compile(".*-.B");
  private static final Pattern conjPattern = Pattern.compile(".*&&.*");

  private final Pair<Pattern, Integer> wordThreshold = new Pair<Pattern, Integer>(wordPattern, 0);
  private final Pair<Pattern, Integer> charThreshold = new Pair<Pattern, Integer>(charPattern, 2);
  private final Pair<Pattern, Integer> bigramThreshold = new Pair<Pattern, Integer>(bigramPattern, 3);
  private final Pair<Pattern, Integer> conjThreshold = new Pair<Pattern, Integer>(conjPattern, 3);

  private final List<Pair<Pattern, Integer>> featureThresholds = new ArrayList<Pair<Pattern, Integer>>();
  private final int universalThreshold = 0;

  private LinearClassifier scorer;
  private Map<String, String> functionWordTags = Generics.newHashMap();
  private Distribution<String> tagDist;
  private final Index<String> wordIndex;
  private final Index<String> tagIndex;
  private transient Counter<String> logProbs;
  private double iteratorCutoffFactor = 4;
  private transient int lastWord = -1;
  String initialWeightFile = null;
  boolean trainFloat = false;
  private static final String featureDir = "gbfeatures";

  private double tol = 1e-4;
  private double sigma = 0.4;

  static final boolean tuneSigma = false;
  static final int trainCountThreshold = 5;
  final int featureLevel;
  static final int DEFAULT_FEATURE_LEVEL = 2;
  private boolean trainOnLowCount = false;
  private boolean trainByType = false;
  private final TreebankLangParserParams tlpParams;
  private final TreebankLanguagePack ctlp;
  private final Options op;

  public boolean isKnown(int word) {
    return isKnown(wordIndex.get(word));
  }

  public boolean isKnown(String word) {
    return tagsForWord.containsKey(word);
  }

  /** {@inheritDoc} */
  @Override
  public Set<String> tagSet(Function<String,String> basicCategoryFunction) {
    Set<String> tagSet = new HashSet<String>();
    for (String tag : tagIndex.objectsList()) {
      tagSet.add(basicCategoryFunction.apply(tag));
    }
    return tagSet;
  }


  private void ensureProbs(int word) {
    ensureProbs(word, true);
  }

  private void ensureProbs(int word, boolean subtractTagScore) {
    if (word == lastWord) {
      return;
    }
    lastWord = word;
    if (functionWordTags.containsKey(wordIndex.get(word))) {
      logProbs = new ClassicCounter<String>();
      String trueTag = functionWordTags.get(wordIndex.get(word));
      for (String tag : tagIndex.objectsList()) {
        if (ctlp.basicCategory(tag).equals(trueTag)) {
          logProbs.setCount(tag, 0);
        } else {
          logProbs.setCount(tag, Double.NEGATIVE_INFINITY);
        }
      }
      return;
    }
    Datum datum = new BasicDatum(featExtractor.makeFeatures(wordIndex.get(word)));
    logProbs = scorer.logProbabilityOf(datum);
    if (subtractTagScore) {
      Set<String> tagSet = logProbs.keySet();
      for (String tag : tagSet) {
        logProbs.incrementCount(tag, -Math.log(tagDist.probabilityOf(tag)));
      }
    }
  }

  public CollectionValuedMap<String, String> tagsForWord = new CollectionValuedMap<String, String>();

  public Iterator<IntTaggedWord> ruleIteratorByWord(int word, int loc, String featureSpec) {
    ensureProbs(word);
    List<IntTaggedWord> rules = new ArrayList<IntTaggedWord>();
    if (seenTagsOnly) {
      String wordString = wordIndex.get(word);
      Collection<String> tags = tagsForWord.get(wordString);
      for (String tag : tags) {
        rules.add(new IntTaggedWord(wordString, tag, wordIndex, tagIndex));
      }
    } else {
      double max = Counters.max(logProbs);
      for (int tag = 0; tag < tagIndex.size(); tag++) {
        IntTaggedWord iTW = new IntTaggedWord(word, tag);
        double score = logProbs.getCount(tagIndex.get(tag));
        if (score > max - iteratorCutoffFactor) {
          rules.add(iTW);
        }
      }
    }
    return rules.iterator();
  }

  public Iterator<IntTaggedWord> ruleIteratorByWord(String word, int loc, String featureSpec) {
    return ruleIteratorByWord(wordIndex.indexOf(word), loc, featureSpec);
  }

  /** Returns the number of rules (tag rewrites as word) in the Lexicon.
   *  This method isn't yet implemented in this class.
   *  It currently just returns 0, which may or may not be helpful.
   */
  public int numRules() {
    int accumulated = 0;
    for (int w = 0, tot = wordIndex.size(); w < tot; w++) {
      Iterator<IntTaggedWord> iter = ruleIteratorByWord(w, 0, null);
      while (iter.hasNext()) {
        iter.next();
        accumulated++;
      }
    }
    return accumulated;
  }

  private String getTag(String word) {
    int iW = wordIndex.addToIndex(word);
    ensureProbs(iW, false);
    return Counters.argmax(logProbs);
  }


  private void verbose(String s) {
    if (verbose) {
      System.err.println(s);
    }
  }

  public ChineseMaxentLexicon(Options op, Index<String> wordIndex, Index<String> tagIndex, int featureLevel) {
    this.op = op;
    this.tlpParams = op.tlpParams;
    this.ctlp = op.tlpParams.treebankLanguagePack();;
    this.wordIndex = wordIndex;
    this.tagIndex = tagIndex;
    this.featureLevel = featureLevel;
    if (fixUnkFunctionWords) {
      String filename = "unknown_function_word-simple.gb";
      try {
        BufferedReader in = new BufferedReader(new InputStreamReader(new FileInputStream(filename), "GB18030"));
        for (String line = in.readLine(); line != null; line = in.readLine()) {
          String[] parts = line.split("\\s+", 2);
          functionWordTags.put(parts[0], parts[1]);
        }
      } catch (IOException e) {
        throw new RuntimeException("Couldn't read function word file " + filename);
      }
    }
  }

  // only used at training time
  transient IntCounter<TaggedWord> datumCounter;

  @Override
  public void initializeTraining(double numTrees) {
    verbose("Training ChineseMaxentLexicon.");
    verbose("trainOnLowCount = " + trainOnLowCount + ", trainByType = " + trainByType + ", featureLevel = " + featureLevel + ", tuneSigma = " + tuneSigma);
    verbose("Making dataset...");

    if (featExtractor == null) {
      featExtractor = new ChineseWordFeatureExtractor(featureLevel);
    }

    this.datumCounter = new IntCounter<TaggedWord>();
  }

  /**
   * Add the given collection of trees to the statistics counted.  Can
   * be called multiple times with different trees.
   */
  public final void train(Collection<Tree> trees) {
    train(trees, 1.0);
  }

  /**
   * Add the given collection of trees to the statistics counted.  Can
   * be called multiple times with different trees.
   */
  @Override
  public void train(Collection<Tree> trees, double weight) {
    for (Tree tree : trees) {
      train(tree, weight);
    }
  }

  /**
   * Add the given tree to the statistics counted.  Can
   * be called multiple times with different trees.
   */
  @Override
  public void train(Tree tree, double weight) {
    train(tree.taggedYield(), weight);
  }

  /**
   * Add the given sentence to the statistics counted.  Can
   * be called multiple times with different sentences.
   */
  @Override
  public void train(List<TaggedWord> sentence, double weight) {
    featExtractor.train(sentence, weight);
    for (TaggedWord word : sentence) {
      datumCounter.incrementCount(word, weight);
      tagsForWord.add(word.word(), word.tag());
    }
  }

  @Override
  public void trainUnannotated(List<TaggedWord> sentence, double weight) {
    // TODO: for now we just punt on these
    throw new UnsupportedOperationException("This version of the parser does not support non-tree training data");
  }

  @Override
  public void incrementTreesRead(double weight) {
    throw new UnsupportedOperationException();
  }

  @Override
  public void train(TaggedWord tw, int loc, double weight) {
    throw new UnsupportedOperationException();
  }

  @Override
  public void finishTraining() {
    IntCounter<String> tagCounter = new IntCounter<String>();

    WeightedDataset data = new WeightedDataset(datumCounter.size());

    for (TaggedWord word : datumCounter.keySet()) {
      int count = datumCounter.getIntCount(word);
      if (trainOnLowCount && count > trainCountThreshold) {
        continue;
      }
      if (functionWordTags.containsKey(word.word())) {
        continue;
      }
      tagCounter.incrementCount(word.tag());
      if (trainByType) {
        count = 1;
      }
      data.add(new BasicDatum(featExtractor.makeFeatures(word.word()), word.tag()), count);
    }
    datumCounter = null;

    tagDist = Distribution.laplaceSmoothedDistribution(tagCounter, tagCounter.size(), 0.5);
    tagCounter = null;
    applyThresholds(data);

    verbose("Making classifier...");
    QNMinimizer minim = new QNMinimizer();//new ResultStoringMonitor(5, "weights"));
//    minim.shutUp();

    LinearClassifierFactory factory = new LinearClassifierFactory(minim);
    factory.setTol(tol);
    factory.setSigma(sigma);
    if (tuneSigma) {
      factory.setTuneSigmaHeldOut();
    }
    scorer = factory.trainClassifier(data);

    verbose("Done training.");
  }

  private void applyThresholds(WeightedDataset data) {
    if (wordThreshold.second > 0) {
      featureThresholds.add(wordThreshold);
    }
    if (featExtractor.chars && charThreshold.second > 0) {
      featureThresholds.add(charThreshold);
    }
    if (featExtractor.bigrams && bigramThreshold.second > 0) {
      featureThresholds.add(bigramThreshold);
    }
    if ((featExtractor.conjunctions || featExtractor.mildConjunctions) && conjThreshold.second > 0) {
      featureThresholds.add(conjThreshold);
    }

    int types = data.numFeatureTypes();
    if (universalThreshold > 0) {
      data.applyFeatureCountThreshold(universalThreshold);
    }
    if (featureThresholds.size() > 0) {
      data.applyFeatureCountThreshold(featureThresholds);
    }
    int numRemoved = types - data.numFeatureTypes();
    if (numRemoved > 0) {
      verbose("Thresholding removed " + numRemoved + " features.");
    }
  }

  public static void main(String[] args) {
    TreebankLangParserParams tlpParams = new ChineseTreebankParserParams();
    TreebankLanguagePack ctlp = tlpParams.treebankLanguagePack();
    Options op = new Options(tlpParams);
    TreeAnnotator ta = new TreeAnnotator(tlpParams.headFinder(), tlpParams, op);

    System.err.println("Reading Trees...");
    FileFilter trainFilter = new NumberRangesFileFilter(args[1], true);
    Treebank trainTreebank = tlpParams.memoryTreebank();
    trainTreebank.loadPath(args[0], trainFilter);

    System.err.println("Annotating trees...");
    Collection<Tree> trainTrees = new ArrayList<Tree>();
    for (Tree tree : trainTreebank) {
      trainTrees.add(ta.transformTree(tree));
    }
    trainTreebank = null; // saves memory

    System.err.println("Training lexicon...");

    Index<String> wordIndex = new HashIndex<String>();
    Index<String> tagIndex = new HashIndex<String>();
    int featureLevel = DEFAULT_FEATURE_LEVEL;
    if (args.length > 3) {
      featureLevel = Integer.parseInt(args[3]);
    }
    ChineseMaxentLexicon lex = new ChineseMaxentLexicon(op, wordIndex, tagIndex, featureLevel);
    lex.initializeTraining(trainTrees.size());
    lex.train(trainTrees);
    lex.finishTraining();

    System.err.print("Testing");

    FileFilter testFilter = new NumberRangesFileFilter(args[2], true);
    Treebank testTreebank = tlpParams.memoryTreebank();
    testTreebank.loadPath(args[0], testFilter);
    List<TaggedWord> testWords = new ArrayList<TaggedWord>();
    for (Tree t : testTreebank) {
      for (TaggedWord tw : t.taggedYield()) {
        testWords.add(tw);
      }
      //testWords.addAll(t.taggedYield());
    }
    int[] totalAndCorrect = lex.testOnTreebank(testWords);

    System.err.println("done.");
    System.out.println(totalAndCorrect[1] + " correct out of " + totalAndCorrect[0] + " -- ACC: " + ((double) totalAndCorrect[1]) / totalAndCorrect[0]);
  }

  private int[] testOnTreebank(Collection<TaggedWord> testWords) {
    int[] totalAndCorrect = new int[2];
    totalAndCorrect[0] = 0;
    totalAndCorrect[1] = 0;
    for (TaggedWord word : testWords) {
      String goldTag = word.tag();
      String guessTag = ctlp.basicCategory(getTag(word.word()));
      totalAndCorrect[0]++;
      if (goldTag.equals(guessTag)) {
        totalAndCorrect[1]++;
      }
    }
    return totalAndCorrect;
  }

  public float score(IntTaggedWord iTW, int loc, String word, String featureSpec) {
    ensureProbs(iTW.word());
    double max = Counters.max(logProbs);
    double score = logProbs.getCount(iTW.tagString(tagIndex));
    if (score > max - iteratorCutoffFactor) {
      return (float) score;
    } else {
      return Float.NEGATIVE_INFINITY;
    }
  }


  public void writeData(Writer w) throws IOException {
    throw new UnsupportedOperationException();
  }

  public void readData(BufferedReader in) throws IOException {
    throw new UnsupportedOperationException();
  }

  public UnknownWordModel getUnknownWordModel() {
    // TODO Auto-generated method stub
    return null;
  }

  public void setUnknownWordModel(UnknownWordModel uwm) {
    // TODO Auto-generated method stub

  }

  @Override
  public void train(Collection<Tree> trees, Collection<Tree> rawTrees) {
    train(trees);

  }




}
TOP

Related Classes of edu.stanford.nlp.parser.lexparser.ChineseMaxentLexicon

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.