Package cc.mallet.fst

Source Code of cc.mallet.fst.CRFTrainerByLabelLikelihood

package cc.mallet.fst;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Random;
import java.util.logging.Logger;

import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.ExpGain;
import cc.mallet.types.FeatureInducer;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.GradientGain;
import cc.mallet.types.InfoGain;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.LabelVector;
import cc.mallet.types.RankedFeatureVector;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;


/**
* Unlike ClassifierTrainer, TransducerTrainer is not "stateless" between calls
* to train. A TransducerTrainer is constructed paired with a specific
* Transducer, and can only train that Transducer. CRF stores and has methods
* for FeatureSelection and weight freezing. CRFTrainer stores and has methods
* for determining the contents/dimensions/sparsity/FeatureInduction of the
* CRF's weights as determined by training data.
* <p>
* <b>Note:</b> In the future this class may go away in favor of some default
* version of CRFTrainerByValueGradients.
*/
public class CRFTrainerByLabelLikelihood extends TransducerTrainer implements TransducerTrainer.ByOptimization {
  private static Logger logger = MalletLogger.getLogger(CRFTrainerByLabelLikelihood.class.getName());

  static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 1.0;
  static final double DEFAULT_HYPERBOLIC_PRIOR_SLOPE = 0.2;
  static final double DEFAULT_HYPERBOLIC_PRIOR_SHARPNESS = 10.0;

  CRF crf;
  //OptimizableCRF ocrf;
  CRFOptimizableByLabelLikelihood ocrf;
  Optimizer opt;
  int iterationCount = 0;
  boolean converged;
 
  boolean usingHyperbolicPrior = false;
  double gaussianPriorVariance = DEFAULT_GAUSSIAN_PRIOR_VARIANCE;
  double hyperbolicPriorSlope = DEFAULT_HYPERBOLIC_PRIOR_SLOPE;
  double hyperbolicPriorSharpness = DEFAULT_HYPERBOLIC_PRIOR_SHARPNESS;
  boolean useSparseWeights = true;
  boolean useNoWeights = false; // TODO remove this; it is just for debugging
  private transient boolean useSomeUnsupportedTrick = true;

  // Various values from CRF acting as indicators of when we need to ...
  private int cachedValueWeightsStamp = -1// ... re-calculate expectations and values to getValue() because weights' values changed
  private int cachedGradientWeightsStamp = -1; // ... re-calculate to getValueGradient() because weights' values changed
  private int cachedWeightsStructureStamp = -1; // ... re-allocate crf.weights, expectations & constraints because new states, transitions
  // Use mcrf.trainingSet to see when we need to re-allocate crf.weights, expectations & constraints because we are using a different TrainingList than last time

  // xxx temporary hack.  This is quite useful to have, though!! -cas
  public boolean printGradient = false;

 
 
  public CRFTrainerByLabelLikelihood (CRF crf) {
    this.crf = crf;
  }
 
  public Transducer getTransducer() { return crf; }
  public CRF getCRF () { return crf; }
  public Optimizer getOptimizer() { return opt; }
  public boolean isConverged() { return converged; }
  public boolean isFinishedTraining() { return converged; }
  public int getIteration () { return iterationCount; }
 

  /**
   * Use this method to specify whether or not factors
   * are added to the CRF by this trainer.  If you have
   * already setup the factors in your CRF, you may
   * not want the trainer to add additional factors.
   *
   * @param flag If true, this trainer adds no factors to the CRF.
   */
  public void setAddNoFactors(boolean flag) {
    this.useNoWeights = flag;
  }

  public CRFOptimizableByLabelLikelihood getOptimizableCRF (InstanceList trainingSet) {
    if (cachedWeightsStructureStamp != crf.weightsStructureChangeStamp) {
      if (!useNoWeights) {
        if (useSparseWeights)
          crf.setWeightsDimensionAsIn (trainingSet, useSomeUnsupportedTrick)
        else
          crf.setWeightsDimensionDensely ();
      }
      //reallocateSufficientStatistics(); // Not necessary here because it is done in the constructor for OptimizableCRF
      ocrf = null;
      cachedWeightsStructureStamp = crf.weightsStructureChangeStamp;
    }
    if (ocrf == null || ocrf.trainingSet != trainingSet) {
      //ocrf = new OptimizableCRF (crf, trainingSet);
      ocrf = new CRFOptimizableByLabelLikelihood(crf, trainingSet);
      ocrf.setGaussianPriorVariance(gaussianPriorVariance);
      ocrf.setHyperbolicPriorSharpness(hyperbolicPriorSharpness);
      ocrf.setHyperbolicPriorSlope(hyperbolicPriorSlope);
      ocrf.setUseHyperbolicPrior(usingHyperbolicPrior);
      opt = null;
    }
    return ocrf;
  }
 
  public Optimizer getOptimizer (InstanceList trainingSet) {
    getOptimizableCRF(trainingSet); // this will set this.mcrf if necessary
    if (opt == null || ocrf != opt.getOptimizable())
      opt = new LimitedMemoryBFGS(ocrf)// Alternative: opt = new ConjugateGradient (0.001);
    return opt;
  }
 
  // Java question:
  // If I make a non-static inner class CRF.Trainer,
  // can that class by subclassed in another .java file,
  // and can that subclass still have access to all the CRF's
  // instance variables?
  // ANSWER: Yes and yes, but you have to use special syntax in the subclass ctor (see mallet-dev archive) -cas


  public boolean trainIncremental (InstanceList training)
  {
    return train (training, Integer.MAX_VALUE);
  }


  public boolean train (InstanceList trainingSet, int numIterations) {
    if (numIterations <= 0)
      return false;
    assert (trainingSet.size() > 0);

    getOptimizableCRF(trainingSet); // This will set this.mcrf if necessary
    getOptimizer(trainingSet); // This will set this.opt if necessary

    boolean converged = false;
    logger.info ("CRF about to train with "+numIterations+" iterations");
    for (int i = 0; i < numIterations; i++) {
      try {
        converged = opt.optimize (1);
        iterationCount++;
        logger.info ("CRF finished one iteration of maximizer, i="+i);
        runEvaluators();
      } catch (IllegalArgumentException e) {
        e.printStackTrace();
        logger.info ("Catching exception; saying converged.");
        converged = true;
      } catch (Exception e) {
        e.printStackTrace();
        logger.info("Catching exception; saying converged.");
        converged = true;
      }
      if (converged) {
        logger.info ("CRF training has converged, i="+i);
        break;
      }
    }
    return converged;
  }
 
 

  /**
   * Train a CRF on various-sized subsets of the data.  This method is typically used to accelerate training by
   * quickly getting to reasonable parameters on only a subset of the parameters first, then on progressively more data.
   * @param training The training Instances.
   * @param numIterationsPerProportion Maximum number of Maximizer iterations per training proportion.
   * @param trainingProportions If non-null, train on increasingly
   * larger portions of the data, e.g. new double[] {0.2, 0.5, 1.0}.  This can sometimes speedup convergence.
   * Be sure to end in 1.0 if you want to train on all the data in the end. 
   * @return True if training has converged.
   */
  public boolean train (InstanceList training, int numIterationsPerProportion, double[] trainingProportions)
  {
    int trainingIteration = 0;
    assert (trainingProportions.length > 0);
    boolean converged = false;
    for (int i = 0; i < trainingProportions.length; i++) {
      assert (trainingProportions[i] <= 1.0);
      logger.info ("Training on "+trainingProportions[i]+"% of the data this round.");
      if (trainingProportions[i] == 1.0)
        converged = this.train (training, numIterationsPerProportion);
      else
        converged = this.train (training.split (new Random(1)
            new double[] {trainingProportions[i]1-trainingProportions[i]})[0], numIterationsPerProportion);
      trainingIteration += numIterationsPerProportion;
    }
    return converged;
  }

  public boolean trainWithFeatureInduction (InstanceList trainingData,
                                            InstanceList validationData, InstanceList testingData,
                                            TransducerEvaluator eval, int numIterations,
                                            int numIterationsBetweenFeatureInductions,
                                            int numFeatureInductions,
                                            int numFeaturesPerFeatureInduction,
                                            double trueLabelProbThreshold,
                                            boolean clusteredFeatureInduction,
                                            double[] trainingProportions)
  {
    return trainWithFeatureInduction (trainingData, validationData, testingData,
        eval, numIterations, numIterationsBetweenFeatureInductions,
        numFeatureInductions, numFeaturesPerFeatureInduction,
        trueLabelProbThreshold, clusteredFeatureInduction,
        trainingProportions, "exp");
  }

  /**
   * Train a CRF using feature induction to generate conjunctions of
   * features. Feature induction is run periodically during
   * training. The features are added to improve performance on the
   * mislabeled instances, with the specific scoring criterion given
   * by the {@link FeatureInducer} specified by <code>gainName</code>
   *
   * @param training The training Instances.
   * @param validation The validation Instances.
   * @param testing The testing instances.
   * @param eval For evaluation during training.
   * @param numIterations Maximum number of Maximizer iterations.
   * @param numIterationsBetweenFeatureInductions Number of maximizer
   * iterations between each call to the Feature Inducer.
   * @param numFeatureInductions Maximum number of rounds of feature
   * induction.
   * @param numFeaturesPerFeatureInduction Maximum number of features
   * to induce at each round of induction.
   * @param trueLabelProbThreshold If the model's probability of the
   * true Label of an Instance is less than this value, it is added as
   * an error instance to the {@link FeatureInducer}.
   * @param clusteredFeatureInduction If true, a separate {@link
   * FeatureInducer} is constructed for each label pair. This can
   * avoid inducing a disproportionate number of features for a single
   * label.
   * @param trainingProportions If non-null, train on increasingly
   * larger portions of the data (e.g. [0.2, 0.5, 1.0]. This can
   * sometimes speedup convergence.
   * @param gainName The type of {@link FeatureInducer} to use. One of
   * "exp", "grad", or "info" for {@link ExpGain}, {@link
   * GradientGain}, or {@link InfoGain}.
   * @return True if training has converged.
   */
  public boolean trainWithFeatureInduction (InstanceList trainingData,
                                            InstanceList validationData, InstanceList testingData,
                                            TransducerEvaluator eval, int numIterations,
                                            int numIterationsBetweenFeatureInductions,
                                            int numFeatureInductions,
                                            int numFeaturesPerFeatureInduction,
                                            double trueLabelProbThreshold,
                                            boolean clusteredFeatureInduction,
                                            double[] trainingProportions,
                                            String gainName)
  {
    int trainingIteration = 0;
    int numLabels = crf.outputAlphabet.size();

    crf.globalFeatureSelection = trainingData.getFeatureSelection();
    if (crf.globalFeatureSelection == null) {
      // Mask out all features; some will be added later by FeatureInducer.induceFeaturesFor(.)
      crf.globalFeatureSelection = new FeatureSelection (trainingData.getDataAlphabet());
      trainingData.setFeatureSelection (crf.globalFeatureSelection);
    }
    // TODO Careful!  If validationData and testingData get removed as arguments to this method
    // then the next two lines of work will have to be done somewhere.
    if (validationData != null) validationData.setFeatureSelection (crf.globalFeatureSelection);
    if (testingData != null) testingData.setFeatureSelection (crf.globalFeatureSelection);

    for (int featureInductionIteration = 0;
    featureInductionIteration < numFeatureInductions;
    featureInductionIteration++)
    {
      // Print out some feature information
      logger.info ("Feature induction iteration "+featureInductionIteration);

      // Train the CRF
      InstanceList theTrainingData = trainingData;
      if (trainingProportions != null && featureInductionIteration < trainingProportions.length) {
        logger.info ("Training on "+trainingProportions[featureInductionIteration]+"% of the data this round.");
        InstanceList[] sampledTrainingData = trainingData.split (new Random(1),
            new double[] {trainingProportions[featureInductionIteration],
          1-trainingProportions[featureInductionIteration]});
        theTrainingData = sampledTrainingData[0];
        theTrainingData.setFeatureSelection (crf.globalFeatureSelection); // xxx necessary?
            logger.info ("  which is "+theTrainingData.size()+" instances");
      }
      boolean converged = false;
      if (featureInductionIteration != 0)
        // Don't train until we have added some features
        converged = this.train (theTrainingData, numIterationsBetweenFeatureInductions);
      trainingIteration += numIterationsBetweenFeatureInductions;

      logger.info ("Starting feature induction with "+crf.inputAlphabet.size()+" features.");

      // Create the list of error tokens, for both unclustered and clustered feature induction
      InstanceList errorInstances = new InstanceList (trainingData.getDataAlphabet(),
          trainingData.getTargetAlphabet());
      // This errorInstances.featureSelection will get examined by FeatureInducer,
      // so it can know how to add "new" singleton features
      errorInstances.setFeatureSelection (crf.globalFeatureSelection);
      ArrayList errorLabelVectors = new ArrayList();
      InstanceList clusteredErrorInstances[][] = new InstanceList[numLabels][numLabels];
      ArrayList clusteredErrorLabelVectors[][] = new ArrayList[numLabels][numLabels];

      for (int i = 0; i < numLabels; i++)
        for (int j = 0; j < numLabels; j++) {
          clusteredErrorInstances[i][j] = new InstanceList (trainingData.getDataAlphabet(),
              trainingData.getTargetAlphabet());
          clusteredErrorInstances[i][j].setFeatureSelection (crf.globalFeatureSelection);
          clusteredErrorLabelVectors[i][j] = new ArrayList();
        }

      for (int i = 0; i < theTrainingData.size(); i++) {
        logger.info ("instance="+i);
        Instance instance = theTrainingData.get(i);
        Sequence input = (Sequence) instance.getData();
        Sequence trueOutput = (Sequence) instance.getTarget();
        assert (input.size() == trueOutput.size());
        SumLattice lattice =
          crf.sumLatticeFactory.newSumLattice (crf, input, (Sequence)null, (Transducer.Incrementor)null, 
              (LabelAlphabet)theTrainingData.getTargetAlphabet());
        int prevLabelIndex = 0;          // This will put extra error instances in this cluster
        for (int j = 0; j < trueOutput.size(); j++) {
          Label label = (Label) ((LabelSequence)trueOutput).getLabelAtPosition(j);
          assert (label != null);
          //System.out.println ("Instance="+i+" position="+j+" fv="+lattice.getLabelingAtPosition(j).toString(true));
          LabelVector latticeLabeling = lattice.getLabelingAtPosition(j);
          double trueLabelProb = latticeLabeling.value(label.getIndex());
          int labelIndex = latticeLabeling.getBestIndex();
          //System.out.println ("position="+j+" trueLabelProb="+trueLabelProb);
          if (trueLabelProb < trueLabelProbThreshold) {
            logger.info ("Adding error: instance="+i+" position="+j+" prtrue="+trueLabelProb+
                (label == latticeLabeling.getBestLabel() ? "  " : " *")+
                " truelabel="+label+
                " predlabel="+latticeLabeling.getBestLabel()+
                " fv="+((FeatureVector)input.get(j)).toString(true));
            errorInstances.add (input.get(j), label, null, null);
            errorLabelVectors.add (latticeLabeling);
            clusteredErrorInstances[prevLabelIndex][labelIndex].add (input.get(j), label, null, null);
            clusteredErrorLabelVectors[prevLabelIndex][labelIndex].add (latticeLabeling);
          }
          prevLabelIndex = labelIndex;
        }
      }
      logger.info ("Error instance list size = "+errorInstances.size());
      if (clusteredFeatureInduction) {
        FeatureInducer[][] klfi = new FeatureInducer[numLabels][numLabels];
        for (int i = 0; i < numLabels; i++) {
          for (int j = 0; j < numLabels; j++) {
            // Note that we may see some "impossible" transitions here (like O->I in a OIB model)
            // because we are using lattice gammas to get the predicted label, not Viterbi.
            // I don't believe this does any harm, and may do some good.
            logger.info ("Doing feature induction for "+
                crf.outputAlphabet.lookupObject(i)+" -> "+crf.outputAlphabet.lookupObject(j)+
                " with "+clusteredErrorInstances[i][j].size()+" instances");
            if (clusteredErrorInstances[i][j].size() < 20) {
              logger.info ("..skipping because only "+clusteredErrorInstances[i][j].size()+" instances.");
              continue;
            }
            int s = clusteredErrorLabelVectors[i][j].size();
            LabelVector[] lvs = new LabelVector[s];
            for (int k = 0; k < s; k++)
              lvs[k] = (LabelVector) clusteredErrorLabelVectors[i][j].get(k);
            RankedFeatureVector.Factory gainFactory = null;
            if (gainName.equals ("exp"))
              gainFactory = new ExpGain.Factory (lvs, gaussianPriorVariance);
            else if (gainName.equals("grad"))
              gainFactory =  new GradientGain.Factory (lvs);
            else if (gainName.equals("info"))
              gainFactory =  new InfoGain.Factory ();
            klfi[i][j] = new FeatureInducer (gainFactory,
                clusteredErrorInstances[i][j],
                numFeaturesPerFeatureInduction,
                2*numFeaturesPerFeatureInduction,
                2*numFeaturesPerFeatureInduction);
            crf.featureInducers.add(klfi[i][j]);
          }
        }
        for (int i = 0; i < numLabels; i++) {
          for (int j = 0; j < numLabels; j++) {
            logger.info ("Adding new induced features for "+
                crf.outputAlphabet.lookupObject(i)+" -> "+crf.outputAlphabet.lookupObject(j));
            if (klfi[i][j] == null) {
              logger.info ("...skipping because no features induced.");
              continue;
            }
            // Note that this adds features globally, but not on a per-transition basis
            klfi[i][j].induceFeaturesFor (trainingData, false, false);
            if (testingData != null) klfi[i][j].induceFeaturesFor (testingData, false, false);
          }
        }
        klfi = null;
      } else {
        int s = errorLabelVectors.size();
        LabelVector[] lvs = new LabelVector[s];
        for (int i = 0; i < s; i++)
          lvs[i] = (LabelVector) errorLabelVectors.get(i);

        RankedFeatureVector.Factory gainFactory = null;
        if (gainName.equals ("exp"))
          gainFactory = new ExpGain.Factory (lvs, gaussianPriorVariance);
        else if (gainName.equals("grad"))
          gainFactory =  new GradientGain.Factory (lvs);
        else if (gainName.equals("info"))
          gainFactory =  new InfoGain.Factory ();
        FeatureInducer klfi =
          new FeatureInducer (gainFactory,
              errorInstances,
              numFeaturesPerFeatureInduction,
              2*numFeaturesPerFeatureInduction,
              2*numFeaturesPerFeatureInduction);
        crf.featureInducers.add(klfi);
        // Note that this adds features globally, but not on a per-transition basis
        klfi.induceFeaturesFor (trainingData, false, false);
        if (testingData != null) klfi.induceFeaturesFor (testingData, false, false);
        logger.info ("CRF4 FeatureSelection now includes "+crf.globalFeatureSelection.cardinality()+" features");
        klfi = null;
      }
      // This is done in CRF4.train() anyway
      //this.setWeightsDimensionAsIn (trainingData);
      ////this.growWeightsDimensionToInputAlphabet ();
    }
    return this.train (trainingData, numIterations - trainingIteration);
  }
 
 
 
 
 
  public void setUseHyperbolicPrior (boolean f) { usingHyperbolicPrior = f; }
  public void setHyperbolicPriorSlope (double p) { hyperbolicPriorSlope = p; }
  public void setHyperbolicPriorSharpness (double p) { hyperbolicPriorSharpness = p; }
  public double getUseHyperbolicPriorSlope () { return hyperbolicPriorSlope; }
  public double getUseHyperbolicPriorSharpness () { return hyperbolicPriorSharpness; }
  public void setGaussianPriorVariance (double p) { gaussianPriorVariance = p; }
  public double getGaussianPriorVariance () { return gaussianPriorVariance; }
  //public int getDefaultFeatureIndex () { return defaultFeatureIndex;}
 
  public void setUseSparseWeights (boolean b) { useSparseWeights = b; }
  public boolean getUseSparseWeights () { return useSparseWeights; }

  /** Sets whether to use the 'some unsupported trick.' This trick is, if training a CRF
   * where some training has been done and sparse weights are used, to add a few weights
   * for feaures that do not occur in the tainig data.
   * <p>
   * This generally leads to better accuracy at only a  small memory cost.
   *
   * @param b Whether to use the trick
   */
  public void setUseSomeUnsupportedTrick (boolean b) { useSomeUnsupportedTrick = b; }
 

 


  // Serialization for CRFTrainerByLikelihood

  private static final long serialVersionUID = 1;
  private static final int CURRENT_SERIAL_VERSION = 1;
  static final int NULL_INTEGER = -1;

  /* Need to check for null pointers. */
  private void writeObject (ObjectOutputStream out) throws IOException {
    int i, size;
    out.writeInt (CURRENT_SERIAL_VERSION);
    //out.writeInt(defaultFeatureIndex);
    out.writeBoolean(usingHyperbolicPrior);
    out.writeDouble(gaussianPriorVariance);
    out.writeDouble(hyperbolicPriorSlope);
    out.writeDouble(hyperbolicPriorSharpness);
    out.writeInt(cachedGradientWeightsStamp);
    out.writeInt(cachedValueWeightsStamp);
    out.writeInt(cachedWeightsStructureStamp);
    out.writeBoolean(printGradient);
    out.writeBoolean (useSparseWeights);
    throw new IllegalStateException("Implementation not yet complete.");   
  }
 
  private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
    int size, i;
    int version = in.readInt ();
    //defaultFeatureIndex = in.readInt();
    usingHyperbolicPrior = in.readBoolean();
    gaussianPriorVariance = in.readDouble();
    hyperbolicPriorSlope = in.readDouble();
    hyperbolicPriorSharpness = in.readDouble();
    printGradient = in.readBoolean();
    useSparseWeights = in.readBoolean();
    throw new IllegalStateException("Implementation not yet complete.");   
  }
 
 
}
TOP

Related Classes of cc.mallet.fst.CRFTrainerByLabelLikelihood

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.