Package cc.mallet.fst

Source Code of cc.mallet.fst.MEMMTrainer$MEMMOptimizableByLabelLikelihood

package cc.mallet.fst;

import java.util.BitSet;
import java.util.logging.Logger;

import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;

import cc.mallet.fst.MEMM.State;
import cc.mallet.fst.MEMM.TransitionIterator;

import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.Optimizer;

import cc.mallet.util.MalletLogger;

/**
* Trains and evaluates a {@link MEMM}.
*/
public class MEMMTrainer extends TransducerTrainer
{
  private static Logger logger = MalletLogger.getLogger(MEMMTrainer.class.getName());

  MEMM memm;
  private boolean gatheringTrainingData = false;
  // After training sets have been gathered in the states, record which
  //   InstanceList we've gathers, so we don't double-count instances.
  private InstanceList trainingGatheredFor;
  // gsc: user is supposed to set the weights manually, so this flag is not needed
//  boolean useSparseWeights = true;
  MEMMOptimizableByLabelLikelihood omemm;
 
  public MEMMTrainer (MEMM memm) {
    this.memm = memm;
  }

  public MEMMOptimizableByLabelLikelihood getOptimizableMEMM (InstanceList trainingSet) {
    return new MEMMOptimizableByLabelLikelihood (memm, trainingSet);
  }

//  public MEMMTrainer setUseSparseWeights (boolean f) { useSparseWeights = f;  return this; }

  /**
   * Trains a MEMM until convergence.
   */
  public boolean train (InstanceList training) {
    return train (training, Integer.MAX_VALUE);
  }

  /**
   * Trains a MEMM for specified number of iterations or until convergence whichever
   * occurs first; returns true if training converged within specified iterations.
   */
  public boolean train (InstanceList training, int numIterations)
  {
    if (numIterations <= 0)
      return false;
    assert (training.size() > 0);

    // Allocate space for the parameters, and place transition FeatureVectors in
    // per-source-state InstanceLists.
    // Here, gatheringTrainingSets will be true, and these methods will result
    // in new InstanceList's being created in each source state, and the FeatureVectors
    // of their outgoing transitions to be added to them as the data field in the Instances.
    if (trainingGatheredFor != training) {
      gatherTrainingSets (training);
    }
    // gsc: the user has to set the weights manually
//    if (useSparseWeights) {
//      memm.setWeightsDimensionAsIn (training, false);
//    } else {
//      memm.setWeightsDimensionDensely ();
//    }


    /*
    if (false) {
      // Expectation-based placement of training data would go here.
      for (int i = 0; i < training.size(); i++) {
        Instance instance = training.get(i);
        FeatureVectorSequence input = (FeatureVectorSequence) instance.getData();
        FeatureSequence output = (FeatureSequence) instance.getTarget();
        // Do it for the paths consistent with the labels...
        gatheringConstraints = true;
        new SumLatticeDefault (this, input, output, true);
        // ...and also do it for the paths selected by the current model (so we will get some negative weights)
        gatheringConstraints = false;
        if (this.someTrainingDone)
          // (do this once some training is done)
          new SumLatticeDefault (this, input, null, true);
      }
      gatheringWeightsPresent = false;
      SparseVector[] newWeights = new SparseVector[weights.length];
      for (int i = 0; i < weights.length; i++) {
        int numLocations = weightsPresent[i].cardinality ();
        logger.info ("CRF weights["+weightAlphabet.lookupObject(i)+"] num features = "+numLocations);
        int[] indices = new int[numLocations];
        for (int j = 0; j < numLocations; j++) {
          indices[j] = weightsPresent[i].nextSetBit (j == 0 ? 0 : indices[j-1]+1);
          //System.out.println ("CRF4 has index "+indices[j]);
        }
        newWeights[i] = new IndexedSparseVector (indices, new double[numLocations],
            numLocations, numLocations, false, false, false);
        newWeights[i].plusEqualsSparse (weights[i]);
      }
      weights = newWeights;
    }
    */

    omemm = new MEMMOptimizableByLabelLikelihood (memm, training);
    // Gather the constraints
    omemm.gatherExpectationsOrConstraints (true);
    Optimizer maximizer = new LimitedMemoryBFGS(omemm);
    int i;
//    boolean continueTraining = true;
    boolean converged = false;
    logger.info ("CRF about to train with "+numIterations+" iterations");
    for (i = 0; i < numIterations; i++) {
      try {
        converged = maximizer.optimize (1);
        logger.info ("CRF finished one iteration of maximizer, i="+i);
        runEvaluators();
      } catch (IllegalArgumentException e) {
        e.printStackTrace();
        logger.info ("Catching exception; saying converged.");
        converged = true;
      }
      if (converged) {
        logger.info ("CRF training has converged, i="+i);
        break;
      }
    }
    logger.info ("About to setTrainable(false)");
    return converged;
  }


  void gatherTrainingSets (InstanceList training)
  {
    if (trainingGatheredFor != null) {
      // It would be easy enough to support this, just go through all the states and set trainingSet to null.
      throw new UnsupportedOperationException ("Training with multiple sets not supported.");
    }

    trainingGatheredFor = training;
    for (int i = 0; i < training.size(); i++) {
      Instance instance = training.get(i);
      FeatureVectorSequence input = (FeatureVectorSequence) instance.getData();
      FeatureSequence output = (FeatureSequence) instance.getTarget();
      // Do it for the paths consistent with the labels...
      new SumLatticeDefault (memm, input, output, new Transducer.Incrementor() {
        public void incrementFinalState(Transducer.State s, double count) { }
        public void incrementInitialState(Transducer.State s, double count) { }
        public void incrementTransition(Transducer.TransitionIterator ti, double count) {
          MEMM.State source = (MEMM.State) ti.getSourceState();
          if (count != 0) {
            // Create the source state's trainingSet if it doesn't exist yet.
            if (source.trainingSet == null)
              // New InstanceList with a null pipe, because it doesn't do any processing of input.
              source.trainingSet = new InstanceList (null);
            // TODO We should make sure we don't add duplicates (through a second call to setWeightsDimenstion..!
            // TODO Note that when the training data still allows ambiguous outgoing transitions
            // this will add the same FV more than once to the source state's trainingSet, each
            // with >1.0 weight.  Not incorrect, but inefficient.
//            System.out.println ("From: "+source.getName()+" ---> "+getOutput()+" : "+getInput());
            source.trainingSet.add (new Instance(ti.getInput (), ti.getOutput (), null, null), count);
          }
        }
      });
    }
  }

  /**
   * Not implemented yet.
   *
   * @throws UnsupportedOperationException
   */
  public boolean train (InstanceList training, InstanceList validation, InstanceList testing,
      TransducerEvaluator eval, int numIterations,
      int numIterationsPerProportion,
      double[] trainingProportions)
  {
    throw new UnsupportedOperationException();
  }

  /**
   * Not implemented yet.
   *
   * @throws UnsupportedOperationException
   */
  public boolean trainWithFeatureInduction (InstanceList trainingData,
      InstanceList validationData, InstanceList testingData,
      TransducerEvaluator eval, int numIterations,
      int numIterationsBetweenFeatureInductions,
      int numFeatureInductions,
      int numFeaturesPerFeatureInduction,
      double trueLabelProbThreshold,
      boolean clusteredFeatureInduction,
      double[] trainingProportions,
      String gainName)
  {
    throw new UnsupportedOperationException();
  }


  public void printInstanceLists ()
  {
    for (int i = 0; i < memm.numStates(); i++) {
      State state = (State) memm.getState (i);
      InstanceList training = state.trainingSet;
      System.out.println ("State "+i+" : "+state.getName());
      if (training == null) {
        System.out.println ("No data");
        continue;
      }
      for (int j = 0; j < training.size(); j++) {
        Instance inst = training.get (j);
        System.out.println ("From : "+state.getName()+" To : "+inst.getTarget());
        System.out.println ("Instance "+j);
        System.out.println (inst.getTarget());
        System.out.println (inst.getData());
      }
    }
  }

  /**
   * Represents the terms in the objective function.
   * <p>
   * The weights are trained by matching the expectations of the model to the observations gathered from the data.
   */
  @SuppressWarnings("serial")
  public class MEMMOptimizableByLabelLikelihood extends CRFOptimizableByLabelLikelihood implements Optimizable.ByGradientValue
  {
    BitSet infiniteValues = null;

    protected MEMMOptimizableByLabelLikelihood (MEMM memm, InstanceList trainingData)
    {
      super (memm, trainingData);
      expectations = new CRF.Factors (memm);
      constraints = new CRF.Factors (memm);
    }
   
    // if constraints=false, return log probability of the training labels
    protected double gatherExpectationsOrConstraints (boolean gatherConstraints)
    {
      // Instance values must either always or never be included in
      // the total values; we can't just sometimes skip a value
      // because it is infinite, this throws off the total values.
      boolean initializingInfiniteValues = false;
      CRF.Factors factors = gatherConstraints ? constraints : expectations;
      CRF.Factors.Incrementor factorIncrementor = factors.new Incrementor ();

      if (infiniteValues == null) {
        infiniteValues = new BitSet ();
        initializingInfiniteValues = true;
      }

      double labelLogProb = 0;
      for (int i = 0; i < memm.numStates(); i++) {
        MEMM.State s = (State) memm.getState (i);

        if (s.trainingSet == null) {
          System.out.println ("Empty training set for state "+s.name);
          continue;
        }

        for (int j = 0; j < s.trainingSet.size(); j++) {
          Instance instance = s.trainingSet.get (j);
          double instWeight = s.trainingSet.getInstanceWeight (j);
          FeatureVector fv = (FeatureVector) instance.getData ();
          String labelString = (String) instance.getTarget ();
          TransitionIterator iter = new TransitionIterator (s, fv, gatherConstraints?labelString:null, memm);
          while (iter.hasNext ()) {
            // gsc
            iter.nextState(); // advance the iterator
//            State destination = (MEMM.State) iter.nextState();  // Just to advance the iterator
            double weight = iter.getWeight();
            factorIncrementor.incrementTransition(iter, Math.exp(weight) * instWeight);
            //iter.incrementCount (Math.exp(weight) * instWeight);
            if (!gatherConstraints && iter.getOutput() == labelString) {
              if (!Double.isInfinite (weight))
                labelLogProb += instWeight * weight; // xxx   ?????
              else {
                logger.warning ("State "+i+" transition "+j+" has infinite cost; skipping.");
                if (initializingInfiniteValues)
                  throw new IllegalStateException ("Infinite-cost transitions not yet supported"); //infiniteValues.set (j);
                else if (!infiniteValues.get(j))
                  throw new IllegalStateException ("Instance i used to have non-infinite value, "
                      +"but now it has infinite value.");
              }
            }
          }
        }
      }

      // Force initial & final weight parameters to 0 by making sure that
      // whether factor refers to expectation or constraint, they have the same value.
      for (int i = 0; i < memm.numStates(); i++) {
        factors.initialWeights[i] = 0.0;
        factors.finalWeights[i] = 0.0;
      }

      return labelLogProb;
    }

    // log probability of the training sequence labels, and fill in expectations[]
    protected double getExpectationValue ()
    {
      return gatherExpectationsOrConstraints (false);
    }

  }
 
 
  @Override
  public int getIteration() {
    // TODO Auto-generated method stub
    return 0;
  }

  @Override
  public Transducer getTransducer() {
    return memm;
  }

  @Override
  public boolean isFinishedTraining() {
    // TODO Auto-generated method stub
    return false;
  }


}
TOP

Related Classes of cc.mallet.fst.MEMMTrainer$MEMMOptimizableByLabelLikelihood

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.