Package cc.mallet.grmm.learning

Source Code of cc.mallet.grmm.learning.DefaultAcrfTrainer$FileEvaluator

/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.grmm.learning;

import gnu.trove.TIntArrayList;

import java.io.File;
import java.io.FileWriter;
import java.io.PrintWriter;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.logging.Logger;

import cc.mallet.grmm.learning.ACRF;
import cc.mallet.grmm.learning.ACRFEvaluator;
import cc.mallet.grmm.learning.ACRF.MaximizableACRF;
import cc.mallet.grmm.util.LabelsAssignment;
import cc.mallet.optimize.ConjugateGradient;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.*;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Timing;


/**
* Class for training ACRFs.
* <p/>
* <p/>
* Created: Thu Oct 16 17:53:14 2003
*
* @author <a href="mailto:casutton@cs.umass.edu">Charles Sutton</a>
* @version $Id: DefaultAcrfTrainer.java,v 1.1 2007/10/22 21:37:43 mccallum Exp $
*/
public class DefaultAcrfTrainer implements ACRFTrainer {

  private static Logger logger = MalletLogger.getLogger (DefaultAcrfTrainer.class.getName ());
  private Optimizer maxer;
  private static boolean rethrowExceptions = false;

  public DefaultAcrfTrainer ()
  {

  } // ACRFTrainer constructor


  private File outputPrefix = new File ("");

  public void setOutputPrefix (File f)
  {
    outputPrefix = f;
  }


  public Optimizer getMaxer ()
  {
    return maxer;
  }

  public void setMaxer (Optimizer maxer)
  {
    this.maxer = maxer;
  }


  public static boolean isRethrowExceptions ()
  {
    return rethrowExceptions;
  }

  public static void setRethrowExceptions (boolean rethrowExceptions)
  {
    DefaultAcrfTrainer.rethrowExceptions = rethrowExceptions;
  }

  public boolean train (ACRF acrf, InstanceList training)
  {
    return train (acrf, training, null, null,
            new LogEvaluator (), 1);
  }

  public boolean train (ACRF acrf, InstanceList training, int numIter)
  {
    return train (acrf, training, null, null,
            new LogEvaluator (), numIter);
  }

  public boolean train (ACRF acrf, InstanceList training, ACRFEvaluator eval, int numIter)
  {
    return train (acrf, training, null, null, eval, numIter);
  }

  public boolean train (ACRF acrf,
                        InstanceList training,
                        InstanceList validation,
                        InstanceList testing,
                        int numIter)
  {
    return train (acrf, training, validation, testing,
            new LogEvaluator (), numIter);
  }

  public boolean train (ACRF acrf,
                        InstanceList trainingList,
                        InstanceList validationList,
                        InstanceList testSet,
                        ACRFEvaluator eval,
                        int numIter)
  {
    Optimizable.ByGradientValue macrf = createMaximizable (acrf, trainingList);
    return train (acrf, trainingList, validationList, testSet,
            eval, numIter, macrf);
  }

  protected Optimizable.ByGradientValue createMaximizable (ACRF acrf, InstanceList trainingList)
  {
    return acrf.getMaximizable (trainingList);
  }

/*
  public boolean threadedTrain (ACRF acrf,
                                InstanceList trainingList,
                                InstanceList validationList,
                                InstanceList testSet,
                                ACRFEvaluator eval,
                                int numIter)
  {
    Maximizable.ByGradient macrf = acrf.getThreadedMaximizable (trainingList);
    return train (dcrf, trainingList, validationList, testSet,
                          eval, numIter, mdcrf);
  }
*/

  public boolean incrementalTrain (ACRF acrf,
                                   InstanceList training,
                                   InstanceList validation,
                                   InstanceList testing,
                                   int numIter)
  {
    return incrementalTrain (acrf, training, validation, testing,
            new LogEvaluator (), numIter);
  }

  private static final double[] SIZE = new double[]{0.1, 0.5};
  private static final int SUBSET_ITER = 10;

  public boolean incrementalTrain (ACRF acrf,
                                   InstanceList training,
                                   InstanceList validation,
                                   InstanceList testing,
                                   ACRFEvaluator eval,
                                   int numIter)
  {
    long stime = new Date ().getTime ();
    for (int i = 0; i < SIZE.length; i++) {
      InstanceList subset = training.split (new double[]
              {SIZE[i], 1 - SIZE[i]})[0];
      logger.info ("Training on subset of size " + subset.size ());
      Optimizable.ByGradientValue subset_macrf = createMaximizable (acrf, subset);
      train (acrf, training, validation, null, eval,
              SUBSET_ITER, subset_macrf);
      logger.info ("Subset training " + i + " finished...");
    }
    long etime = new Date ().getTime ();
    logger.info ("All subset training finished.  Time = " + (etime - stime) + " ms.");
    return train (acrf, training, validation, testing, eval, numIter);
  }

  public boolean train (ACRF acrf,
                        InstanceList trainingList,
                        InstanceList validationList,
                        InstanceList testSet,
                        ACRFEvaluator eval,
                        int numIter,
                        Optimizable.ByGradientValue macrf)
  {
    Optimizer maximizer = createMaxer (macrf);
//    Maximizer.ByGradient maximizer = new BoldDriver ();
//    Maximizer.ByGradient maximizer = new GradientDescent ();
    boolean converged = false;
    boolean resetOnError = true;
    long stime = System.currentTimeMillis ();

    int numNodes = (macrf instanceof ACRF.MaximizableACRF) ? ((ACRF.MaximizableACRF) macrf).getTotalNodes () : 0;
    double thresh = 1e-5 * numNodes; // "early" stopping (reasonably conservative)

    if (testSet == null) {
      logger.warning ("ACRF trainer: No test set provided.");
    }

    double prevValue = Double.NEGATIVE_INFINITY;
    double currentValue;
    int iter;
    for (iter = 0; iter < numIter; iter++) {
      long etime = new java.util.Date ().getTime ();
      logger.info ("ACRF trainer iteration " + iter + " at time " + (etime - stime));

      try {
        converged = maximizer.optimize (1);
        converged |= callEvaluator (acrf, trainingList, validationList, testSet, iter, eval);

        if (converged) break;
        resetOnError = true;

      } catch (RuntimeException e) {
        e.printStackTrace ();

        // If we get a maximizing error, reset LBFGS memory and try
        // again.  If we get an error on the second try too, then just
        // give up.
        if (resetOnError) {
          logger.warning ("Exception in iteration " + iter + ":" + e + "\n  Resetting LBFGs and trying again...");
          if (maximizer instanceof LimitedMemoryBFGS) ((LimitedMemoryBFGS) maximizer).reset ();
          if (maximizer instanceof ConjugateGradient) ((ConjugateGradient) maximizer).reset ();
          resetOnError = false;
        } else {
          logger.warning ("Exception in iteration " + iter + ":" + e + "\n   Quitting and saying converged...");
          converged = true;
          if (rethrowExceptions) throw e;
          break;
        }
      }
      if (converged) break;

      // "early" stopping
      currentValue = macrf.getValue ();
      if (Math.abs (currentValue - prevValue) < thresh) {
        // ignore cutoff if we're about to reset L-BFGS
        if (resetOnError) {
          logger.info ("ACRFTrainer saying converged: " +
                  " Current value " + currentValue + ", previous " + prevValue +
                  "\n...threshold was " + thresh + " = 1e-5 * " + numNodes);
          converged = true;
          break;
        }
      } else {
        prevValue = currentValue;
      }
    }

    if (iter >= numIter) {
      logger.info ("ACRFTrainer: Too many iterations, stopping training.  maxIter = "+numIter);
    }

    long etime = System.currentTimeMillis ();
    logger.info ("ACRF training time (ms) = " + (etime - stime));

    if (macrf instanceof MaximizableACRF) {
      ((MaximizableACRF) macrf).report ();
    }

    if ((testSet != null) && (eval != null)) {
      // don't cache test set
      boolean oldCache = acrf.isCacheUnrolledGraphs ();
      acrf.setCacheUnrolledGraphs (false);
      eval.test (acrf, testSet, "Testing");
      acrf.setCacheUnrolledGraphs (oldCache);
    }

    return converged;
  }

  private Optimizer createMaxer (Optimizable.ByGradientValue macrf)
  {
    if (maxer == null) {
      return new LimitedMemoryBFGS (macrf);
    } else return maxer;
  }

  /**
   * @return true means stop, false means keep going (opposite of evaluators... ugh!)
   */
  protected boolean callEvaluator (ACRF acrf, InstanceList trainingList, InstanceList validationList,
                                 InstanceList testSet, int iter, ACRFEvaluator eval)
  {
    if (eval == null) return false// If no evaluator specified, keep going blindly

    eval.setOutputPrefix (outputPrefix);

    // don't cache test set
    boolean wasCached = acrf.isCacheUnrolledGraphs ();
    acrf.setCacheUnrolledGraphs (false);

    Timing timing = new Timing ();

    if (!eval.evaluate (acrf, iter+1, trainingList, validationList, testSet)) {
      logger.info ("ACRF trainer: evaluator returned false. Quitting.");
      timing.tick ("Evaluation time (iteration "+iter+")");
      return true;
    }

    timing.tick ("Evaluation time (iteration "+iter+")");

    // set test set caching back to normal
    acrf.setCacheUnrolledGraphs (wasCached);
    return false;
  }

  public boolean someUnsupportedTrain (ACRF acrf,
                        InstanceList trainingList,
                        InstanceList validationList,
                        InstanceList testSet,
                        ACRFEvaluator eval,
                        int numIter)
  {

    Optimizable.ByGradientValue macrf = createMaximizable (acrf, trainingList);
    train (acrf, trainingList, validationList, testSet, eval, 5, macrf);
    ACRF.Template[] tmpls = acrf.getTemplates ();
    for (int ti = 0; ti < tmpls.length; ti++)
      tmpls[ti].addSomeUnsupportedWeights (trainingList);
    logger.info ("Some unsupporetd weights initialized.  Training...");
    return train (acrf, trainingList, validationList, testSet, eval, numIter, macrf);
  }

  public void test (ACRF acrf, InstanceList testing, ACRFEvaluator eval)
  {
    test (acrf, testing, new ACRFEvaluator[]{eval});
  }

  public void test (ACRF acrf, InstanceList testing, ACRFEvaluator[] evals)
  {
    List pred = acrf.getBestLabels (testing);
    for (int i = 0; i < evals.length; i++) {
      evals[i].setOutputPrefix (outputPrefix);
      evals[i].test (testing, pred, "Testing");
    }
  }

  private static final Random r = new Random (1729);


  public static Random getRandom ()
  {
    return r;
  }

  public void train (ACRF acrf, InstanceList training, InstanceList validation, InstanceList testing,
                     ACRFEvaluator eval, double[] proportions, int iterPerProportion)
  {
    for (int i = 0; i < proportions.length; i++) {
      double proportion = proportions[i];
      InstanceList[] lists = training.split (r, new double[]{proportion, 1.0});
      logger.info ("ACRF trainer: Round " + i + ", training proportion = " + proportion);
      train (acrf, lists[0], validation, testing, eval, iterPerProportion);
    }

    logger.info ("ACRF trainer: Training on full data");
    train (acrf, training, validation, testing, eval, 99999);
  }

  public static class LogEvaluator extends ACRFEvaluator {

    private TestResults lastResults;

    public LogEvaluator ()
    {
    }

    ;

    public boolean evaluate (ACRF acrf, int iter,
                             InstanceList training,
                             InstanceList validation,
                             InstanceList testing)
    {
      if (shouldDoEvaluate (iter)) {
        if (training != null) { test (acrf, training, "Training"); }
        if (testing != null) { test (acrf, testing, "Testing"); }
      }
      return true;
    }

    public void test (InstanceList testList, List returnedList,
                      String description)
    {
      logger.info (description+": Number of instances = " + testList.size ());
      TestResults results = computeTestResults (testList, returnedList);
      results.log (description);
      lastResults = results;
//    results.printConfusion ();
    }

    public static TestResults computeTestResults (InstanceList testList, List returnedList)
    {
      TestResults results = new TestResults (testList);
      Iterator it1 = testList.iterator ();
      Iterator it2 = returnedList.iterator ();
      while (it1.hasNext ()) {
        Instance inst = (Instance) it1.next ();
//      System.out.println ("\n\nInstance");
        LabelsAssignment lblseq = (LabelsAssignment) inst.getTarget ();
        LabelsSequence target = lblseq.getLabelsSequence ();
        LabelsSequence returned = (LabelsSequence) it2.next ();
//      System.out.println (target);
        compareLabelings (results, returned, target);
      }

      results.computeStatistics ();
      return results;
    }


    static void compareLabelings (TestResults results,
                                  LabelsSequence returned,
                                  LabelsSequence target)
    {
      assert returned.size () == target.size ();
      for (int i = 0; i < returned.size (); i++) {
//      System.out.println ("Time "+i);
        Labels lblsReturned = returned.getLabels (i);
        Labels lblsTarget = target.getLabels (i);
        results.incrementCount (lblsReturned, lblsTarget);
      }
    }

    public double getJointAccuracy ()
    {
      return lastResults.getJointAccuracy ();
    }
  }

  public static class FileEvaluator extends ACRFEvaluator {

    private File file;

    public FileEvaluator (File file)
    {
      this.file = file;
    }

    ;

    public boolean evaluate (ACRF acrf, int iter,
                             InstanceList training,
                             InstanceList validation,
                             InstanceList testing)
    {
      if (shouldDoEvaluate (iter)) {
        test (acrf, testing, "Testing ");
      }
      return true;
    }

    public void test (InstanceList testList, List returnedList,
                      String description)
    {
      logger.info ("Number of testing instances = " + testList.size ());
      TestResults results = LogEvaluator.computeTestResults (testList, returnedList);

      try {
        PrintWriter writer = new PrintWriter (new FileWriter (file, true));
        results.print (description, writer);
        writer.close ();
      } catch (Exception e) {
        e.printStackTrace ();
      }
//    results.printConfusion ();
    }
  }

  public static class TestResults {

    public int[][] confusion;  // Confusion matrix
    public int numClasses;

    // Marginals of confusion matrix
    public int[] trueCounts;
    public int[] returnedCounts;

    // Per-class precision, recall, and F1.
    public double[] precision;
    public double[] recall;
    public double[] f1;

    // Measuring accuracy of each factor
    public TIntArrayList[] factors;

    // Measuring joint accuracy
    public int maxT = 0;
    public int correctT = 0;

    public Alphabet alphabet;

    TestResults (InstanceList ilist)
    {
      this (ilist.get (0));
    }

    TestResults (Instance inst)
    {
      alphabet = new Alphabet ();
      setupAlphabet (inst);

      numClasses = alphabet.size ();
      confusion = new int [numClasses][numClasses];
      precision = new double [numClasses];
      recall = new double [numClasses];
      f1 = new double [numClasses];
    }

    // This isn't pretty, but I swear there's
    //  not an easy way...
    private void setupAlphabet (Instance inst)
    {
      LabelsAssignment lblseq = (LabelsAssignment) inst.getTarget ();
      factors = new TIntArrayList [lblseq.numSlices ()];
      for (int i = 0; i < lblseq.numSlices (); i++) {
        LabelAlphabet dict = lblseq.getOutputAlphabet (i);
        factors[i] = new TIntArrayList (dict.size ());
        for (int j = 0; j < dict.size (); j++) {
          int idx = alphabet.lookupIndex (dict.lookupObject (j));
          factors[i].add (idx);
        }
      }
    }

    void incrementCount (Labels lblsReturned, Labels lblsTarget)
    {
      boolean allSame = true;

      // and per-label accuracy
      for (int j = 0; j < lblsReturned.size (); j++) {
        Label lret = lblsReturned.get (j);
        Label ltarget = lblsTarget.get (j);
//        System.out.println(ltarget+" vs. "+lret);
        int idxTrue = alphabet.lookupIndex (ltarget.getEntry ());
        int idxRet = alphabet.lookupIndex (lret.getEntry ());
        if (idxTrue != idxRet) allSame = false;
        confusion[idxTrue][idxRet]++;
      }

      // Measure joint accuracy
      maxT++;
      if (allSame) correctT++;
    }

    void computeStatistics ()
    {
      // Compute marginals of confusion matrix.
      //  Assumes that confusion[i][j] means true label i and
      //  returned label j
      trueCounts = new int [numClasses];
      returnedCounts = new int [numClasses];
      for (int i = 0; i < numClasses; i++) {
        for (int j = 0; j < numClasses; j++) {
          trueCounts[i] += confusion[i][j];
          returnedCounts[j] += confusion[i][j];
        }
      }

      // Compute per-class precision, recall, and F1
      for (int i = 0; i < numClasses; i++) {
        double correct = confusion[i][i];

        if (returnedCounts[i] == 0) {
          precision[i] = (correct == 0) ? 1.0 : 0.0;
        } else {
          precision[i] = correct / returnedCounts[i];
        }

        if (trueCounts[i] == 0) {
          recall[i] = 1.0;
        } else {
          recall[i] = correct / trueCounts[i];
        }

        f1[i] = (2 * precision[i] * recall[i]) / (precision[i] + recall[i]);
      }
    }

    public void log ()
    {
      log ("");
    }

    public void log (String desc)
    {
      logger.info (desc+":  i\tLabel\tN\tCorrect\tReturned\tP\tR\tF1");
      for (int i = 0; i < numClasses; i++) {
        logger.info (desc+":  "+i + "\t" + alphabet.lookupObject (i) + "\t"
                + trueCounts[i] + "\t"
                + confusion[i][i] + "\t"
                + returnedCounts[i] + "\t"
                + precision[i] + "\t"
                + recall[i] + "\t"
                + f1[i] + "\t");
      }
      for (int fnum = 0; fnum < factors.length; fnum++) {
        int correct = 0;
        int returned = 0;
        for (int i = 0; i < factors[fnum].size (); i++) {
          int lbl = factors[fnum].get (i);
          correct += confusion[lbl][lbl];
          returned += returnedCounts[lbl];
        }
        logger.info (desc + ":  Factor " + fnum + " accuracy: (" + correct + " " + returned + ") "
                + (correct / ((double) returned)));
      }

      logger.info (desc + " CorrectT " + correctT + "  maxt " + maxT);
      logger.info (desc + " Joint accuracy: " + ((double) correctT) / maxT);
    }

    public void print (String desc, PrintWriter out)
    {
      out.println ("i\tLabel\tN\tCorrect\tReturned\tP\tR\tF1");
      for (int i = 0; i < numClasses; i++) {
        out.println (i + "\t" + alphabet.lookupObject (i) + "\t"
                + trueCounts[i] + "\t"
                + confusion[i][i] + "\t"
                + returnedCounts[i] + "\t"
                + precision[i] + "\t"
                + recall[i] + "\t"
                + f1[i] + "\t");
      }
      for (int fnum = 0; fnum < factors.length; fnum++) {
        int correct = 0;
        int returned = 0;
        for (int i = 0; i < factors[fnum].size (); i++) {
          int lbl = factors[fnum].get (i);
          correct += confusion[lbl][lbl];
          returned += returnedCounts[lbl];
        }
        out.println (desc + " Factor " + fnum + " accuracy: (" + correct + " " + returned + ") "
                + (correct / ((double) returned)));
      }

      out.println (desc + " CorrectT " + correctT + "  maxt " + maxT);
      out.println (desc + " Joint accuracy: " + ((double) correctT) / maxT);
    }

    void printConfusion ()
    {
      System.out.println ("True\t\tReturned\tCount");
      for (int i = 0; i < numClasses; i++) {
        for (int j = 0; j < numClasses; j++) {
          System.out.println (i + "\t\t" + j + "\t" + confusion[i][j]);
        }
      }
    }

    public double getJointAccuracy ()
    {
      return ((double) correctT) / maxT;
    }

  } // TestResults


} // ACRFTrainer
TOP

Related Classes of cc.mallet.grmm.learning.DefaultAcrfTrainer$FileEvaluator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.