/*
* Copyright 2011-2013, by Vladimir Kostyukov, Mike Anderson and Contributors.
*
* This file is adapted from the la4j project (http://la4j.org)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* You may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* Contributor(s): Julia Kostyukova
*
*/
package mikera.matrixx.decompose.impl.svd;
import mikera.matrixx.AMatrix;
import mikera.matrixx.Matrix;
import mikera.matrixx.algo.impl.Constants;
import mikera.matrixx.decompose.ISVDResult;
import mikera.matrixx.impl.DiagonalMatrix;
import mikera.vectorz.Vector;
/**
* This class implements a thin SVD decomposition of a matrix
*
* @author Mike
*/
public class ThinSVD {
public static ISVDResult decompose(AMatrix a) {
return decompose(Matrix.create(a));
}
public static ISVDResult decompose(Matrix matrix) {
return decomposeInternal(matrix.clone());
}
// internal decomposition function, destructively modifies input Matrix
private static ISVDResult decomposeInternal(Matrix a) {
int rc = a.rowCount();
int cc = a.columnCount();
if (rc < cc) { throw new IllegalArgumentException("Wrong matrix size: "
+ "rows < columns"); }
// TODO: confirm this is a "Thin SVD"
// as per Wikipedia
int n = Math.min(rc, cc); // this should always be cc??
Matrix u = Matrix.create(rc, n);
Vector s = Vector.createLength(cc);
Matrix v = Matrix.create(cc, cc);
Vector e = Vector.createLength(cc);
Vector work = Vector.createLength(rc);
int nct = Math.min(rc - 1, cc);
int nrt = Math.max(0, Math.min(cc - 2, rc));
for (int k = 0; k < Math.max(nct, nrt); k++) {
if (k < nct) {
for (int i = k; i < rc; i++) {
s.set(k, Math.hypot(s.get(k), a.get(i, k)));
}
if (Math.abs(s.get(k)) > Constants.EPS) {
if (a.get(k, k) < 0.0) {
s.set(k, -s.get(k));
}
for (int i = k; i < rc; i++) {
a.set(i, k, a.get(i, k) / (s.get(k)));
}
a.addAt(k, k, 1.0);
}
s.set(k, -s.get(k));
}
for (int j = k + 1; j < cc; j++) {
if ((k < nct) && (Math.abs(s.get(k)) > Constants.EPS)) {
double t = 0;
for (int i = k; i < rc; i++) {
t += a.get(i, k) * a.get(i, j);
}
t = -t / a.get(k, k);
for (int i = k; i < rc; i++) {
a.addAt(i, j, (t * a.get(i, k)));
}
}
e.set(j, a.get(k, j));
}
if (k < nct) {
for (int i = k; i < rc; i++) {
u.set(i, k, a.get(i, k));
}
}
if (k < nrt) {
e.set(k, 0);
for (int i = k + 1; i < cc; i++) {
e.set(k, Math.hypot(e.get(k), e.get(i)));
}
if (Math.abs(e.get(k)) > Constants.EPS) {
if (e.get(k + 1) < 0.0) {
e.set(k, -e.get(k));
}
for (int i = k + 1; i < cc; i++) {
e.set(i, e.get(i) / (e.get(k)));
}
e.addAt(k + 1, 1.0);
}
e.set(k, -e.get(k));
if ((k + 1 < rc) && (Math.abs(e.get(k)) > Constants.EPS)) {
for (int j = k + 1; j < cc; j++) {
for (int i = k + 1; i < rc; i++) {
work.addAt(i, (e.get(j) * a.get(i, j)));
}
}
for (int j = k + 1; j < cc; j++) {
double t = -e.get(j) / e.get(k + 1);
for (int i = k + 1; i < rc; i++) {
a.addAt(i, j, (t * work.get(i)));
}
}
}
for (int i = k + 1; i < cc; i++) {
v.set(i, k, e.get(i));
}
}
}
int p = Math.min(cc, rc + 1);
if (nct < cc) {
s.set(nct, a.get(nct, nct));
}
if (rc < p) {
s.set(p - 1, 0.0);
}
if (nrt + 1 < p) {
e.set(nrt, a.get(nrt, p - 1));
}
e.set(p - 1, 0.0);
for (int j = nct; j < n; j++) {
for (int i = 0; i < rc; i++) {
u.set(i, j, 0.0);
}
u.set(j, j, 1.0);
}
for (int k = nct - 1; k >= 0; k--) {
if (Math.abs(s.get(k)) > Constants.EPS) {
for (int j = k + 1; j < n; j++) {
double t = 0;
for (int i = k; i < rc; i++) {
t += u.get(i, k) * u.get(i, j);
}
t = -t / u.get(k, k);
for (int i = k; i < rc; i++) {
u.addAt(i, j, (t * u.get(i, k)));
}
}
for (int i = k; i < rc; i++) {
u.set(i, k, -u.get(i, k));
}
u.addAt(k, k, 1.0);
for (int i = 0; i < k - 1; i++) {
u.set(i, k, 0.0);
}
} else {
for (int i = 0; i < rc; i++) {
u.set(i, k, 0.0);
}
u.set(k, k, 1.0);
}
}
for (int k = n - 1; k >= 0; k--) {
if ((k < nrt) & (Math.abs(e.get(k)) > Constants.EPS)) {
for (int j = k + 1; j < n; j++) {
double t = 0;
for (int i = k + 1; i < cc; i++) {
t += v.get(i, k) * v.get(i, j);
}
t = -t / v.get(k + 1, k);
for (int i = k + 1; i < cc; i++) {
v.addAt(i, j, (t * v.get(i, k)));
}
}
}
for (int i = 0; i < cc; i++) {
v.set(i, k, 0.0);
}
v.set(k, k, 1.0);
}
int pp = p - 1;
int iter = 0;
double eps = Math.pow(2.0, -52.0);
double tiny = Math.pow(2.0, -966.0);
while (p > 0) {
int k, kase;
for (k = p - 2; k >= -1; k--) {
if (k == -1) break;
if (Math.abs(e.get(k)) <= tiny
+ eps
* (Math.abs(s.get(k)) + Math
.abs(s.get(k + 1)))) {
e.set(k, 0.0);
break;
}
}
if (k == p - 2) {
kase = 4;
} else {
int ks;
for (ks = p - 1; ks >= k; ks--) {
if (ks == k) break;
double t = (ks != p ? Math.abs(e.get(ks)) : 0.)
+ (ks != k + 1 ? Math.abs(e.get(ks - 1)) : 0.);
if (Math.abs(s.get(ks)) <= tiny + eps * t) {
s.set(ks, 0.0);
break;
}
}
if (ks == k) {
kase = 3;
} else if (ks == p - 1) {
kase = 1;
} else {
kase = 2;
k = ks;
}
}
k++;
switch (kase) {
case 1: {
double f = e.get(p - 2);
e.set(p - 2, 0.0);
for (int j = p - 2; j >= k; j--) {
double sj=s.unsafeGet(j);
double t = Math.hypot(sj, f);
double cs = sj / t;
double sn = f / t;
s.set(j, j, t);
if (j != k) {
f = -sn * e.get(j - 1);
e.set(j - 1, cs * e.get(j - 1));
}
for (int i = 0; i < cc; i++) {
t = cs * v.get(i, j) + sn * v.get(i, p - 1);
v.set(i, p - 1,
-sn * v.get(i, j) + cs * v.get(i, p - 1));
v.set(i, j, t);
}
}
}
break;
case 2: {
double f = e.get(k - 1);
e.set(k - 1, 0.0);
for (int j = k; j < p; j++) {
double sj=s.unsafeGet(j);
double t = Math.hypot(sj, f);
double cs = sj / t;
double sn = f / t;
s.set(j, j, t);
f = -sn * e.get(j);
e.set(j, cs * e.get(j));
for (int i = 0; i < rc; i++) {
t = cs * u.get(i, j) + sn * u.get(i, k - 1);
u.set(i, k - 1,
-sn * u.get(i, j) + cs * u.get(i, k - 1));
u.set(i, j, t);
}
}
}
break;
case 3: {
double scale = Math
.max(Math.max(Math.max(
Math.max(Math.abs(s.get(p - 1)),
Math.abs(s.get(p - 2))),
Math.abs(e.get(p - 2))), Math.abs(s.get(k))),
Math.abs(e.get(k)));
double sp = s.get(p - 1) / scale;
double spm1 = s.get(p - 2) / scale;
double epm1 = e.get(p - 2) / scale;
double sk = s.get(k) / scale;
double ek = e.get(k) / scale;
double b = ((spm1 + sp) * (spm1 - sp) + epm1 * epm1) / 2.0;
double c = (sp * epm1) * (sp * epm1);
double shift = 0.0;
if ((b != 0.0) | (c != 0.0)) {
shift = Math.sqrt(b * b + c);
if (b < 0.0) {
shift = -shift;
}
shift = c / (b + shift);
}
double f = (sk + sp) * (sk - sp) + shift;
double g = sk * ek;
for (int j = k; j < p - 1; j++) {
double t = Math.hypot(f, g);
double cs = f / t;
double sn = g / t;
if (j != k) {
e.set(j - 1, t);
}
double sj=s.unsafeGet(j);
f = cs * sj + sn * e.get(j);
e.set(j, cs * e.get(j) - sn * sj);
g = sn * s.get(j + 1);
s.set(j + 1, cs * s.get(j + 1));
for (int i = 0; i < cc; i++) {
t = cs * v.get(i, j) + sn * v.get(i, j + 1);
v.set(i, j + 1,
-sn * v.get(i, j) + cs * v.get(i, j + 1));
v.set(i, j, t);
}
t = Math.hypot(f, g);
cs = f / t;
sn = g / t;
s.set(j, t);
f = cs * e.get(j) + sn * s.get(j + 1);
s.set(j + 1,
-sn * e.get(j) + cs * s.get(j + 1));
g = sn * e.get(j + 1);
e.set(j + 1, e.get(j + 1) * (cs));
if (j < rc - 1) {
for (int i = 0; i < rc; i++) {
t = cs * u.get(i, j) + sn * u.get(i, j + 1);
u.set(i, j + 1,
-sn * u.get(i, j) + cs * u.get(i, j + 1));
u.set(i, j, t);
}
}
}
e.set(p - 2, f);
iter = iter + 1;
}
break;
case 4: {
double skk = s.get(k);
if (skk <= 0.0) {
s.set(k, -skk);
for (int i = 0; i <= pp; i++) {
v.set(i, k, -v.get(i, k));
}
}
while (k < pp) {
if (s.get(k) >= s.get(k + 1)) {
break;
}
double t = s.get(k);
s.set(k, s.get(k + 1));
s.set(k + 1, t);
if (k < cc - 1) {
v.swapColumns(k, k + 1);
// for (int i = 0; i < cc; i++) {
// t = v.get(i, k + 1);
// v.set(i, k + 1, v.get(i, k));
// v.set(i, k, t);
// }
}
if (k < rc - 1) {
u.swapColumns(k, k + 1);
// for (int i = 0; i < rc; i++) {
// t = u.get(i, k + 1);
// u.set(i, k + 1, u.get(i, k));
// u.set(i, k, t);
// }
}
k++;
}
iter = 0;
p--;
}
break;
}
}
return new SVDResult (u, DiagonalMatrix.wrap(s), v, s);
}
}