Package cc.mallet.fst.semi_supervised

Source Code of cc.mallet.fst.semi_supervised.CRFOptimizableByGECriteria

package cc.mallet.fst.semi_supervised;

import java.io.Serializable;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.logging.Logger;

import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;

import cc.mallet.fst.CRF;
import cc.mallet.fst.SumLattice;
import cc.mallet.fst.SumLatticeDefault;
import cc.mallet.fst.Transducer;

import cc.mallet.optimize.Optimizable;

import cc.mallet.util.MalletLogger;

/**
* GE criteria for training a linear chain CRF.
*
* @author Gaurav Chandalia
* @author Gregory Druck
*/
public class CRFOptimizableByGECriteria implements Optimizable.ByGradientValue,
                                                    Serializable {
 
  private static double DEFAULT_GPV = Double.POSITIVE_INFINITY;
 
  private static final long serialVersionUID = 1;
  private static Logger logger =
      MalletLogger.getLogger(CRFOptimizableByGECriteria.class.getName());

  // unlabeled data
  protected InstanceList data;
 
  private int numThreads;
  private int cachedValueWeightsStamp = -1;
  private int cachedGradientWeightsStamp = -1;

  // the model
  protected CRF crf;

  // GE criteria
  protected GECriteria geCriteria;

  // gradient of GE criteria
  protected CRF.Factors gradient;
  protected Transducer.Incrementor incrementor;

  // GE value
  protected double cachedValue;
  // GE gradient (double[] form)
  protected double[] cachedGradient;
 
  protected double priorVariance;

  // thread handler used to create lattices in new threads and update the
  // gradient
  protected transient LatticeCreationExecutor geLatticeExecutor;
  protected transient ThreadPoolExecutor sumLatticeExecutor;

  /**
   * Initializes the structures.
   *
   * @param geCriteria GE criteria.
   * @param crf Model.
   * @param ilist Data used for training.
   */
  public CRFOptimizableByGECriteria(GECriteria geCriteria,
                                     CRF crf, InstanceList ilist,
                                     int numThreads) {
    this.data = ilist;
    this.crf = crf;
    this.geCriteria = geCriteria;

    // initialize
    gradient = new CRF.Factors(crf);
    incrementor = gradient.new Incrementor();
    cachedValue = 0.0;
    cachedGradient = new double[crf.getParameters().getNumFactors()];
    priorVariance = DEFAULT_GPV;
    geCriteria.setConstraintBits(data, 0, data.size());

    this.numThreads = numThreads;
    geLatticeExecutor = new LatticeCreationExecutor(numThreads);
    sumLatticeExecutor = (ThreadPoolExecutor)Executors.newFixedThreadPool(numThreads);
  }

  public void shutdown() {
    geLatticeExecutor.shutdown();
    sumLatticeExecutor.shutdown();
  }
 
  public void setGaussianPriorVariance(double priorVariance) {
    this.priorVariance = priorVariance;
  }
 
  public GECriteria getGECriteria() {
    return geCriteria;
  }

  /**
   * Initializes the gradient to zero and re-computes expectations
   * for a new iteration. <p>
   *
   * Also creates the executor to compute the gradient (if not done yet).
   */
  public void initialize(Map<Integer, SumLattice> lattices) {
    assert(gradient.structureMatches(crf.getParameters()));
    gradient.zero();
    geLatticeExecutor.initialize();

    // compute the expected prior distribution over labels for all feature-label
    // pairs (constraints)
    geCriteria.calculateExpectations(data, crf, lattices);
  }

  /**
   * Fills gradient from a single instance. <p>
   */
  public void computeGradient(FeatureVectorSequence input,
                                  double[][] gammas, double[][][] xis) {
    new GELattice(input, gammas, xis, crf, incrementor, geCriteria, false);
  }

  /**
   * Resets, computes and fills gradient from all instances. <p>
   *
   * Analogous to <tt>CRFOptimizableByLabelLikelihood.getExpectationValue<tt>.
   */
  public void computeGradient(Map<Integer, SumLattice> lattices) {
    this.initialize(lattices);
    logger.info("Updating gradient...");
    long time = System.currentTimeMillis();
    geLatticeExecutor.computeGradient(lattices);
    time = (System.currentTimeMillis() - time) / 1000;
    logger.info(String.valueOf(time) + " secs.");
  }
 
  public double getValue() {
    if (crf.getWeightsValueChangeStamp() != cachedValueWeightsStamp) {
      // The cached value is not up to date; it was calculated for a different set of CRF weights.
      cachedValueWeightsStamp = crf.getWeightsValueChangeStamp();

      // compute lattices in multiple threads
      ArrayList<Callable<Void>> handlers = new ArrayList<Callable<Void>>();
      // number of instances per subset
      int numInstancesSubset = data.size() / numThreads;
      // range of each subset
      int start = -1, end = -1;
      for (int i = 0; i < numThreads; ++i) {
        // get the indices of subset
        if (i == 0) {
          start = 0;
          end = start + numInstancesSubset;
        } else if (i == numThreads - 1) {
          start = end;
          end = data.size();
        } else {
          start = end;
          end = start + numInstancesSubset;
        }
        SumLatticeHandler handler = new SumLatticeHandler(start, end);
        handlers.add(handler);
      }
      // run tasks
      try {
        sumLatticeExecutor.invokeAll(handlers);
      } catch (InterruptedException e) {
        e.printStackTrace();
      }
     
      // combine lattices from multiple threads
      HashMap<Integer,SumLattice> lattices = new HashMap<Integer,SumLattice>();
      for (Callable<Void> handler : handlers) {
        lattices.putAll(((SumLatticeHandler)handler).getLattices());
      }

      computeGradient(lattices);
     
      cachedValue = geCriteria.getGEValue();

      if (priorVariance != Double.POSITIVE_INFINITY) {
        cachedValue += crf.getParameters().gaussianPrior(priorVariance);
      }
      assert(!Double.isNaN(cachedValue) && !Double.isInfinite(cachedValue))
          : "Likelihood due to GE criteria is NaN/Infinite";
      logger.info("getValue() (GE) = " + cachedValue);
    }
    return cachedValue;
  }

  public void getValueGradient(double[] buffer) {
    if (cachedGradientWeightsStamp != crf.getWeightsValueChangeStamp()) {
      cachedGradientWeightsStamp = crf.getWeightsValueChangeStamp();

      getValue();
      gradient.assertNotNaNOrInfinite();
      // fill up gradient
      cachedGradient = new double[cachedGradient.length];
      if (priorVariance != Double.POSITIVE_INFINITY) {
        gradient.plusEqualsGaussianPriorGradient(crf.getParameters(), -priorVariance);
      }
      gradient.getParameters(cachedGradient);
      MatrixOps.timesEquals(cachedGradient, -1.0);
    }
    System.arraycopy(cachedGradient, 0, buffer, 0, cachedGradient.length);
  }

  public void printGradientAbsNorm() {
    logger.info("absNorm(GE Gradient): " + gradient.getParametersAbsNorm());
  }

  // some get/set methods that have to be implemented
  public int getNumParameters() {
    return crf.getParameters().getNumFactors();
  }

  public void getParameters(double[] buffer) {
    crf.getParameters().getParameters(buffer);
  }

  public void setParameters(double[] buffer) {
    crf.getParameters().setParameters(buffer);
    crf.weightsValueChanged();
  }

  public double getParameter(int index) {
    return crf.getParameters().getParameter(index);
  }

  public void setParameter(int index, double value) {
    crf.getParameters().setParameter(index, value);
    crf.weightsValueChanged();
  }

  /**
   * Computes GE gradient. <p>
   *
   * Uses multi-threading, each thread does computations for a subset of the
   * data. To reduce sharing, each thread has its own CRF.Factors structure for
   * gradient. The final gradient is obtained by combining from all
   * subset gradients.
   */
  private class LatticeCreationExecutor {
    // key: instance index, value: Lattice
    private Map<Integer, SumLattice> lattices;

    // number of data subsets == number of threads
    private int numSubsets;
    // gradient obtained from subsets of data
    private List<FactorsIncrementorPair> mtGradient;

    // thread pool to create lattices and update gradient
    private ThreadPoolExecutor executor;

    // milliseconds
    public static final int SLEEP_TIME = 1000;

    // key: unique integer identifying a thread running forward-backward, the
    // respective thread sets the bit when its computation is over, range: (0,
    // numSubsets - 1)
    private BitSet threadIds;

    /**
     * Initializes the executor with specified number of threads.
     */
    public LatticeCreationExecutor(int numThreads) {
      lattices = null;
      numSubsets = numThreads;
      mtGradient = new ArrayList<FactorsIncrementorPair>(numSubsets);
      // initialize from the main gradients object
      for (int i = 0; i < numSubsets; ++i) {
        mtGradient.add(new FactorsIncrementorPair(gradient));
      }
      logger.info("Creating " + numSubsets +
                   " threads for updating gradient...");
      executor =
          (ThreadPoolExecutor) Executors.newFixedThreadPool(numSubsets);
      threadIds = new BitSet(numSubsets);
    }

    /**
     * Initializes each thread's gradient to zero.
     */
    public void initialize() {
      for (int i = 0; i < mtGradient.size(); ++i) {
        FactorsIncrementorPair exp = mtGradient.get(i);
        exp.subsetFactors.zero();
      }
    }

    /**
     * Computes lattices and fills gradient from all instances. <p>
     */
    public void computeGradient(Map<Integer, SumLattice> lattices) {
      this.lattices = lattices;
      threadIds.clear();

      // number of instances per subset
      int numInstancesSubset = data.size() / numSubsets;
      // range of each subset
      int start = -1, end = -1;
      for (int i = 0; i < numSubsets; ++i) {
        // get the indices of subset
        if (i == 0) {
          start = 0;
          end = start + numInstancesSubset;
        } else if (i == numSubsets - 1) {
          start = end;
          end = data.size();
        } else {
          start = end;
          end = start + numInstancesSubset;
        }
        executor.execute(new LatticeHandler(i, start, end));
      }

      // wait till all threads finish
      int numSetBits = 0;
      while (numSetBits != numSubsets) {
        synchronized(this) {
          numSetBits = threadIds.cardinality();
        }
        try {
          Thread.sleep(SLEEP_TIME);
        } catch (InterruptedException ie) {
          ie.printStackTrace();
          System.exit(1);
        }
      }

      // update main gradient
      this.updateGradient();
      lattices = null;
    }

    /**
     * Aggregates all subset gradients into the main gradient object.
     */
    private void updateGradient() {
      for (int i = 0; i < mtGradient.size(); ++i) {
        CRF.Factors subsetGradient = mtGradient.get(i).subsetFactors;
        gradient.plusEquals(subsetGradient, 1.0);
      }
    }
   
    public void shutdown() {
      executor.shutdown();
    }


    /**
     * Runs forward-backward for a subset of data in a new thread, uses
     * subset-specific incrementor.
     */
    private class LatticeHandler implements Runnable {
      // index to determine which incrementor to use
      private int index;

      // start, end indices of subset of data
      private int start;
      private int end;

      /**
       * Initializes the indices.
       */
      public LatticeHandler(int index, int startIndex, int endIndex) {
        this.index = index;
        this.start = startIndex;
        this.end = endIndex;
      }

      /**
       * Creates lattice, updates gradient.
       */
      public void run() {
        Transducer.Incrementor incrementor =
            mtGradient.get(index).subsetIncrementor;
        BitSet constraintBits = geCriteria.getConstraintBits();
        for (int i = start; i < end; ++i) {
          // skip if the instance doesn't have any constraints
          if (!constraintBits.get(i)) {
            continue;
          }
          FeatureVectorSequence fvs =
              (FeatureVectorSequence) data.get(i).getData();
          SumLattice lattice = lattices.get(i);
          assert(lattice != null)
              : "Lattice is null:: " + i + ", size: " + lattices.size();
          new GELattice(
              fvs, lattice.getGammas(), lattice.getXis(), crf, incrementor,
              geCriteria, false);
        }
        synchronized(LatticeCreationExecutor.this) {
          threadIds.set(index);
        }
      }
    }
  }
 
  /**
   * Runs forward-backward for a subset of data in a new thread, uses
   * subset-specific incrementor.
   */
  private class SumLatticeHandler implements Callable<Void> {
    // start, end indices of subset of data
    private int start;
    private int end;
    private HashMap<Integer,SumLatticeDefault> lattices;

    /**
     * Initializes the indices.
     */
    public SumLatticeHandler(int startIndex, int endIndex) {
      this.start = startIndex;
      this.end = endIndex;
      this.lattices = new HashMap<Integer,SumLatticeDefault>();
    }
   
    public HashMap<Integer,SumLatticeDefault> getLattices() {
      return lattices;
    }

    public Void call() throws Exception {
      BitSet constraintBits = geCriteria.getConstraintBits();
      for (int ii = start; ii < end; ii++) {
        if (!constraintBits.get(ii)) {
          continue;
        }
        FeatureVectorSequence fvs = (FeatureVectorSequence)data.get(ii).getData();
        lattices.put(ii, new SumLatticeDefault(crf,fvs,true));
      }
      return null;
    }
  }

 
  private class FactorsIncrementorPair {
    // model's Factors from a subset of data
    public CRF.Factors subsetFactors;
    public Transducer.Incrementor subsetIncrementor;

    /**
     * Initialize Factors using the structure of main Factors object.
     */
    public FactorsIncrementorPair(CRF.Factors other) {
      subsetFactors = new CRF.Factors(other);
      subsetIncrementor = subsetFactors.new Incrementor();
    }
  }
}
TOP

Related Classes of cc.mallet.fst.semi_supervised.CRFOptimizableByGECriteria

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.