/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.fst.semi_supervised.pr;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.logging.Logger;
import cc.mallet.fst.CRF;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.Transducer.TransitionIterator;
import cc.mallet.fst.semi_supervised.pr.constraints.PRConstraint;
import cc.mallet.optimize.Optimizable.ByGradientValue;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;
/**
* Optimizable for E-step/I-projection in Posterior Regularization (PR).
*
* @author Kedar Bellare
* @author Gregory Druck
*/
public class ConstraintsOptimizableByPR implements Serializable, ByGradientValue {
private static Logger logger = MalletLogger.getLogger(ConstraintsOptimizableByPR.class.getName());
private static final long serialVersionUID = 1;
protected boolean cacheStale;
protected int numParameters;
protected int numThreads;
protected InstanceList trainingSet;
protected double cachedValue = -123456789;
protected double[] cachedGradient;
protected CRF crf;
protected ThreadPoolExecutor executor;
protected double[][][][] cachedDots;
PRAuxiliaryModel model;
public ConstraintsOptimizableByPR(CRF crf, InstanceList ilist, PRAuxiliaryModel model) {
this(crf,ilist,model,1);
}
public ConstraintsOptimizableByPR(CRF crf, InstanceList ilist, PRAuxiliaryModel model, int numThreads) {
this.crf = crf;
this.trainingSet = ilist;
this.model = model;
this.numParameters = model.numParameters();
cachedGradient = new double[numParameters];
this.cacheStale = true;
this.numThreads = numThreads;
this.executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(numThreads);
cacheDotProducts();
}
public void cacheDotProducts() {
cachedDots = new double[trainingSet.size()][][][];
for (int i = 0; i < trainingSet.size(); i++) {
FeatureVectorSequence input = (FeatureVectorSequence)trainingSet.get(i).getData();
cachedDots[i] = new double[input.size()][crf.numStates()][crf.numStates()];
for (int j = 0; j < input.size(); j++) {
for (int k = 0; k < crf.numStates(); k++) {
for (int l = 0; l < crf.numStates(); l++) {
cachedDots[i][j][k][l] = Transducer.IMPOSSIBLE_WEIGHT;
}
}
}
for (int j = 0; j < input.size(); j++) {
for (int k = 0; k < crf.numStates(); k++) {
TransitionIterator iter = crf.getState(k).transitionIterator(input, j);
while (iter.hasNext()) {
int l = iter.next().getIndex();
cachedDots[i][j][k][l] = iter.getWeight();
}
}
}
}
}
public int getNumParameters() {
return numParameters;
}
public void getParameters(double[] params) {
model.getParameters(params);
}
public double getParameter(int index) {
return model.getParameter(index);
}
public void setParameters(double[] params) {
cacheStale = true;
model.setParameters(params);
}
public void setParameter(int index, double value) {
cacheStale = true;
model.setParameter(index, value);
}
protected double getExpectationValue() {
model.zeroExpectations();
// updating tasks
ArrayList<Callable<Double>> tasks = new ArrayList<Callable<Double>>();
int increment = trainingSet.size() / numThreads;
int start = 0;
int end = increment;
for (int taskIndex = 0; taskIndex < numThreads; taskIndex++) {
tasks.add(new ExpectationTask(start,end,model.copy()));
start = end;
if (taskIndex == numThreads - 2) {
end = trainingSet.size();
}
else {
end = start + increment;
}
}
double value = 0;
try {
List<Future<Double>> results = executor.invokeAll(tasks);
// compute value
for (Future<Double> f : results) {
try {
value += f.get();
} catch (ExecutionException ee) {
ee.printStackTrace();
}
}
} catch (InterruptedException ie) {
ie.printStackTrace();
}
// combine results
combine(model,tasks);
// mu*b - w*||mu||^2
value += model.getValue();
return value;
}
/**
* Returns the log probability of the training sequence labels and the prior
* over parameters.
*/
public double getValue() {
if (cacheStale) {
cachedValue = getExpectationValue();
model.getValueGradient(cachedGradient);
cacheStale = false;
logger.info("getValue (auxiliary distribution) = " + cachedValue);
}
return cachedValue;
}
public double getCompleteValueContribution() {
if (cacheStale) {
getValue();
}
double value = model.getCompleteValueContribution();
return value;
}
public void getValueGradient(double[] buffer) {
if (cacheStale) {
getValue();
}
System.arraycopy(cachedGradient, 0, buffer, 0, cachedGradient.length);
}
private void combine(PRAuxiliaryModel orig, ArrayList<Callable<Double>> tasks) {
for (int i = 0; i < tasks.size(); i++) {
ExpectationTask task = (ExpectationTask)tasks.get(i);
PRAuxiliaryModel model = task.getModelCopy();
for (int ci = 0; ci < model.numConstraints(); ci++) {
PRConstraint origConstraint = orig.getConstraint(ci);
PRConstraint copyConstraint = model.getConstraint(ci);
double[] expectation = new double[origConstraint.numDimensions()];
copyConstraint.getExpectations(expectation);
origConstraint.addExpectations(expectation);
}
}
}
public void shutdown() {
executor.shutdown();
}
public double[][][][] getCachedDots() {
return cachedDots;
}
public PRAuxiliaryModel getAuxModel() {
return model;
}
private class ExpectationTask implements Callable<Double> {
private int start;
private int end;
private PRAuxiliaryModel modelCopy;
public ExpectationTask(int start, int end, PRAuxiliaryModel modelCopy) {
this.start = start;
this.end = end;
this.modelCopy = modelCopy;
}
public Double call() throws Exception {
double value = 0;
for (int ii = start; ii < end; ii++) {
Instance inst = trainingSet.get(ii);
Sequence input = (Sequence) inst.getData();
// logZ
value -= new SumLatticePR(crf, ii, input, null, modelCopy, cachedDots[ii], true, null, null, false).getTotalWeight();
}
return value;
}
public PRAuxiliaryModel getModelCopy() {
return modelCopy;
}
}
}