Package cc.mallet.fst.semi_supervised

Source Code of cc.mallet.fst.semi_supervised.CRFTrainerByLikelihoodAndGE

package cc.mallet.fst.semi_supervised;

import java.util.ArrayList;

import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFCacheStaleIndicator;
import cc.mallet.fst.CRFOptimizableByBatchLabelLikelihood;
import cc.mallet.fst.CRFOptimizableByGradientValues;
import cc.mallet.fst.CRFOptimizableByLabelLikelihood;
import cc.mallet.fst.CRFTrainerByLabelLikelihood;
import cc.mallet.fst.CRFTrainerByThreadedLabelLikelihood;
import cc.mallet.fst.ThreadedOptimizable;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.fst.semi_supervised.CRFOptimizableByGE;
import cc.mallet.fst.semi_supervised.StateLabelMap;
import cc.mallet.fst.semi_supervised.constraints.GEConstraint;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.InstanceList;

public class CRFTrainerByLikelihoodAndGE extends TransducerTrainer {

  private boolean initSupervised;
  private boolean converged;
  private double geWeight;
  private double gpv;
  private int supIterations;
  private int numThreads;
  private int iteration;
  private CRF crf;
  private ArrayList<GEConstraint> constraints;
  private StateLabelMap map;
 
  public CRFTrainerByLikelihoodAndGE(CRF crf, ArrayList<GEConstraint> constraints, StateLabelMap map) {
    this.crf = crf;
    this.constraints = constraints;
    this.map = map;
    this.iteration = 0;
    this.converged = false;
    this.geWeight = 1.0;
    this.initSupervised = false;
    this.gpv = 10.0;
    this.numThreads = 1;
    this.supIterations = Integer.MAX_VALUE;
  }
 
  public void setGEWeight(double weight) {
    this.geWeight = weight;
  }
 
  public void setGaussianPriorVariance(double gpv) {
    this.gpv = gpv;
  }
 
  public void setInitSupervised(boolean flag) {
    this.initSupervised = flag;
  }
 
  public void setSupervisedIterations(int iterations) {
    this.supIterations = iterations;
  }
 
  public void setNumThreads(int numThreads) {
    this.numThreads = numThreads;
  }
 
  @Override
  public Transducer getTransducer() {
    return crf;
  }

  @Override
  public int getIteration() {
    return iteration;
  }

  @Override
  public boolean isFinishedTraining() {
    return converged;
  }
 
  public boolean train(InstanceList trainingSet, InstanceList unlabeledSet, int numIterations) {
    System.err.println(trainingSet.size());
    System.err.println(unlabeledSet.size());
    if (initSupervised) {
     
      // train supervised
      if (numThreads == 1) {
        CRFTrainerByLabelLikelihood trainer = new CRFTrainerByLabelLikelihood(crf);
        trainer.setAddNoFactors(true);
        trainer.setGaussianPriorVariance(gpv);
        trainer.train(trainingSet,supIterations);
      }
      else {
        CRFTrainerByThreadedLabelLikelihood trainer = new CRFTrainerByThreadedLabelLikelihood(crf,numThreads);
        trainer.setAddNoFactors(true);
        trainer.setGaussianPriorVariance(gpv);
        trainer.train(trainingSet,supIterations);
        trainer.shutdown();
      }
      runEvaluators();
    }
   
    // train semi-supervised
    Optimizable.ByGradientValue optLikelihood;
    if (numThreads == 1) {
      optLikelihood = new CRFOptimizableByLabelLikelihood(crf,trainingSet);
      ((CRFOptimizableByLabelLikelihood)optLikelihood).setGaussianPriorVariance(gpv);
    }
    else {
      CRFOptimizableByBatchLabelLikelihood likelihood = new CRFOptimizableByBatchLabelLikelihood(crf,trainingSet,numThreads);
      optLikelihood = new ThreadedOptimizable(likelihood,trainingSet,crf.getParameters().getNumFactors(),
        new CRFCacheStaleIndicator(crf));
      likelihood.setGaussianPriorVariance(gpv);
    }

    CRFOptimizableByGE ge = new CRFOptimizableByGE(crf,constraints,unlabeledSet,map,numThreads,geWeight);
    // turn off the prior... it already appears above!
    ge.setGaussianPriorVariance(Double.POSITIVE_INFINITY);

    CRFOptimizableByGradientValues opt =
      new CRFOptimizableByGradientValues(crf,new Optimizable.ByGradientValue[] { optLikelihood, ge });
   
    LimitedMemoryBFGS optimizer = new LimitedMemoryBFGS(opt);

    try {
      converged = optimizer.optimize(numIterations);
    }
    catch (Exception e) {
      e.printStackTrace();
    }
   
    optimizer.reset();
    try {
      converged = optimizer.optimize(numIterations);
    }
    catch (Exception e) {
      e.printStackTrace();
    }
   
    if (numThreads > 1) {
      ((ThreadedOptimizable)optLikelihood).shutdown();
      ge.shutdown();
    }
   
    return converged;
  }
 
  @Override
  public boolean train(InstanceList trainingSet, int numIterations) {
    throw new RuntimeException("Must use train(InstanceList trainingSet, InstanceList unlabeledSet, int numIterations) instead");
  }
}
TOP

Related Classes of cc.mallet.fst.semi_supervised.CRFTrainerByLikelihoodAndGE

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.