Package edu.stanford.nlp.tagger.maxent

Source Code of edu.stanford.nlp.tagger.maxent.TestSentence

/**
* Title:        StanfordMaxEnt<p>
* Description:  A Maximum Entropy Toolkit<p>
* Copyright:    Copyright (c) Kristina Toutanova<p>
* Company:      Stanford University<p>
*/

package edu.stanford.nlp.tagger.maxent;

import edu.stanford.nlp.io.EncodingPrintWriter;
import edu.stanford.nlp.io.PrintFile;
import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.ling.Sentence;
import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.maxent.iis.LambdaSolve;
import edu.stanford.nlp.sequences.BestSequenceFinder;
import edu.stanford.nlp.sequences.ExactBestSequenceFinder;
import edu.stanford.nlp.sequences.SequenceModel;
import edu.stanford.nlp.util.ArrayUtils;

import java.util.*;
import java.text.NumberFormat;
import java.text.DecimalFormat;


/**
* @author Kristina Toutanova
* @author Michel Galley
* @version 1.0
*/
public class TestSentence implements SequenceModel {

  static boolean VERBOSE = false;
  static String tagSeparator = "/"; // currently a yucky static. Remove someday.
  static int leftContext = 2;
  static int rightContext = 2;
  protected static final String eosWord = "EOS";
  protected static final String naTag = "NA";
  protected static final boolean DBG = false;
  protected static boolean doDeterministicTagExpansion = true;
  protected static int kBestSize = 1;
  protected final PairsHolder pairs = new PairsHolder();
  protected List<String> sent; // = null // DON'T NEED? = new ArrayList<String>();
  protected int size; // this always has the value of sent.size(). Remove it? [cdm 2008]
  // protected double[][][] probabilities;
  protected String[] correctTags;
  protected String[] finalTags;
  protected LambdaSolve prob;
  int numRight;
  int numWrong;
  int numUnknown;
  int numWrongUnknown;
  private int endSizePairs = 0;

  private volatile History history = new History(pairs);
  protected volatile Map<String,double[]> localScores = new HashMap<String,double[]>();
  protected volatile double[][] localContextScores;

  public TestSentence(LambdaSolve prob) {
    if(prob == null)
      throw new RuntimeException("LambdaSolve set to null.");
    this.prob = prob;
  }

  public TestSentence(LambdaSolve prob, String[] s, String[] correctTags, PrintFile pf, Dictionary wrongWords) {
    if (DBG) {
      assert(s.length == correctTags.length);
      System.err.println("Entering TestSentence(); s.length is " + s.length + "; startSizePairs is " + endSizePairs + "; endSizePairs is " + endSizePairs);
    }
    this.prob = prob;
    this.sent = new ArrayList<String>(Arrays.asList(s));
    this.size = sent.size();
    this.correctTags = correctTags;
    init();
    testTagInference(pf, wrongWords);
    if (DBG) {
      System.err.println("Exiting TestSentence(); startSizePairs is " + endSizePairs + "; endSizePairs is " + endSizePairs);
    }
  }

  /**
   * Tags the sentence s by running maxent model.  Returns a Sentence of
   * TaggedWord objects.
   *
   * @param s Input sentence (List).  This isn't changed.
   * @return Tagged sentence
   */
  public Sentence<TaggedWord> tagSentence(List<? extends HasWord> s) {
    int sz = s.size();
    List<String> sentence = new ArrayList<String>(sz + 1);
    this.sent = sentence;
    for (int j = 0; j < sz; j++) {
      sentence.add(s.get(j).word());
    }
    sent.add(eosWord);
    size = sz + 1;
    if (VERBOSE) {
      System.err.println("Sentence is " + sentence.toString());
    }
    init();
    return testTagInference();
  }


  protected void revert(int prevSize) {
    endSizePairs = prevSize;
  }

  protected void init() {
    //the eos are assumed already there
    localContextScores = new double[size][];
    for (int i = 0; i < size - 1; i++) {
      if (GlobalHolder.dict.isUnknown(sent.get(i))) {
        numUnknown++;
      }
    }
  }

  /**
   * Returns a string representation of the sentence.
   * @return tagged sentence
   */
  String getTaggedNice() {
    StringBuilder sb = new StringBuilder();
    // size - 1 means to exclude the EOS (end of string) symbol
    for (int i = 0; i < size - 1; i++) {
      sb.append(toNice(sent.get(i))).append(tagSeparator).append(toNice(finalTags[i]));
      sb.append(' ');
    }
    return sb.toString();
  }


  Sentence<TaggedWord> getTaggedSentence() {
    Sentence<TaggedWord> taggedSentence = new Sentence<TaggedWord>();
    for (int j = 0; j < size - 1; j++) {
      String tag = finalTags[j];
      TaggedWord w = new TaggedWord(sent.get(j), tag);
      taggedSentence.add(w);
    }
    return taggedSentence;
  }

  static String toNice(String s) {
    if (s == null) {
      return naTag;
    } else {
      return s;
    }
  }

  /** calculateProbs puts log probs of taggings in the probabilities array.
   *
   *  @param probabilities Array with indices sent size, k best size, numTags
   */
  protected void calculateProbs(double[][][] probabilities) {
    ArrayUtils.fill(probabilities, Double.NEGATIVE_INFINITY);
    for (int hyp = 0; hyp < kBestSize; hyp++) {
      // put the whole thing in pairs, give its beginning and end
      pairs.setSize(size);
      for (int i = 0; i < size; i++) {
        pairs.setWord(i,sent.get(i));
        pairs.setTag(i,finalTags[i]);
        //pairs.add(new WordTag(sent.get(i),finalTags[i]));
        // TODO: if kBestSize > 1, use KBestSequenceFinder and save
        // k-best hypotheses into finalTags:
        //pairs.setTag(i,finalTags[i]);
      }
      int start = endSizePairs;
      int end = endSizePairs + size - 1;
      endSizePairs = endSizePairs + size;
      // iterate over the sentence
      for (int current = 0; current < size; current++) {
        History h = new History(start, end, current + start, pairs);
        String[] tags = stringTagsAt(h.current - h.start + leftWindow());
        double[] probs = getHistories(tags, h);
        ArrayMath.logNormalize(probs);

        // System.err.println("word: " + pairs.getWord(current));
        // System.err.println("tags: " + Arrays.asList(tags));
        // System.err.println("probs: " + ArrayMath.toString(probs));

        for (int j = 0; j < tags.length; j++) {
          // score the j-th tag
          String tag = tags[j];
          boolean approximate = GlobalHolder.defaultScore > 0.0;
          int tagindex = approximate ? GlobalHolder.tags.getIndex(tag) : j;
          // System.err.println("Mapped from j="+ j + " " + tag + " to " + tagindex);
          probabilities[current][hyp][tagindex] = probs[j];
        }
      } // for current
    } // for hyp
    // clean up the stuff in PairsHolder (added by cdm in Aug 2008)
    revert(0);
  } // end calculateProbs()


  /** Write the tagging and note any errors (if pf != null) and accumulate
   *  global statistics.
   *
   *  @param finalTags Chosen tags for sentence
   *  @param pf File to write tagged output too (can be null, then no output;
   *               at present it is non-null iff the debug property is set)
   *  @param wrongWords Dictionary to accumulate wrong word counts in (cannot be null)
   */
  protected void writeTagsAndErrors(String[] finalTags, PrintFile pf, Dictionary wrongWords) {
    //without the EOS word
    for (int i = 0; i < correctTags.length - 1; i++) {
      if (pf != null) {
        pf.print(toNice(sent.get(i)));
        pf.print('_');
        pf.print(finalTags[i]);
      }
      if ((correctTags[i]).equals(finalTags[i])) {
        numRight++;
      } else {
        numWrong++;
        if (pf != null) pf.print('|' + correctTags[i]);
        // TODO: This writes the below output in utf8, not the user's specified encoding
        EncodingPrintWriter.out.println("Word: " + sent.get(i) + "; correct: " + correctTags[i] + "; guessed: " + finalTags[i]);

        wrongWords.add(sent.get(i) + correctTags[i], finalTags[i]);
        if (GlobalHolder.dict.isUnknown(sent.get(i))) {
          numWrongUnknown++;
          if (pf != null) pf.print("*");
        }// if
      }// else
      if (pf != null) pf.print(" ");
    }// for
    if (pf != null) pf.println();
  }


  // Test using (exact Viterbi) TagInference.
  private void testTagInference(PrintFile pf, Dictionary wrongWords) {
    runTagInference();
    writeTagsAndErrors(finalTags, pf, wrongWords);
  }


  /**
   * Test using (exact Viterbi) TagInference.
   *
   * @return The tagged sentence
   */
  private Sentence<TaggedWord> testTagInference() {
    runTagInference();

    ArrayList<TaggedWord> taggedWords = new ArrayList<TaggedWord>();
    // leave out EOS
    for (int j = 0, len = size - 1; j < len; j++) {
      String tag = finalTags[j];
      TaggedWord w = new TaggedWord(sent.get(j), tag);
      taggedWords.add(w);
    }

    return new Sentence<TaggedWord>(taggedWords);
  }

  private void runTagInference() {
    this.initializeScorer(sent);

    BestSequenceFinder ti = new ExactBestSequenceFinder();
      //new BeamBestSequenceFinder(50);
      //new KBestSequenceFinder()
    int[] bestTags = ti.bestSequence(this);
    finalTags = new String[bestTags.length];
    for (int j = 0; j < size; j++) {
      finalTags[j] = GlobalHolder.tags.getTag(bestTags[j + leftWindow()]);
    }
    cleanUpScorer();
  }


  // This is used for Dan's tag inference methods.
  // current is the actual word number + leftW
  private void setHistory(int current, History h, int[] tags) {
    //writes over the tags in the last thing in pairs

    int left = leftWindow();
    int right = rightWindow();

    for (int j = current - left; j <= current + right; j++) {
      if (j < left) {
        continue;
      } //but shouldn't happen
      if (j >= size + left) {
        break;
      } //but shouldn't happen
      h.setTag(j - left, GlobalHolder.tags.getTag(tags[j]));
    }
  }

  // do initializations for the TagScorer interface
  protected void initializeScorer(List<String> sentence) {
    this.sent = sentence;
    this.size = sent.size();
    pairs.setSize(size);
    for (int i = 0; i < size; i++)
      pairs.setWord(i,sent.get(i));
    endSizePairs += size;
  }


  /**
   * clean-up after the scorer
   */
  protected void cleanUpScorer() {
    revert(0);
  }

  protected static String[] append(String[] tags, String word) {
    if (doDeterministicTagExpansion) {
      return GlobalHolder.tags.deterministicallyExpandTags(tags, word);
    } else {
      return tags;
    }
  }

  // This scores the current assignment in PairsHolder at
  // current position h.current (returns normalized scores)
  private double[] getScores(History h) {
    if (GlobalHolder.defaultScore > 0) {
      return getApproximateScores(h);
    }
    return getExactScores(h);
  }

  private double[] getExactScores(History h) {
    String[] tags = stringTagsAt(h.current - h.start + leftWindow());
    double[] histories = getHistories(tags, h); // log score for each tag
    ArrayMath.logNormalize(histories);
    double[] scores = new double[tags.length];
    for (int j = 0; j < tags.length; j++) {
      // score the j-th tag
      String tag = tags[j];
      int tagindex = GlobalHolder.tags.getIndex(tag);
      scores[j] = histories[tagindex];
    }
    return scores;
  }

  // In this method, each tag that is incompatible with the current word
  // (e.g., apple_CC) gets a default (constant) score instead of its exact score.
  // The scores of all other tags are computed exactly.
  private double[] getApproximateScores(History h) {
    String[] tags = stringTagsAt(h.current - h.start + leftWindow());
    double[] scores = getHistories(tags, h); // log score for each active tag, unnormalized

    // Number of tags that get assigned a default score:
    double nDefault = GlobalHolder.ySize - tags.length;
    double logScore = ArrayMath.logSum(scores);
    double logScoreInactiveTags = Math.log(nDefault*GlobalHolder.defaultScore);
    double logTotal =
      ArrayMath.logSum(new double[] {logScore, logScoreInactiveTags});
    ArrayMath.addInPlace(scores, -logTotal);

    return scores;
  }

  // This precomputes scores of local features (localScores).
  protected double[] getHistories(String[] tags, History h) {
      boolean rare = GlobalHolder.isRare(ExtractorFrames.cWord.extract(h));
      //boolean rare = false;//GlobalHolder.isRare(ExtractorFrames.cWord.extract(h));
    Extractors ex = GlobalHolder.extractors, exR = GlobalHolder.extractorsRare;
    //Extractors ex = GlobalHolder.extractors, exR = GlobalHolder.extractors;
    String w = pairs.getWord(h.current);
    double[] lS, lcS;
    if((lS = localScores.get(w)) == null) {
      lS = getHistories(tags, h, ex.local, rare ? exR.local : null);
      localScores.put(w,lS);
    }
    if((lcS = localContextScores[h.current]) == null) {
      lcS = getHistories(tags, h, ex.localContext, rare ? exR.localContext : null);
      localContextScores[h.current] = lcS;
      ArrayMath.pairwiseAddInPlace(lcS,lS);
    }
    double[] totalS = getHistories(tags, h, ex.dynamic, rare ? exR.dynamic : null);
    ArrayMath.pairwiseAddInPlace(totalS,lcS);
    return totalS;
  }

  private double[] getHistories(String[] tags, History h, Map<Integer,Extractor> extractors, Map<Integer,Extractor> extractorsRare) {
    if(GlobalHolder.defaultScore > 0)
      return getApproximateHistories(tags, h, extractors, extractorsRare);
    return getExactHistories(h, extractors, extractorsRare);
  }

  private double[] getExactHistories(History h, Map<Integer,Extractor> extractors, Map<Integer,Extractor> extractorsRare) {
    double[] scores = new double[GlobalHolder.ySize];
    FeatureKey s = new FeatureKey();
    int szCommon = GlobalHolder.extractors.getSize();

    for(Map.Entry<Integer,Extractor> e : extractors.entrySet()) {
      int kf = e.getKey();
      Extractor ex = e.getValue();
      for (int i = 0; i < GlobalHolder.ySize; i++) {
        String tag = GlobalHolder.tags.getTag(i);
        s.set(kf, ex.extract(h), tag);
        int fNum = GlobalHolder.getNum(s);
        if (fNum > -1) {
          scores[i] += prob.lambda[fNum];
        }
      }
    }
    if(extractorsRare != null) {
        //System.out.println("USING THE RARE EXTRACTOR");
      for(Map.Entry<Integer,Extractor> e : extractorsRare.entrySet()) {
        int kf = e.getKey();
        Extractor ex = e.getValue();
        for (int i = 0; i < GlobalHolder.ySize; i++) {
          String tag = GlobalHolder.tags.getTag(i);
          s.set(szCommon+kf, ex.extract(h), tag);
          int fNum = GlobalHolder.getNum(s);
          if (fNum > -1) {
            scores[i] += prob.lambda[fNum];
          } // end for
        }
      }
    } else{
        //System.out.println("NOT USING THE RARE EXTRACTOR");
    }
    return scores;
  }

  // Returns an unnormalized score (in log space) for each tag
  private double[] getApproximateHistories(String[] tags, History h, Map<Integer,Extractor> extractors, Map<Integer,Extractor> extractorsRare) {

    double[] scores = new double[tags.length];
    FeatureKey s = new FeatureKey();
    int szCommon = GlobalHolder.extractors.getSize();

    for(Map.Entry<Integer,Extractor> e : extractors.entrySet()) {
      int kf = e.getKey();
      Extractor ex = e.getValue();
      for (int j = 0; j < tags.length; j++) {
        String tag = tags[j];
        s.set(kf, ex.extract(h), tag);
        int fNum = GlobalHolder.getNum(s);
        if (fNum > -1) {
          scores[j] += prob.lambda[fNum];
        }
      }
    }
    if(extractorsRare != null) {
        //System.out.println("USING THE RARE EXTRACTOR");
      for(Map.Entry<Integer,Extractor> e : extractorsRare.entrySet()) {
        int kf = e.getKey();
        Extractor ex = e.getValue();
        for (int j = 0; j < tags.length; j++) {
          String tag = tags[j];
          s.set(szCommon+kf, ex.extract(h), tag);
          int fNum = GlobalHolder.getNum(s);
          if (fNum > -1) {
            scores[j] += prob.lambda[fNum];
          } // end for
        }
      }
    } else {
        //System.out.println("NOT USING THE RARE EXTRACTOR");
    }
    return scores;
  }


  /**
   * This method should be called after the sentence has been tagged.
   * For every unknown word, this method prints the 3 most probable tags
   * to the file pfu.
   *
   * @param numSent The sentence number
   * @param pfu The file to print the probable tags to
   */
  void printUnknown(int numSent, PrintFile pfu) {
    NumberFormat nf = new DecimalFormat("0.0000");
    int numTags = GlobalHolder.tags.getSize();
    double[][][] probabilities = new double[size][kBestSize][numTags];
    calculateProbs(probabilities);
    for (int current = 0; current < size; current++) {
      if (GlobalHolder.dict.isUnknown(sent.get(current))) {
        pfu.print(sent.get(current));
        pfu.print(':');
        pfu.print(numSent);
        double[] probs = new double[3];
        String[] tag3 = new String[3];
        getTop3(probabilities, current, probs, tag3);
        for (int i = 0; i < 3; i++) {
          if (probs[i] > Double.NEGATIVE_INFINITY) {
            pfu.print('\t');
            pfu.print(tag3[i]);
            pfu.print(' ');
            pfu.print(nf.format(Math.exp(probs[i])));
          }
        }
        int rank;
        String correctTag = toNice(this.correctTags[current]);
        for (rank = 0; rank < 3; rank++) {
          if (correctTag.equals(tag3[rank])) {
            break;
          } //if
        }
        pfu.print('\t');
        switch (rank) {
          case 0:
            pfu.print("Correct");
            break;
          case 1:
            pfu.print("2nd");
            break;
          case 2:
            pfu.print("3rd");
            break;
          default:
            pfu.print("Not top 3");
        }
        pfu.println();
      }// if
    }// for
  }

  // This method should be called after a sentence has been tagged.
  // For every word token, this method prints the 3 most probable tags
  // to the file pfu except for
  void printTop(PrintFile pfu) {
    NumberFormat nf = new DecimalFormat("0.0000");
    int numTags = GlobalHolder.tags.getSize();
    double[][][] probabilities = new double[size][kBestSize][numTags];
    calculateProbs(probabilities);
    for (int current = 0; current < size; current++) {
      pfu.print(sent.get(current));
      double[] probs = new double[3];
      String[] tag3 = new String[3];
      getTop3(probabilities, current, probs, tag3);
      for (int i = 0; i < 3; i++) {
        if (probs[i] > Double.NEGATIVE_INFINITY) {
          pfu.print('\t');
          pfu.print(tag3[i]);
          pfu.print(' ');
          pfu.print(nf.format(Math.exp(probs[i])));
        }
      }
      int rank;
      String correctTag = toNice(this.correctTags[current]);
      for (rank = 0; rank < 3; rank++) {
        if (correctTag.equals(tag3[rank])) {
          break;
        } //if
      }
      pfu.print('\t');
      switch (rank) {
      case 0:
        pfu.print("Correct");
        break;
      case 1:
        pfu.print("2nd");
        break;
      case 2:
        pfu.print("3rd");
        break;
      default:
        pfu.print("Not top 3");
      }
      pfu.println();
    } // for
  }

  /** probs and tags should be passed in as arrays of size 3!
   *  If probs[i] == Double.NEGATIVE_INFINITY, then the entry should be ignored.
   */
  private static void getTop3(double[][][] probabilities, int current, double[] probs, String[] tags) {
    int[] topIds = new int[3];
    double[] probTags = probabilities[current][0];
    Arrays.fill(probs, Double.NEGATIVE_INFINITY);
    for (int i = 0; i < probTags.length; i++) {
      if (probTags[i] > probs[0]) {
        probs[2] = probs[1];
        probs[1] = probs[0];
        probs[0] = probTags[i];
        topIds[2] = topIds[1];
        topIds[1] = topIds[0];
        topIds[0] = i;
      } else if (probTags[i] > probs[1]) {
        probs[2] = probs[1];
        probs[1] = probTags[i];
        topIds[2] = topIds[1];
        topIds[1] = i;
      } else if (probTags[i] > probs[2]) {
        probs[2] = probTags[i];
        topIds[2] = i;
      }
    }
    for (int j = 0; j < 3; j++) {
      tags[j] = toNice(GlobalHolder.tags.getTag(topIds[j]));
    }
  }

  /*
   * Implementation of the TagScorer interface follows
   */

  public int length() {
    return sent.size();
  }

  public int leftWindow() {
    return leftContext; //hard-code for now
  }

  public int rightWindow() {
    return rightContext; //hard code for now
  }


  public int[] getPossibleValues(int pos) {
    String[] arr1 = stringTagsAt(pos);
    // for (int i=0; i < arr1.length; i++) {
    //    System.err.println("TAG " + pos+" " + arr1 + " " + a);
    // }
    int[] arr = new int[arr1.length];
    for (int i = 0; i < arr.length; i++) {
      arr[i] = GlobalHolder.tags.getIndex(arr1[i]);
    }

    return arr;
  }

  public double scoreOf(int[] tags, int pos) {
    double[] scores = scoresOf(tags, pos);
    double score = Double.NEGATIVE_INFINITY;
    int[] pv = getPossibleValues(pos);
    for (int i = 0; i < scores.length; i++) {
      if (pv[i] == tags[pos]) {
        score = scores[i];
      }
    }
    return score;
  }

  public double scoreOf(int[] sequence) {
    throw new UnsupportedOperationException();
  }

  public double[] scoresOf(int[] tags, int pos) {
    if (DBG) {
      System.err.println("scoresOf(): length of tags is " + tags.length + "; position is " + pos + "; endSizePairs = " + endSizePairs + "; size is " + size + "; leftWindow is " + leftWindow());
      System.err.println("  History h = new History(" + (endSizePairs - size) + ", " + (endSizePairs - 1) + ", " + (endSizePairs - size + pos - leftWindow()) + ")");
    }
    history.init(endSizePairs - size, endSizePairs - 1, endSizePairs - size + pos - leftWindow());
    setHistory(pos, history, tags);
    return getScores(history);
  }

  protected String[] stringTagsAt(int pos) {
    String[] arr1;
    if ((pos < leftWindow()) || (pos >= size + leftWindow())) {
      arr1 = new String[1];
      arr1[0] = naTag;
      return arr1;
    }
    //if (GlobalHolder.isRare(sent.get(pos - leftWindow()))) {
        //arr1 = GlobalHolder.tags.getOpenTags().toArray(new String[GlobalHolder.tags.getOpenTags().size()]);
   
    if (!GlobalHolder.isSeen(sent.get(pos-leftWindow()))) {
      arr1 = GlobalHolder.tags.index.toArray(new String[GlobalHolder.tags.index.size()]);
    } else {
      arr1 = GlobalHolder.dict.getTags(sent.get(pos - leftWindow()));
    }
    arr1 = append(arr1, sent.get(pos - leftWindow()));
    return arr1;
  }
}
TOP

Related Classes of edu.stanford.nlp.tagger.maxent.TestSentence

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.