/*
* Compute and output RS forward errors
*/
ResampleType rsResampleType = rsDistribution.getMaxValueKey().getResampleType();
Vector rsStateProbDiffs = computeStateDiffs(i, hmm.getNumStates(), rsDistribution, forwardResults);
String[] rsLine = {Integer.toString(k), Integer.toString(i), "p(x_t=0|y^t)",
rsResampleType.toString(),
Double.toString(rsStateProbDiffs.getElement(0))};
writer.writeNext(rsLine);
log.info("rsStateProbDiffs=" + rsStateProbDiffs);
if (i > numPreRuns) {
wfFilter.update(wfDistribution, obsState);
RingAccumulator<MutableDouble> pfAtTRate = new RingAccumulator<MutableDouble>();
for (P state : wfDistribution.getDomain()) {
final double err = (x == state.getClassId()) ? wfDistribution.getFraction(state) : 0d;
pfAtTRate.accumulate(new MutableDouble(err));
}
pfRunningRate.accumulate(new MutableDouble(pfAtTRate.getSum()));
ResampleType wfResampleType = wfDistribution.getMaxValueKey().getResampleType();
Vector wfStateProbDiffs = computeStateDiffs(i, hmm.getNumStates(), wfDistribution, forwardResults);
String[] wfLine = {Integer.toString(k), Integer.toString(i), "p(x_t=0|y^t)", "water-filling",
wfResampleType.toString(),
Double.toString(wfStateProbDiffs.getElement(0))};
writer.writeNext(wfLine);
log.info("wfStateProbDiffs=" + wfStateProbDiffs);
}
}
log.info("viterbiRate:" + viterbiRate.getMean());
log.info("pfRunningRate:" + pfRunningRate.getMean());
// RingAccumulator<MutableDouble> pfRate2 = new RingAccumulator<MutableDouble>();
// for (HMMTransitionState<Integer> state : distribution.getDomain()) {
// final double chainLogLikelihood = distribution.getLogFraction(state);
// RingAccumulator<MutableDouble> pfAtTimeRate = new RingAccumulator<MutableDouble>();
// for (int i = 0; i < T; i++) {
// final double x = DoubleMath.roundToInt(sample.getSecond().get(i), RoundingMode.HALF_EVEN);
// final double err;
// if (i < T - 1) {
// final WeightedValue<Integer> weighedState = state.getStateHistory().get(i);
// err = (x == weighedState.getValue() ? 1d : 0d);
// } else {
// err = (x == state.getState() ? 1d : 0d);
// }
// pfAtTimeRate.accumulate(new MutableDouble(err));
// }
// pfRate2.accumulate(new MutableDouble(pfAtTimeRate.getMean().doubleValue()
// * Math.exp(chainLogLikelihood)));
// }
// log.info("pfChainRate:" + pfRate2.getSum());
/*
* Loop through the smoothed trajectories and compute the
* class probabilities for each state.
*/
for (int t = 0; t < T; t++) {
CountedDataDistribution<Integer> wfStateSums = new CountedDataDistribution<Integer>(true);
CountedDataDistribution<Integer> rsStateSums = new CountedDataDistribution<Integer>(true);
if (t < T - 1) {
for (HmmTransitionState<T, H> state : wfDistribution.getDomain()) {
final WeightedValue<Integer> weighedState = state.getStateHistory().get(t);
wfStateSums.increment(weighedState.getValue(), weighedState.getWeight());
}
for (HmmTransitionState<T, H> state : rsDistribution.getDomain()) {
final WeightedValue<Integer> weighedState = state.getStateHistory().get(t);
rsStateSums.increment(weighedState.getValue(), weighedState.getWeight());
}
} else {
for (P state : wfDistribution.getDomain()) {
wfStateSums.adjust(state.getClassId(), wfDistribution.getLogFraction(state), wfDistribution.getCount(state));
}
for (P state : rsDistribution.getDomain()) {
rsStateSums.adjust(state.getClassId(), rsDistribution.getLogFraction(state), rsDistribution.getCount(state));
}
}
Vector wfStateProbDiffs = VectorFactory.getDefault().createVector(hmm.getNumStates());
Vector rsStateProbDiffs = VectorFactory.getDefault().createVector(hmm.getNumStates());
for (int j = 0; j < hmm.getNumStates(); j++) {
/*
* Sometimes all the probability goes to one class...
*/
final double wfStateProb;
if (!wfStateSums.getDomain().contains(j))
wfStateProb = 0d;
else
wfStateProb = wfStateSums.getFraction(j);
wfStateProbDiffs.setElement(j, gammas.get(t).getElement(j) - wfStateProb);
final double rsStateProb;
if (!rsStateSums.getDomain().contains(j))
rsStateProb = 0d;
else
rsStateProb = rsStateSums.getFraction(j);
rsStateProbDiffs.setElement(j, gammas.get(t).getElement(j) - rsStateProb);
}
String[] wfLine = {Integer.toString(k), Integer.toString(t), "p(x_t=0|y^T)", "water-filling",
wfDistribution.getMaxValueKey().getResampleType().toString(),
Double.toString(wfStateProbDiffs.getElement(0))};
writer.writeNext(wfLine);
String[] rsLine = {Integer.toString(k), Integer.toString(t), "p(x_t=0|y^T)", "resample",
rsDistribution.getMaxValueKey().getResampleType().toString(),
Double.toString(rsStateProbDiffs.getElement(0))};
writer.writeNext(rsLine);
}
}
writer.close();
}