Package edu.stanford.nlp.sempre.vis

Source Code of edu.stanford.nlp.sempre.vis.ExampleDerivations

package edu.stanford.nlp.sempre.vis;

import com.google.common.base.Joiner;
import com.google.common.collect.Lists;
import edu.stanford.nlp.sempre.Derivation;
import edu.stanford.nlp.sempre.Example;
import edu.stanford.nlp.sempre.Params;
import edu.stanford.nlp.sempre.Vis;
import fig.basic.IOUtils;
import fig.basic.LogInfo;
import fig.basic.MapUtils;
import fig.basic.Pair;
import fig.exec.Execution;

import java.io.File;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
* Visualize predicate derivations from one or more executions of the semantic
* parser.
*
* @author Roy Frostig
*/
public class ExampleDerivations {
  private final List<String> execPaths;
  private final int topN;

  public ExampleDerivations(List<String> execPaths, int topN) {
    this.execPaths = execPaths;
    this.topN = topN;
  }

  private PrintWriter tmpLog;
  private void pushLog(PrintWriter newLog) {
    tmpLog = LogInfo.getFileOut();
    LogInfo.setFileOut(newLog);
  }
  private void popLog() {
    LogInfo.setFileOut(tmpLog);
  }

  public boolean write(int iter, String group) {
    List<File> files = Vis.getFilesPerExec(execPaths, iter, group);

    if (files == null)
      return false;

    LogInfo.logs("Reading files: %s", files);
    List<Params> params = getParamsPerExec(iter);
    assert files.size() == params.size();

    String basePath = "vis.preds-iter" + iter + "-" + group + ".examples";
    String outPath = Execution.getFile(basePath);
    PrintWriter out = IOUtils.openOutHard(outPath);
    LogInfo.log("Writing " + basePath);
    pushLog(out);

    int i = 0;
    for (List<Example> row : Vis.zipExamples(files)) {
      Example first = row.get(0);
      LogInfo.begin_track("Example %d: %s [%s]", i, first.getUtterance(), first.getId());
      List<Map<Derivation, Integer>> rowDerivIndices = getRowDerivIndices(row);
      List<Derivation> firstDerivs = first.getPredDerivations();

      LogInfo.logs("STAT %s", first.getEvaluation().summary());

      for (int j = 0; j < firstDerivs.size(); j++) {
        Derivation firstDeriv = firstDerivs.get(j);
        LogInfo.begin_track("Ex %d, derivation %d", i, j);
        if (firstDeriv.getCompatibility() == 1.0d)
          LogInfo.log("DERIV CORRECT");
        else
          LogInfo.log("DERIV WRONG");
        LogInfo.logs("DERIV %s", firstDeriv);

        // TODO no executor stats being written/loaded at present.
        //firstDeriv.getExecutorStats().logStats(""+j);

        List<Pair<String, Double>> topFeatures;
        List<Pair<String, Double>> botFeatures;
        if (row.size() >= 2) {
          Integer derivIndex = rowDerivIndices.get(1).get(firstDeriv);
          Derivation secondDeriv = (derivIndex == null)
                                   ? null
                                   : row.get(1).getPredDerivations().get(derivIndex);
          topFeatures = getTopFeatures(
              topN,
              params.get(0),
              params.get(1),
              firstDeriv,
              secondDeriv,
              false);
          botFeatures = getTopFeatures(
              topN,
              params.get(0),
              params.get(1),
              firstDeriv,
              secondDeriv,
              true);
        } else {
          topFeatures = getTopFeatures(
              topN,
              params.get(0),
              null,
              firstDeriv,
              null,
              false);
          botFeatures = getTopFeatures(
              topN,
              params.get(0),
              null,
              firstDeriv,
              null,
              true);
        }

        String[] positions = new String[row.size()];
        String[][] chart = new String[topFeatures.size() + botFeatures.size()][row.size()];
        String[][] totals = new String[4][row.size()];

        for (int k = 0; k < row.size(); k++) {
          Integer derivIndex = rowDerivIndices.get(k).get(firstDeriv);

          // Positions
          positions[k] = String.format("%12s", (derivIndex == null) ? "~" : ("" + derivIndex));

          // Features
          // Walk down topFeatures and up botFeatures at the same time.
          int n = Math.max(topFeatures.size(), botFeatures.size());
          for (int f = 0; f < n; f++) {
            if (f < topFeatures.size()) {
              String featureName = topFeatures.get(f).getFirst();
              double val = params.get(k).getWeight(featureName);
              chart[f][k] = String.format("%12.4f", val);
            }
            if (f < botFeatures.size()) {
              String featureName = botFeatures.get(f).getFirst();
              double val = params.get(k).getWeight(featureName);
              chart[chart.length - 1 - f][k] = String.format("%12.4f", val);
            }
          }

          // Totals
          if (derivIndex == null) {
            totals[0][k] = totals[1][k] = totals[2][k] = totals[3][k] = String.format("%12s", "~");
          } else {
            Derivation deriv = row.get(k).getPredDerivations().get(derivIndex);
            totals[0][k] = String.format("%12.4f", deriv.getScore());
            totals[1][k] = String.format("%12.4f", deriv.getProb());
            totals[2][k] = String.format("%12.4f", deriv.getCompatibility());
            totals[3][k] = String.format("%12d", deriv.getMaxBeamPosition());
          }
        }
        LogInfo.log("");

        LogInfo.logs("%-40s\t%12s%s", "POS", "", Joiner.on(' ').join(positions));
        LogInfo.log("");

        for (int f = 0; f < chart.length; f++) {
          List<Pair<String, Double>> tops = (f < topFeatures.size()) ? topFeatures : botFeatures;
          int topsIndex = (f < topFeatures.size()) ? f : chart.length - 1 - f;
          if (f == topFeatures.size())
            LogInfo.log("...");
          LogInfo.logs(
              "%-40s\t%12.4f%s",
              "FEAT " + tops.get(topsIndex).getFirst(),
              tops.get(topsIndex).getSecond(),
              Joiner.on(' ').join(chart[f]));
        }
        LogInfo.log("");
        LogInfo.logs("%-40s\t%12s%s", "SCORE", "", Joiner.on(' ').join(totals[0]));
        LogInfo.logs("%-40s\t%12s%s", "PROB", "", Joiner.on(' ').join(totals[1]));
        LogInfo.logs("%-40s\t%12s%s", "COMPAT", "", Joiner.on(' ').join(totals[2]));
        LogInfo.logs("%-40s\t%12s%s", "MAXBEAMPOS", "", Joiner.on(' ').join(totals[3]));
        LogInfo.end_track();
      }
      LogInfo.end_track();
      i++;
    }

    popLog();
    return true;
  }

  public void writeAll() {
    boolean done = false;
    for (int iter = 0; !done; iter++)
      for (String group : new String[]{"train", "dev"})
        if (done = !write(iter, group))
          break;
  }

  private List<Map<Derivation, Integer>> getRowDerivIndices(List<Example> row) {
    List<Map<Derivation, Integer>> res = new ArrayList<Map<Derivation, Integer>>(row.size());
    for (Example ex : row) {
      Map<Derivation, Integer> m = new HashMap<Derivation, Integer>();
      int i = 0;
      for (Derivation d : ex.getPredDerivations()) {
        if (!m.containsKey(d))
          m.put(d, i);
        i++;
      }
      res.add(m);
    }
    return res;
  }

  private List<Pair<String, Double>> getTopFeatures(int topN,
                                                    Params firstParams,
                                                    Params secondParams,
                                                    Derivation firstDeriv,
                                                    Derivation secondDeriv,
                                                    boolean reverse) {
    Map<String, Double> sortBy;
    Map<String, Double> firstFeats = new HashMap<String, Double>();
    firstDeriv.incrementAllFeatureVector(1.0d, firstFeats);
    double factor = reverse ? -1.0d : 1.0d;
    if (secondDeriv == null) {
      sortBy = Utils.scale(
          factor,
          Utils.elementwiseProduct(
              firstFeats,
              firstParams.getWeights()));
    } else {
      Map<String, Double> secondFeats = new HashMap<String, Double>();
      secondDeriv.incrementAllFeatureVector(1.0d, secondFeats);
      sortBy = Utils.linearComb(
          factor, -factor,
          Utils.elementwiseProduct(
              firstFeats,
              firstParams.getWeights()),
          Utils.elementwiseProduct(
              secondFeats,
              secondParams.getWeights()));
    }
    List<Pair<String, Double>> top = MapUtils.getTopN(sortBy, topN);
    List<Pair<String, Double>> topFeats = Lists.newArrayListWithExpectedSize(top.size());
    for (Pair<String, Double> pair : top) {
      topFeats.add(
          Pair.newPair(
              pair.getFirst(),
              firstFeats.get(pair.getFirst())));
    }
    return topFeats;
  }

  private List<Params> getParamsPerExec(int iter) {
    List<Params> params = new ArrayList<Params>();
    for (String execPath : execPaths) {
      Params p = new Params();
      p.read(execPath + "/params." + iter);
      params.add(p);
    }
    return params;
  }
}
TOP

Related Classes of edu.stanford.nlp.sempre.vis.ExampleDerivations

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.