Package edu.stanford.nlp.stats

Source Code of edu.stanford.nlp.stats.MultiClassChunkEvalStats

package edu.stanford.nlp.stats;

import edu.stanford.nlp.classify.Classifier;
import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.pipeline.LabeledChunkIdentifier;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.StringUtils;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Properties;

/**
* Calculates phrase based precision and recall (similar to conlleval)
* Handles various encodings such as IO, IOB, IOE, BILOU, SBEIO, []
*
* Usage: java edu.stanford.nlp.stats.MultiClassChunkEvalStats [options] < filename
*        -r - Do raw token based evaluation
*        -d <delimiter> - Specifies delimiter to use (instead of tab)
*        -b <boundary> - Boundary token (default is -X- )
*        -t <defaultTag> - Default tag to use if tag is not prefixed (i.e. is not X-xxx )
*        -ignoreProvidedTag - Discards the provided tag (i.e. if label is X-xxx, just use xxx for evaluation)
* @author Angel Chang
*/
public class MultiClassChunkEvalStats extends MultiClassPrecisionRecallExtendedStats.MultiClassStringLabelStats {
  private boolean inCorrect = false;
  private LabeledChunkIdentifier.LabelTagType prevCorrect = null;
  private LabeledChunkIdentifier.LabelTagType prevGuess = null;
  private LabeledChunkIdentifier chunker;
  private boolean useLabel = false;

  public <F> MultiClassChunkEvalStats(Classifier<String,F> classifier, GeneralDataset<String,F> data, String negLabel)
  {
    super(classifier, data, negLabel);
    chunker = new LabeledChunkIdentifier();
    chunker.setNegLabel(negLabel);
  }

  public MultiClassChunkEvalStats(String negLabel)
  {
    super(negLabel);
    chunker = new LabeledChunkIdentifier();
    chunker.setNegLabel(negLabel);
  }

  public MultiClassChunkEvalStats(Index<String> dataLabelIndex, String negLabel)
  {
    super(dataLabelIndex, negLabel);
    chunker = new LabeledChunkIdentifier();
    chunker.setNegLabel(negLabel);
  }

  public LabeledChunkIdentifier getChunker()
  {
    return chunker;
  }

  public void clearCounts()
  {
    super.clearCounts();
    inCorrect = false;
    prevCorrect = null;
    prevGuess = null;
  }

  protected void finalizeCounts()
  {
    markBoundary();
    super.finalizeCounts();
  }

  private String getTypeLabel(LabeledChunkIdentifier.LabelTagType tagType)
  {
    if (useLabel) return tagType.label;
    else return tagType.type;
  }
  protected void markBoundary()
  {
    if (inCorrect) {
      inCorrect=false;
      correctGuesses.incrementCount(getTypeLabel(prevCorrect));
    }
    prevGuess = null;
    prevCorrect = null;
  }

  protected void addGuess(String guess, String trueLabel, boolean addUnknownLabels)
  {
    LabeledChunkIdentifier.LabelTagType guessTagType = chunker.getTagType(guess);
    LabeledChunkIdentifier.LabelTagType correctTagType = chunker.getTagType(trueLabel);
    addGuess(guessTagType, correctTagType, addUnknownLabels);
  }

  protected void addGuess(LabeledChunkIdentifier.LabelTagType guess,
                          LabeledChunkIdentifier.LabelTagType correct, boolean addUnknownLabels)
  {
    if (addUnknownLabels) {
      if (labelIndex == null) {
        labelIndex = new HashIndex<String>();
      }
      labelIndex.add(getTypeLabel(guess));
      labelIndex.add(getTypeLabel(correct));
    }
    if (inCorrect) {
       boolean prevCorrectEnded = chunker.isEndOfChunk(prevCorrect, correct);
       boolean prevGuessEnded = chunker.isEndOfChunk(prevGuess, guess);
       if (prevCorrectEnded && prevGuessEnded && prevGuess.typeMatches(prevCorrect)) {
         inCorrect=false;
         correctGuesses.incrementCount(getTypeLabel(prevCorrect));
       } else if (prevCorrectEnded != prevGuessEnded || !guess.typeMatches(correct)) {
         inCorrect=false;
       }
    }

    boolean correctStarted = chunker.isStartOfChunk(prevCorrect, correct);
    boolean guessStarted = chunker.isStartOfChunk(prevGuess, guess);
    if ( correctStarted && guessStarted && guess.typeMatches(correct)) {
      inCorrect = true;
    }

    if ( correctStarted ) {
      foundCorrect.incrementCount(getTypeLabel(correct));
    }
    if ( guessStarted ) {
      foundGuessed.incrementCount(getTypeLabel(guess));
    }

    if (chunker.isIgnoreProvidedTag()) {
      if (guess.typeMatches(correct)) {
        tokensCorrect++;
      }
    } else {
      if (guess.label.equals(correct.label)) {
        tokensCorrect++;
      }
    }
   
    tokensCount++;
    prevGuess = guess;
    prevCorrect = correct;
  }

  // Returns string precision recall in ConllEval format
  public String getConllEvalString()
  {
    return getConllEvalString(true);
  }

  public static void main(String[] args) throws Exception
  {
    StringUtils.printErrInvocationString("MultiClassChunkEvalStats", args);
    Properties props = StringUtils.argsToProperties(args);
    String boundary = props.getProperty("b","-X-");
    String delimiter = props.getProperty("d","\t");
    String defaultPosTag = props.getProperty("t", "I");
    boolean raw = Boolean.valueOf(props.getProperty("r","false"));
    boolean ignoreProvidedTag = Boolean.valueOf(props.getProperty("ignoreProvidedTag","false"));
    String format = props.getProperty("format", "conll");
    String filename = props.getProperty("i");
    String backgroundLabel = props.getProperty("k", "O");
    try {
      MultiClassPrecisionRecallExtendedStats stats;
      if (raw) {
        stats = new MultiClassStringLabelStats(backgroundLabel);
      } else {
        MultiClassChunkEvalStats mstats = new MultiClassChunkEvalStats(backgroundLabel);
        mstats.getChunker().setDefaultPosTag(defaultPosTag);
        mstats.getChunker().setIgnoreProvidedTag(ignoreProvidedTag);
        stats = mstats;
      }
      if (filename != null) {
        stats.score(filename, delimiter, boundary);
      } else {
        stats.score(new BufferedReader(new InputStreamReader(System.in)), delimiter, boundary);
      }
      if ("conll".equalsIgnoreCase(format)) {
        System.out.println(stats.getConllEvalString());
      } else {
        System.out.println(stats.getDescription(6));       
      }
    } catch (IOException ex) {
      System.err.println("Error processing file: " + ex.toString());
      ex.printStackTrace(System.err);
    }
  }
}
TOP

Related Classes of edu.stanford.nlp.stats.MultiClassChunkEvalStats

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.