/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.extract;
import java.io.*;
import java.text.DecimalFormat;
import java.util.List;
import cc.mallet.fst.CRF;
import cc.mallet.fst.MaxLattice;
import cc.mallet.fst.MaxLatticeDefault;
import cc.mallet.fst.SumLatticeDefault;
import cc.mallet.fst.Transducer;
import cc.mallet.types.*;
/**
* Created: Oct 31, 2004
*
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: LatticeViewer.java,v 1.1 2007/10/22 21:37:44 mccallum Exp $
*/
public class LatticeViewer {
private static final int FEATURE_CUTOFF_PCT = 25;
private static final int LENGTH = 10;
static void lattice2html (PrintStream out, ExtorInfo info)
{
PrintWriter writer = new PrintWriter (new OutputStreamWriter (out), true);
lattice2html (writer, info);
}
// if lattice == null, no alpha, beta values printed
static void lattice2html (PrintWriter out, ExtorInfo info)
{
assert (info.target.size() == info.predicted.size());
assert (info.input.size() == info.predicted.size());
int N = info.target.size();
for (int start = 0; start < N; start += LENGTH - 1) {
int end = Math.min (N, start + LENGTH);
if (!allSeqMatches (info.predicted, info.target, start, end)) {
error2html (out, info, start, end);
}
}
}
private static void writeHeader (PrintWriter out)
{
out.println ("<html><head><title>ERROR OUTPUT</title>\n<link rel=\"stylesheet\" href=\"errors.css\" type=\"text/css\" />\n</head><body>");
}
private static void writeFooter (PrintWriter out)
{
out.println ("</body></html>");
}
// Display HTML for one error
private static void error2html (PrintWriter out, ExtorInfo info, int start, int end)
{
String anchor = info.idx+":"+start+":"+end;
out.println ("<p><A NAME=\""+anchor+"\">");
out.println ("<p>Instance "+info.desc+" Position "+start+"..."+end);
if (info.link != null) {
out.println ("<a href=\""+info.link+"#"+anchor+"\">[Lattice]</a>");
}
out.println ("</p>");
out.println ("<table>");
outputIndices (out, start, end);
outputInputRow (out, info.input, start, end);
outputTableRow (out, "target", info.target, info.predicted, start, end);
outputTableRow (out, "predicted", info.predicted, info.target, start, end);
if (info.lattice != null) {
outputLatticeRows (out, info.lattice, start, end);
outputTransitionCosts (out, info, start, end);
outputFeatures (out, info.fvs, info.predicted, info.target, start, end);
}
out.println ("</table>");
}
public static int numMaxViterbi = 5;
private static void outputLatticeRows (PrintWriter out, MaxLattice lattice, int start, int end)
{
DecimalFormat f = new DecimalFormat ("0.##");
Transducer ducer = lattice.getTransducer ();
int max = Math.min (numMaxViterbi, ducer.numStates());
List<Sequence<Transducer.State>> stateSequences = lattice.bestStateSequences(max);
for (int k = 0; k < max; k++) {
out.println (" <tr class=\"delta\">");
out.println (" <td class=\"label\">δ rank "+k+"</td>");
for (int ip = start; ip < end; ip++) {
Transducer.State state = stateSequences.get(k).get(ip+1);
if (state.getName().equals (lattice.bestOutputSequence().get(ip))) {
out.print ("<td class=\"viterbi\">");
} else {
out.print ("<td>");
}
out.print (state.getName()+"<br />"+f.format (-lattice.getDelta (ip+1, state.getIndex ()))+"</td>");
}
out.println ("</tr>");
}
}
private static int numFeaturesToDisplay = 5;
public static int getNumFeaturesToDisplay ()
{
return numFeaturesToDisplay;
}
public static void setNumFeaturesToDisplay (int numFeaturesToDisplay)
{
LatticeViewer.numFeaturesToDisplay = numFeaturesToDisplay;
}
private static void outputTransitionCosts (PrintWriter out, ExtorInfo info, int start, int end)
{
Transducer ducer = info.lattice.getTransducer ();
out.println ("<tr class=\"predtrans\">");
out.println ("<td class=\"label\">Cost(pred. trans)</td>");
for (int ip = start; ip < end; ip++) {
if (ip == 0) {
out.println ("<td></td>");
continue;
}
Transducer.State from = ((CRF) ducer).getState (info.bestStates.get (ip - 1).toString ());
Transducer.TransitionIterator iter = from.transitionIterator (info.fvs, ip, info.predicted, ip);
if (iter.hasNext ()) {
iter.next ();
double cost = iter.getWeight();
String str = iter.describeTransition ((int) (Math.abs(cost) / FEATURE_CUTOFF_PCT));
out.print ("<td>" + str + "</td>");
} else {
out.print ("<td>No matching transition</td>");
}
}
out.println ("</tr>");
out.println ("<tr class=\"targettrans\">");
out.println ("<td class=\"label\">Cost(target trans)</td>");
for (int ip = start; ip < end; ip++) {
if (ip == 0) {
out.println ("<td></td>");
continue;
}
if (!seqMatches (info.predicted, info.target, ip) || !seqMatches (info.predicted, info.target, ip - 1)) {
Transducer.State from = ((CRF) ducer).getState (info.target.get (ip - 1).toString ());
if (from == null) {
out.println ("<td colspan='"+(end-start)+"'>Could not find state for "+info.target.get(ip-1)+"</td>");
} else {
Transducer.TransitionIterator iter = from.transitionIterator (info.fvs, ip, info.target, ip);
if (iter.hasNext ()) {
iter.next ();
double cost = iter.getWeight();
String str = iter.describeTransition ((int) (Math.abs(cost) / FEATURE_CUTOFF_PCT));
out.print ("<td>" + str + "</td>");
} else {
out.print ("<td>No matching transition</td>");
}
}
} else {
out.print ("<td></td>");
}
}
out.println ("</tr>");
out.println ("<tr class=\"predtargettrans\">");
out.println ("<td class=\"label\">Cost (pred->target trans)</td>");
for (int ip = start; ip < end; ip++) {
if (ip == 0) {
out.println ("<td></td>");
continue;
}
if (!seqMatches (info.predicted, info.target, ip) || !seqMatches (info.predicted, info.target, ip - 1)) {
Transducer.State from = ((CRF) ducer).getState (info.bestStates.get (ip - 1).toString ());
Transducer.TransitionIterator iter = from.transitionIterator (info.fvs, ip, info.target, ip);
if (iter.hasNext ()) {
iter.next ();
double cost = iter.getWeight();
String str = iter.describeTransition ((int) (Math.abs(cost) / FEATURE_CUTOFF_PCT));
out.print ("<td>" + str + "</td>");
} else {
out.print ("<td>No matching transition</td>");
}
} else {
out.print ("<td></td>");
}
}
out.println ("</tr>");
}
private static void outputLatticeRows (PrintWriter out, SumLatticeDefault lattice, int start, int end)
{
DecimalFormat f = new DecimalFormat ("0.##");
Transducer ducer = lattice.getTransducer ();
for (int k = 0; k < ducer.numStates(); k++) {
Transducer.State state = ducer.getState (k);
out.println (" <tr class=\"alpha\">");
out.println (" <td class=\"label\">α("+state.getName()+")</td>");
for (int ip = start; ip < end; ip++) {
out.print ("<td>"+f.format (lattice.getAlpha (ip+1, state))+"</td>");
}
out.println ("</tr>");
}
for (int k = 0; k < ducer.numStates(); k++) {
Transducer.State state = ducer.getState (k);
out.println (" <tr class=\"beta\">");
out.println (" <td class=\"label\">β("+state.getName()+")</td>");
for (int ip = start; ip < end; ip++) {
out.print ("<td>"+f.format (lattice.getBeta (ip+1, state))+"</td>");
}
out.println ("</tr>");
}
for (int k = 0; k < ducer.numStates(); k++) {
Transducer.State state = ducer.getState (k);
out.println (" <tr class=\"gamma\">");
out.println (" <td class=\"label\">γ("+state.getName()+")</td>");
for (int ip = start; ip < end; ip++) {
out.print ("<td>"+f.format (lattice.getGammaWeight(ip+1, state))+"</td>");
}
out.println ("</tr>");
}
}
private static void outputInputRow (PrintWriter out, TokenSequence input, int start, int end)
{
out.println (" <tr class=\"input\">");
out.println (" <td class=\"label\"></td>");
for (int ip = start; ip < end; ip++) {
out.print ("<td>"+input.get(ip).getText()+"</td>");
}
out.println (" </tr>");
}
private static void outputIndices (PrintWriter out, int start, int end)
{
out.println (" <tr class=\"indices\">");
out.println (" <td class=\"label\"></td>");
for (int ip = start; ip < end; ip++) {
out.print ("<td>"+ip+"</td>");
}
out.println (" </tr>");
}
private static void outputTableRow (PrintWriter out, String cssClass, Sequence seq1, Sequence seq2, int start, int end)
{
out.println (" <tr class=\""+cssClass+"\">");
out.println (" <td class=\"label\">"+cssClass+"</td>");
for (int i = start; i < end; i++) {
if (seqMatches (seq1, seq2, i)) {
out.print ("<td>");
} else {
out.print ("<td class=\"error\">");
}
out.print (seq1.get(i));
out.print ("</td>");
}
out.println (" </tr>");
}
private static void outputFeatures (PrintWriter out, FeatureVectorSequence fvs, Sequence in, Sequence output, int start, int end)
{
out.println (" <tr class=\"features\">\n<td class=\"label\">Features</td>");
for (int i = start; i < end; i++) {
if (!seqMatches (in, output, i)) {
out.print ("<td>");
FeatureVector fv = fvs.getFeatureVector (i);
for (int k = 0; k < fv.numLocations (); k++) {
out.print (fv.getAlphabet ().lookupObject (fv.indexAtLocation (k)));
if (fv.valueAtLocation (k) != 1.0) {
out.print (" "+fv.valueAtLocation (k));
}
out.println ("<br />");
}
out.println ("</td>");
} else {
out.println ("<td></td>");
}
}
out.println (" </tr>");
}
private static boolean seqMatches (Sequence seq1, Sequence seq2, int i)
{
return seq1.get(i).toString().equals (seq2.get(i).toString());
}
private static boolean allSeqMatches (Sequence seq1, Sequence seq2, int start, int end)
{
for (int i = start; i < end; i++) {
if (!seqMatches (seq1, seq2, i)) return false;
}
return true;
}
public static void extraction2html (Extraction extraction, CRFExtractor extor, PrintStream out)
{
PrintWriter writer = new PrintWriter (new OutputStreamWriter (out), true);
extraction2html (extraction, extor, out, false);
}
public static void extraction2html (Extraction extraction, CRFExtractor extor, PrintWriter out)
{
extraction2html (extraction, extor, out, false);
}
public static void extraction2html (Extraction extraction, CRFExtractor extor, PrintStream out, boolean showLattice)
{
PrintWriter writer = new PrintWriter (new OutputStreamWriter (out), true);
extraction2html (extraction, extor, writer, showLattice);
}
public static void extraction2html (Extraction extraction, CRFExtractor extor, PrintWriter out, boolean showLattice)
{
writeHeader (out);
for (int i = 0; i < extraction.getNumDocuments (); i++) {
DocumentExtraction docextr = extraction.getDocumentExtraction (i);
String desc = docextr.getName();
String doc = ((CharSequence) docextr.getDocument ()).toString();
ExtorInfo info = infoForDoc (doc, desc, "N"+i, docextr, extor, showLattice);
if (!showLattice) info.link = "lattice.html";
lattice2html (out, info);
}
writeFooter (out);
}
private static class ExtorInfo {
TokenSequence input;
Sequence predicted;
LabelSequence target;
FeatureVectorSequence fvs;
MaxLattice lattice;
Sequence bestStates;
String link; // If non-null, name of HTML file to use for cross-links
String desc;
String idx;
public ExtorInfo (TokenSequence input, Sequence predicted, LabelSequence target, String desc, String idx)
{
this.input = input;
this.predicted = predicted;
this.target = target;
this.desc = desc;
this.idx = idx;
}
}
private static ExtorInfo infoForDoc (String doc, String desc, String idx, DocumentExtraction docextr,
CRFExtractor extor, boolean showLattice)
{
// Instance c2 = new Instance (doc, null, null, null, extor.getTokenizationPipe ());
// TokenSequence input = (TokenSequence) c2.getData ();
TokenSequence input = (TokenSequence) docextr.getInput ();
LabelSequence target = docextr.getTarget ();
Sequence predicted = docextr.getPredictedLabels ();
ExtorInfo info = new ExtorInfo (input, predicted, target, desc, idx);
if (showLattice == true) {
CRF crf = extor.getCrf();
// xxx perhaps the next two lines could be a transducer method???
Instance carrier = extor.getFeaturePipe().pipe(new Instance (input, null, null, null));
info.fvs = (FeatureVectorSequence) carrier.getData ();
info.lattice = new MaxLatticeDefault (crf, (Sequence) carrier.getData(), null);
info.bestStates = info.lattice.bestOutputSequence();
}
return info;
}
// Lattice files get too large if too many instances are written to one file
private static final int EXTRACTIONS_PER_FILE = 25;
public static void viewDualResults (File dir, Extraction e1, CRFExtractor extor1, Extraction e2, CRFExtractor extor2) throws IOException
{
if (e1.getNumDocuments () != e2.getNumDocuments ())
throw new IllegalArgumentException ("Extractions don't match: different number of docs.");
PrintWriter errorStr = new PrintWriter (new FileWriter (new File (dir, "errors.html")));
writeDualExtractions (errorStr, e1, extor1, e2, extor2, 0, e1.getNumDocuments (), false);
errorStr.close ();
int max = e1.getNumDocuments ();
for (int start = 0; start < max; start += EXTRACTIONS_PER_FILE) {
int end = Math.min (start + EXTRACTIONS_PER_FILE, max);
PrintWriter latticeStr = new PrintWriter (new FileWriter (new File (dir, "lattice-"+start+".html")));
writeDualExtractions (latticeStr, e1, extor1, e2, extor2, start, end, true);
latticeStr.close ();
}
}
private static String computeLatticeFname (int docIdx)
{
int htmlDocNo = docIdx / EXTRACTIONS_PER_FILE; // this will get integer truncated
int start = EXTRACTIONS_PER_FILE * htmlDocNo;
return "lattice-"+start+".html";
}
private static void writeDualExtractions (PrintWriter out, Extraction e1, CRFExtractor extor1, Extraction e2, CRFExtractor extor2,
int start, int end, boolean showLattice)
{
writeHeader (out);
for (int i = start; i < end; i++) {
DocumentExtraction doc1 = e1.getDocumentExtraction (i);
DocumentExtraction doc2 = e2.getDocumentExtraction (i);
String desc = doc1.getName();
String doc1Str = ((CharSequence) doc1.getDocument ()).toString();
String doc2Str = ((CharSequence) doc2.getDocument ()).toString();
if (!doc1Str.equals (doc2Str)) {
System.err.println ("Skipping document "+i+": Extractions don't match");
continue;
}
Sequence targ1 = doc1.getPredictedLabels ();
Sequence targ2 = doc2.getPredictedLabels ();
if (!predictionsMatch (targ1, targ2)) {
ExtorInfo info1 = infoForDoc (doc1Str, "CRF1::"+desc, "C1I"+i, doc1, extor1, showLattice);
ExtorInfo info2 = infoForDoc (doc1Str, "CRF2::"+desc, "C2I"+i, doc2, extor2, showLattice);
if (!showLattice) { // add links from errors.html --> lattice.html
info1.link = info2.link = computeLatticeFname (i);
}
dualLattice2html (out, desc, info1, info2);
}
}
writeFooter (out);
}
// if lattice == null, no alpha, beta values printed
public static void dualLattice2html (PrintWriter out, String desc, ExtorInfo info1, ExtorInfo info2)
{
assert (info1.predicted.size() == info1.target.size());
assert (info1.input.size() == info1.predicted.size());
assert (info2.input.size() == info2.predicted.size());
assert (info2.predicted.size() == info2.target.size());
int N = info1.target.size();
for (int start = 0; start < N; start += LENGTH - 1) {
int end = Math.min (info1.predicted.size(), start + LENGTH);
if (!allSeqMatches (info1.predicted, info2.predicted, start, end)) {
error2html (out, info1, start, end);
error2html (out, info2, start, end);
}
}
}
private static boolean predictionsMatch (Sequence targ1, Sequence targ2)
{
if (targ1.size() != targ2.size()) return false;
for (int i = 0; i < targ1.size(); i++)
if (!targ1.get(i).toString().equals (targ2.get(i).toString()))
return false;
return true;
}
}