Package cc.mallet.fst

Source Code of cc.mallet.fst.CRFTrainerByValueGradients$OptimizableCRF

package cc.mallet.fst;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;

import java.util.BitSet;
import java.util.Random;
import java.util.logging.Logger;

import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;

import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.OptimizationException;
import cc.mallet.optimize.Optimizer;

import cc.mallet.util.MalletLogger;


/**
* A CRF trainer that can combine multiple objective functions, each represented
* by a Optmizable.ByValueGradient.
*/
public class CRFTrainerByValueGradients extends TransducerTrainer implements TransducerTrainer.ByOptimization {

  private static Logger logger = MalletLogger.getLogger(CRFTrainerByLabelLikelihood.class.getName());

  CRF crf;
  // gsc: keep objects instead of classnames, this will give more flexibility to the
  // user to setup new CRFOptimizable* objects and then pass them directly in the constructor,
  // so the CRFOptimizable inner class no longer creates CRFOptimizable* objects
  Optimizable.ByGradientValue[] optimizableByValueGradientObjects;
//  Class[] optimizableByValueGradientClasses;
  OptimizableCRF ocrf;
  Optimizer opt;
  int iterationCount = 0;
  boolean converged;
  // gsc: removing these options, the user ought to set the weights before
  // creating the trainer object
//  boolean useSparseWeights = true;
//  // gsc
//  boolean useUnsupportedTrick = false;
 
  // Various values from CRF acting as indicators of when we need to ...
  private int cachedValueWeightsStamp = -1// ... re-calculate expectations and values to getValue() because weights' values changed
  private int cachedGradientWeightsStamp = -1; // ... re-calculate to getValueGradient() because weights' values changed
 
  // gsc: removing this because the user will call setWeightsDimensionsAsIn
//  private int cachedWeightsStructureStamp = -1; // ... re-allocate crf.weights, expectations & constraints because new states, transitions
  // Use mcrf.trainingSet to see when we need to re-allocate crf.weights, expectations & constraints because we are using a different TrainingList than last time

  // gsc: number of times to reset (the optimizer), and continue training when the "could not step in
  // current direction" exception occurs
  public static final int DEFAULT_MAX_RESETS = 3;
  int maxResets = DEFAULT_MAX_RESETS;
 
  public CRFTrainerByValueGradients (CRF crf, Optimizable.ByGradientValue[] optimizableByValueGradientObjects) {
    this.crf = crf;
    this.optimizableByValueGradientObjects = optimizableByValueGradientObjects;
  }
 
  public Transducer getTransducer() { return crf; }
  public CRF getCRF () { return crf; }
  public Optimizer getOptimizer() { return opt; }
  /** Returns true if training converged, false otherwise. */
  public boolean isConverged() { return converged; }
  /** Returns true if training converged, false otherwise. */
  public boolean isFinishedTraining() { return converged; }
  public int getIteration () { return iterationCount; }
 
  // gsc
  public Optimizable.ByGradientValue[] getOptimizableByGradientValueObjects() {
    return optimizableByValueGradientObjects;
  }

  /**
   * Returns an optimizable CRF that contains a collection of objective functions.
   * <p>
   * If one doesn't exist then creates one and sets the optimizer to null.
   */
  public OptimizableCRF getOptimizableCRF (InstanceList trainingSet) {
    // gsc: user should call setWeightsDimensionsAsIn before the optimizable and
    // trainer objects are created
//    if (cachedWeightsStructureStamp != crf.weightsStructureChangeStamp) {
//        if (useSparseWeights)
//          crf.setWeightsDimensionAsIn (trainingSet, useUnsupportedTrick); 
//        else
//          crf.setWeightsDimensionDensely ();
//      ocrf = null;
//      cachedWeightsStructureStamp = crf.weightsStructureChangeStamp;
//    }
    if (ocrf == null || ocrf.trainingSet != trainingSet) {
      ocrf = new OptimizableCRF (crf, trainingSet);
      opt = null;
    }
    return ocrf;
  }
 
  /**
   * Returns a L-BFGS optimizer, creating if one doesn't exist.
   * <p>
   * Also creates an optimizable CRF if required.
   */
  public Optimizer getOptimizer (InstanceList trainingSet) {
    getOptimizableCRF(trainingSet); // this will set this.mcrf if necessary
    if (opt == null || ocrf != opt.getOptimizable())
      opt = new LimitedMemoryBFGS(ocrf)// Alternative: opt = new ConjugateGradient (0.001);
    return opt;
  }

  /** Trains a CRF until convergence. */
  public boolean trainIncremental (InstanceList training)
  {
    return train (training, Integer.MAX_VALUE);
  }

  /**
   * Trains a CRF until convergence or specified number of iterations, whichever is earlier.
   * <p>
   * Also creates an optimizable CRF and an optmizer if required.
   */
  public boolean train (InstanceList trainingSet, int numIterations) {
    if (numIterations <= 0)
      return false;
    assert (trainingSet.size() > 0);

    getOptimizableCRF(trainingSet); // This will set this.mcrf if necessary
    getOptimizer(trainingSet); // This will set this.opt if necessary

    int numResets = 0;
    boolean converged = false;
    logger.info ("CRF about to train with "+numIterations+" iterations");
    for (int i = 0; i < numIterations; i++) {
      try {
        // gsc: timing each iteration
        long startTime = System.currentTimeMillis();
        converged = opt.optimize (1);
        logger.info ("CRF finished one iteration of maximizer, i="+i+", "+
                     +(System.currentTimeMillis()-startTime)/1000 + " secs.");
        iterationCount++;
        runEvaluators();
      } catch (OptimizationException e) {
            // gsc: resetting the optimizer for specified number of times
        e.printStackTrace();
        logger.info ("Catching exception.");
        if (numResets < maxResets) {
          // reset the optimizer and get a new one
          logger.info("Resetting optimizer.");
          ++numResets;
          opt = null;
          getOptimizer(trainingSet);
//        logger.info ("Catching exception; saying converged.");
//        converged = true;
        } else {
          logger.info("Saying converged.");
          converged = true;
        }
      }
      if (converged) {
        logger.info ("CRF training has converged, i="+i);
        break;
      }
    }
    return converged;
  }
 

  /**
   * Train a CRF on various-sized subsets of the data.  This method is typically used to accelerate training by
   * quickly getting to reasonable parameters on only a subset of the parameters first, then on progressively more data.
   * @param training The training Instances.
   * @param numIterationsPerProportion Maximum number of Maximizer iterations per training proportion.
   * @param trainingProportions If non-null, train on increasingly
   * larger portions of the data, e.g. new double[] {0.2, 0.5, 1.0}.  This can sometimes speedup convergence.
   * Be sure to end in 1.0 if you want to train on all the data in the end. 
   * @return True if training has converged.
   */
  public boolean train (InstanceList training, int numIterationsPerProportion, double[] trainingProportions)
  {
    int trainingIteration = 0;
    assert (trainingProportions.length > 0);
    boolean converged = false;
    for (int i = 0; i < trainingProportions.length; i++) {
      assert (trainingProportions[i] <= 1.0);
      logger.info ("Training on "+trainingProportions[i]+"% of the data this round.");
      if (trainingProportions[i] == 1.0)
        converged = this.train (training, numIterationsPerProportion);
      else
        converged = this.train (training.split (new Random(1)
            new double[] {trainingProportions[i]1-trainingProportions[i]})[0], numIterationsPerProportion);
      trainingIteration += numIterationsPerProportion;
    }
    return converged;
  }
 
  // gsc: see comment in getOptimizableCRF
//  public void setUseSparseWeights (boolean b) { useSparseWeights = b; }
//  public boolean getUseSparseWeights () { return useSparseWeights; }
//
//  // gsc
//  public void setUseUnsupportedTrick (boolean b) { useUnsupportedTrick = b; }
//  public boolean getUseUnsupportedTrick () { return useUnsupportedTrick; }

  // gsc: change max. number of times the optimizer can be reset before
  // throwing the "could not step in current direction" exception
  /**
   * Sets the max. number of times the optimizer can be reset before throwing
   * an exception.
   * <p>
   * Default value: <tt>DEFAULT_MAX_RESETS</tt>.
   */
  public void setMaxResets(int maxResets) { this.maxResets = maxResets; }
 
  /** An optimizable CRF that contains a collection of objective functions. */
  public class OptimizableCRF implements Optimizable.ByGradientValue, Serializable
  {
    InstanceList trainingSet;
    double cachedValue = -123456789;
    double[] cachedGradie;
    BitSet infiniteValues = null;
    CRF crf;
    Optimizable.ByGradientValue[] opts;
   
    protected OptimizableCRF (CRF crf, InstanceList ilist)
    {
      // Set up
      this.crf = crf;
      this.trainingSet = ilist;
      this.opts = optimizableByValueGradientObjects;
      cachedGradie = new double[crf.parameters.getNumFactors()];
      cachedValueWeightsStamp = -1;
      cachedGradientWeightsStamp = -1;
    }

//    protected OptimizableCRF (CRF crf, InstanceList ilist)
//    {
//      // Set up
//      this.crf = crf;
//      this.trainingSet = ilist;
//      cachedGradie = new double[crf.parameters.getNumFactors()];
//      Class[] parameterTypes = new Class[] {CRF.class, InstanceList.class};
//      for (int i = 0; i < optimizableByValueGradientClasses.length; i++) {
//        try { 
//          Constructor c = optimizableByValueGradientClasses[i].getConstructor(parameterTypes);
//          opts[i] = (Optimizable.ByGradientValue) c.newInstance(crf, ilist);
//        } catch (Exception e) { throw new IllegalStateException ("Couldn't contruct Optimizable.ByGradientValue"); }
//      }
//      cachedValueWeightsStamp = -1;
//      cachedGradientWeightsStamp = -1;
//    }

    // TODO Move these implementations into CRF.java, and put here stubs that call them!
    public int getNumParameters () {
      return crf.parameters.getNumFactors();
    }

    public void getParameters (double[] buffer) {
      crf.parameters.getParameters(buffer);
    }

    public double getParameter (int index) {
      return crf.parameters.getParameter(index);
    }

    public void setParameters (double [] buff) {
      crf.parameters.setParameters(buff);
      crf.weightsValueChanged();
    }

    public void setParameter (int index, double value) {
      crf.parameters.setParameter(index, value);
      crf.weightsValueChanged();
    }

    /** Returns the log probability of the training sequence labels and the prior over parameters. */
    public double getValue ()
    {
      if (crf.weightsValueChangeStamp != cachedValueWeightsStamp) {
        // The cached value is not up to date; it was calculated for a different set of CRF weights.
        long startingTime = System.currentTimeMillis();

        cachedValue = 0;
        for (int i = 0; i < opts.length; i++)
          cachedValue += opts[i].getValue();
       
        cachedValueWeightsStamp = crf.weightsValueChangeStamp;  // cachedValue is now no longer stale
        logger.info ("getValue() (loglikelihood) = "+cachedValue);
        logger.fine ("Inference milliseconds = "+(System.currentTimeMillis() - startingTime));
      }
      return cachedValue;
    }

    public void getValueGradient (double [] buffer)
    {
      // PriorGradient is -parameter/gaussianPriorVariance
      // Gradient is (constraint - expectation + PriorGradient)
      // == -(expectation - constraint - PriorGradient).
      // Gradient points "up-hill", i.e. in the direction of higher value
      if (cachedGradientWeightsStamp != crf.weightsValueChangeStamp) {
        getValue (); // This will fill in the this.expectation, updating it if necessary
        MatrixOps.setAll(cachedGradie, 0);
        double[] b2 = new double[buffer.length];
        for (int i = 0; i < opts.length; i++) {
          MatrixOps.setAll(b2, 0);
          opts[i].getValueGradient(b2);
          MatrixOps.plusEquals(cachedGradie, b2);
        }
        cachedGradientWeightsStamp = crf.weightsValueChangeStamp;
      }
      System.arraycopy(cachedGradie, 0, buffer, 0, cachedGradie.length);
    }

    //Serialization of MaximizableCRF

    private static final long serialVersionUID = 1;
    private static final int CURRENT_SERIAL_VERSION = 0;

    private void writeObject (ObjectOutputStream out) throws IOException {
      out.writeInt (CURRENT_SERIAL_VERSION);
      out.writeObject(trainingSet);
      out.writeDouble(cachedValue);
      out.writeObject(cachedGradie);
      out.writeObject(infiniteValues);
      out.writeObject(crf);
    }

    private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
      in.readInt ();
      trainingSet = (InstanceList) in.readObject();
      cachedValue = in.readDouble();
      cachedGradie = (double[]) in.readObject();
      infiniteValues = (BitSet) in.readObject();
      crf = (CRF)in.readObject();
    }

  }

  // Serialization for CRFTrainerByValueGradient

  private static final long serialVersionUID = 1;
  private static final int CURRENT_SERIAL_VERSION = 1;
  static final int NULL_INTEGER = -1;

  /* Need to check for null pointers. */
  private void writeObject (ObjectOutputStream out) throws IOException {
    out.writeInt (CURRENT_SERIAL_VERSION);
    //out.writeInt(defaultFeatureIndex);
    out.writeInt(cachedGradientWeightsStamp);
    out.writeInt(cachedValueWeightsStamp);
//    out.writeInt(cachedWeightsStructureStamp);
//    out.writeBoolean (useSparseWeights);
    throw new IllegalStateException("Implementation not yet complete.");   
  }
 
  private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
    in.readInt ();
    //defaultFeatureIndex = in.readInt();
//    useSparseWeights = in.readBoolean();
    throw new IllegalStateException("Implementation not yet complete.");   
  }
}
TOP

Related Classes of cc.mallet.fst.CRFTrainerByValueGradients$OptimizableCRF

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.