// to take advantage of the sparsity of matrix. Although this
// will likely cause cache misses in Hprime, it's likely better
// than traversing _every_ cell in matrix, which may be in the
// millions.
long start = System.currentTimeMillis();
Matrix Hprime = new ArrayMatrix(H.rows(), H.columns());
double s = 0;
for (int k = 0; k < numDimensions; ++k) {
for (int n = 0; n < matrix.rows(); ++n) {
SparseDoubleVector v = matrix.getRowVector(n);
int[] nonZeros = v.getNonZeroIndices();
for (int m : nonZeros)
Hprime.set(k, m, Hprime.get(k, m) +
W.get(n,k) * v.get(m));
}
}
long end = System.currentTimeMillis();
LOG.info("Step 1: " + (end-start) + "ms");
// Compute WtW using standard matrix multiplication.
start = System.currentTimeMillis();
Matrix WtW = new ArrayMatrix(numDimensions, numDimensions);
for (int k = 0; k < numDimensions; ++k) {
for (int l = 0; l < numDimensions; ++l) {
double sum = 0;
for (int n = 0; n < W.rows(); ++n)
sum += W.get(n, k) * W.get(n, l);
WtW.set(k, l, sum);
}
}
end = System.currentTimeMillis();
LOG.info("Step 2: " + (end-start) + "ms");
// Compute the final update to H which is
// H <- H .* (WtA)./ (WtWH).
//
// Do this by computing each cell of WtWH and then let
// v <- Hprime[k, m]
// w <- H[k, m]
// sum <- WtWH[k, m]
// Hprime[k,m] <- w * v / sum
// This saves us from every storing WtWH in memory. We can
// store the updated values in Hprime because we only access
// each cell once, but we cannot use H itself since we need to
// maintain those values until every value of WtWH is computed.
start = System.currentTimeMillis();
for (int k = 0; k < numDimensions; ++k) {
for (int m = 0; m < H.columns(); ++m) {
double sum = 0;
for (int l = 0; l < numDimensions; ++l)
sum += WtW.get(k, l) * H.get(l, m);
double v = Hprime.get(k, m);
double w = H.get(k, m);
Hprime.set(k, m, w * v / sum);
}
}
end = System.currentTimeMillis();
LOG.info("Step 3: " + (end-start) + "ms");
// Update H with the new value.
H = Hprime;
}
LOG.info("Updating W matrix");
// Update the H matrix by holding the W matrix fixed for a few
// iterations.
for (int j = 0; j < innerLoop; ++j) {
// Compute Wprime, which is AHt. Since A is the left matrix, we
// can take advantage of it's sparsity using the standard matrix
// multiplication techniques.
long start = System.currentTimeMillis();
Matrix Wprime = new ArrayMatrix(W.rows(), W.columns());
for (int n = 0; n < matrix.rows(); ++ n) {
SparseDoubleVector v = matrix.getRowVector(n);
int[] nonZeros = v.getNonZeroIndices();
for (int k = 0; k < numDimensions; ++k) {
double sum = 0;
for (int m : nonZeros)
sum += v.get(m) * H.get(k, m);
Wprime.set(n, k, sum);
}
}
long end = System.currentTimeMillis();
LOG.info("Step 4: " + (end-start) + "ms");
// Compute HHt using standard matrix multiplication.
start = System.currentTimeMillis();
Matrix HHt = new ArrayMatrix(numDimensions, numDimensions);
for (int k = 0; k < numDimensions; ++k) {
for (int l = 0; l < numDimensions; ++l) {
double sum = 0;
for (int m = 0; m < H.columns(); ++m)
sum += H.get(k, m) * H.get(l, m);
HHt.set(k, l, sum);
}
}
end = System.currentTimeMillis();
LOG.info("Step 5: " + (end-start) + "ms");
// Compute W(HHt) and update Wprime using the following update:
// W <- W .* (AHt) ./ (W(HHt)).
//
// Do this by computing each cell of W(HHt) and then let
// v <- Wprime[n, k]
// w <- W[n, k]
// sum <- W(HHt)[n, k]
// This saves us from every storing W(HHt)in memory. We can
// store the updated values in Wprime because we only access
// each cell once, but we cannot use W itself since we need to
// maintain those values until every value of W(HHt) is
// computed.
start = System.currentTimeMillis();
for (int n = 0; n < W.rows(); ++n) {
for (int k = 0; k < W.columns(); ++k) {
double sum = 0;
for (int l = 0; l < numDimensions; ++l)
sum += W.get(n, l) * HHt.get(l, k);
double v = Wprime.get(n, k);
double w = W.get(n, k);
Wprime.set(n, k, w * v / sum);
}
}