@Override
public GaussianArHpWfParticle update(
GaussianArHpWfParticle predState) {
final MultivariateGaussian posteriorState = predState.getState().clone();
final KalmanFilter kf = predState.getFilter().clone();
/*
* The following are the parameter learning updates;
* they can be done off-line, but we'll do them now.
* TODO FIXME check that the input/offset thing is working!
*/
final int xDim = posteriorState.getInputDimensionality();
final Matrix H = MatrixFactory.getDefault().createMatrix(xDim, xDim * 2);
H.setSubMatrix(0, 0, Ix);
H.setSubMatrix(0, xDim,
// x_{t-1}
MatrixFactory.getDefault().createDiagonal(predState.getStateSample()));
final InverseGammaDistribution sigma2SS = predState.getSigma2SS().clone();
// TODO FIXME matrix inverse!!
final Matrix postStatePrec = posteriorState.getCovarianceInverse().scale(
sigma2SS.getShape()/sigma2SS.getScale());
MultivariateStudentTDistribution postStateMarginal = new MultivariateStudentTDistribution(
sigma2SS.getShape(),
posteriorState.getMean(), postStatePrec);
final Vector postStateSample = postStateMarginal.sample(this.rng);
final Vector psiPriorSmpl = predState.getPsiSample();
// x_t
final Vector xHdiff = postStateSample.minus(H.times(psiPriorSmpl));
/*
* 1. Update the sigma2 sufficient stats.
*/
final double newN = sigma2SS.getShape() + 1d;
final double d = sigma2SS.getScale() + xHdiff.dotProduct(xHdiff);
sigma2SS.setScale(d);
sigma2SS.setShape(newN);
/*
* 2. Update psi sufficient stats. (i.e. offset and AR(1)).
*
* Note that we divide out the previous scale param, since
* we want to update A alone.
* TODO FIXME inverse! ewww.
*/
final Matrix priorAInv = predState.getPsiSS().getCovarianceInverse();
/*
* TODO FIXME: we don't have a generalized outer product, so we're only
* supporting the 1d case for now.
*/
final Vector Hv = H.convertToVector();
/*
* TODO FIXME inverse! ewww.
*/
final Matrix postAInv = priorAInv.plus(Hv.outerProduct(Hv)).inverse();
final Vector postPsiMean = postAInv.times(priorAInv.times(psiPriorSmpl).plus(
H.transpose().times(postStateSample)));
final MultivariateGaussian postPsi = predState.getPsiSS().clone();
postPsi.setMean(postPsiMean);
postPsi.setCovariance(postAInv);
final double sigma2Smpl = sigma2SS.sample(this.rng);
final GaussianArHpWfParticle postState =
new GaussianArHpWfParticle(kf, predState.getObservation(),
posteriorState, postStateSample,