package cc.mallet.fst;
import java.util.logging.Logger;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
public class HMMTrainerByLikelihood extends TransducerTrainer {
private static Logger logger = MalletLogger
.getLogger(HMMTrainerByLikelihood.class.getName());
HMM hmm;
InstanceList trainingSet, unlabeledSet;
int iterationCount = 0;
boolean converged = false;
public HMMTrainerByLikelihood(HMM hmm) {
this.hmm = hmm;
}
@Override
public Transducer getTransducer() {
return hmm;
}
@Override
public int getIteration() {
return iterationCount;
}
@Override
public boolean isFinishedTraining() {
return converged;
}
@Override
public boolean train(InstanceList trainingSet, int numIterations) {
return train(trainingSet, null, numIterations);
}
public boolean train(InstanceList trainingSet, InstanceList unlabeledSet,
int numIterations) {
if (hmm.emissionEstimator == null)
hmm.reset();
converged = false;
double threshold = 0.001;
double logLikelihood = Double.NEGATIVE_INFINITY, prevLogLikelihood;
for (int iter = 0; iter < numIterations; iter++) {
prevLogLikelihood = logLikelihood;
logLikelihood = 0;
for (Instance inst : trainingSet) {
FeatureSequence input = (FeatureSequence) inst.getData();
FeatureSequence output = (FeatureSequence) inst.getTarget();
double obsLikelihood = new SumLatticeDefault(hmm, input,
output, hmm.new Incrementor()).getTotalWeight();
logLikelihood += obsLikelihood;
}
logger.info("getValue() (observed log-likelihood) = "
+ logLikelihood);
if (unlabeledSet != null) {
int numEx = 0;
for (Instance inst : unlabeledSet) {
numEx++;
if (numEx % 100 == 0) {
System.err.print(numEx + ". ");
System.err.flush();
}
FeatureSequence input = (FeatureSequence) inst.getData();
double hiddenLikelihood = new SumLatticeDefault(hmm, input,
null, hmm.new Incrementor()).getTotalWeight();
logLikelihood += hiddenLikelihood;
}
System.err.println();
}
logger.info("getValue() (log-likelihood) = " + logLikelihood);
hmm.estimate();
iterationCount++;
logger.info("HMM finished one iteration of maximizer, i=" + iter);
runEvaluators();
if (Math.abs(logLikelihood - prevLogLikelihood) < threshold) {
converged = true;
logger.info("HMM training has converged, i=" + iter);
break;
}
}
return converged;
}
}