double previousLogLikelihood = 0d;
logLikelihood = Double.NEGATIVE_INFINITY;
// Initialize model to fit to initial mixture.
fittedModel = new MixtureMultivariateNormalDistribution(initialMixture.getComponents());
while (numIterations++ <= maxIterations &&
Math.abs(previousLogLikelihood - logLikelihood) > threshold) {
previousLogLikelihood = logLikelihood;
double sumLogLikelihood = 0d;
// Mixture components
final List<Pair<Double, MultivariateNormalDistribution>> components
= fittedModel.getComponents();
// Weight and distribution of each component
final double[] weights = new double[k];
final MultivariateNormalDistribution[] mvns = new MultivariateNormalDistribution[k];
for (int j = 0; j < k; j++) {
weights[j] = components.get(j).getFirst();
mvns[j] = components.get(j).getSecond();
}
// E-step: compute the data dependent parameters of the expectation
// function.
// The percentage of row's total density between a row and a
// component
final double[][] gamma = new double[n][k];
// Sum of gamma for each component
final double[] gammaSums = new double[k];
// Sum of gamma times its row for each each component
final double[][] gammaDataProdSums = new double[k][numCols];
for (int i = 0; i < n; i++) {
final double rowDensity = fittedModel.density(data[i]);
sumLogLikelihood += Math.log(rowDensity);
for (int j = 0; j < k; j++) {
gamma[i][j] = weights[j] * mvns[j].density(data[i]) / rowDensity;
gammaSums[j] += gamma[i][j];
for (int col = 0; col < numCols; col++) {
gammaDataProdSums[j][col] += gamma[i][j] * data[i][col];
}
}
}
logLikelihood = sumLogLikelihood / n;
// M-step: compute the new parameters based on the expectation
// function.
final double[] newWeights = new double[k];
final double[][] newMeans = new double[k][numCols];
for (int j = 0; j < k; j++) {
newWeights[j] = gammaSums[j] / n;
for (int col = 0; col < numCols; col++) {
newMeans[j][col] = gammaDataProdSums[j][col] / gammaSums[j];
}
}
// Compute new covariance matrices
final RealMatrix[] newCovMats = new RealMatrix[k];
for (int j = 0; j < k; j++) {
newCovMats[j] = new Array2DRowRealMatrix(numCols, numCols);
}
for (int i = 0; i < n; i++) {
for (int j = 0; j < k; j++) {
final RealMatrix vec
= new Array2DRowRealMatrix(MathArrays.ebeSubtract(data[i], newMeans[j]));
final RealMatrix dataCov
= vec.multiply(vec.transpose()).scalarMultiply(gamma[i][j]);
newCovMats[j] = newCovMats[j].add(dataCov);
}
}
// Converting to arrays for use by fitted model
final double[][][] newCovMatArrays = new double[k][numCols][numCols];
for (int j = 0; j < k; j++) {
newCovMats[j] = newCovMats[j].scalarMultiply(1d / gammaSums[j]);
newCovMatArrays[j] = newCovMats[j].getData();
}
// Update current model
fittedModel = new MixtureMultivariateNormalDistribution(newWeights,
newMeans,
newCovMatArrays);
}
if (Math.abs(previousLogLikelihood - logLikelihood) > threshold) {