package mikera.matrixx;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import mikera.arrayz.Arrayz;
import mikera.arrayz.INDArray;
import mikera.arrayz.NDArray;
import mikera.arrayz.TestArrays;
import mikera.indexz.Index;
import mikera.indexz.Indexz;
import mikera.matrixx.impl.BandedMatrix;
import mikera.matrixx.impl.BlockDiagonalMatrix;
import mikera.matrixx.impl.ColumnMatrix;
import mikera.matrixx.impl.IdentityMatrix;
import mikera.matrixx.impl.ImmutableMatrix;
import mikera.matrixx.impl.LowerTriangularMatrix;
import mikera.matrixx.impl.PermutationMatrix;
import mikera.matrixx.impl.PermutedMatrix;
import mikera.matrixx.impl.QuadtreeMatrix;
import mikera.matrixx.impl.RowMatrix;
import mikera.matrixx.impl.ScalarMatrix;
import mikera.matrixx.impl.SparseColumnMatrix;
import mikera.matrixx.impl.SparseRowMatrix;
import mikera.matrixx.impl.StridedMatrix;
import mikera.matrixx.impl.SubsetMatrix;
import mikera.matrixx.impl.UpperTriangularMatrix;
import mikera.matrixx.impl.VectorMatrixM3;
import mikera.matrixx.impl.VectorMatrixMN;
import mikera.matrixx.impl.ZeroMatrix;
import mikera.transformz.MatrixTransform;
import mikera.transformz.TestTransformz;
import mikera.vectorz.AVector;
import mikera.vectorz.Vector;
import mikera.vectorz.Vector3;
import mikera.vectorz.Vectorz;
import mikera.vectorz.impl.AxisVector;
import mikera.vectorz.ops.Constant;
import org.junit.Test;
public class TestMatrices {
private void doMutationTest(AMatrix m) {
if (!m.isFullyMutable()) return;
m=m.exactClone();
AMatrix m2=m.exactClone();
assertEquals(m,m2);
int rc=m.rowCount();
int cc=m.columnCount();
for (int i=0; i<rc; i++) {
for (int j=0; j<cc; j++) {
m2.set(i,j,m2.get(i,j)+1.3);
assertEquals(m2.get(i,j),m2.getRow(i).get(j),0.0);
assertNotSame(m.get(i,j),m2.get(i, j));
}
}
}
private void doTransposeTest(AMatrix m) {
AMatrix m2=m.clone();
m2=m2.getTranspose();
assertTrue(m.equalsTranspose(m2));
assertTrue(m2.equalsTranspose(m));
assertEquals(m2, m.getTranspose());
assertEquals(m2, m.getTransposeView());
assertEquals(m2, m.toMatrixTranspose());
m2=m2.getTranspose();
assertEquals(m,m2);
assertEquals(m.getTranspose().innerProduct(m),m.transposeInnerProduct(m));
}
private void doSquareTransposeTest(AMatrix m) {
AMatrix m2=m.clone();
m2.transposeInPlace();
// two different kinds of transpose should produce same result
AMatrix tm=m.getTranspose();
assertEquals(tm,m2);
assertEquals(m.trace(),tm.trace(),0.0);
m2.transposeInPlace();
assertEquals(m,m2);
}
private void doNotSquareTests(AMatrix m) {
try {
m.getLeadingDiagonal();
fail();
} catch (Throwable t) {
// OK
}
try {
m.checkSquare();
fail();
} catch (Throwable t) {
// OK
}
try {
m.inverse();
fail();
} catch (Throwable t) {
// OK
}
try {
m.determinant();
fail();
} catch (Throwable t) {
// OK
}
}
private void doLeadingDiagonalTests(AMatrix m) {
int dims=m.rowCount();
assertEquals(dims,m.columnCount());
AVector v=m.getLeadingDiagonal();
for (int i=0; i<dims; i++) {
assertEquals(v.get(i),m.get(i, i),0.0);
}
}
private void doTraceTests(AMatrix m) {
assertEquals(m.clone().trace(), m.trace(),0.00001);
}
private void doMaybeSquareTests(AMatrix m) {
if (!m.isSquare()) {
assertNotEquals(m.rowCount(),m.columnCount());
doNotSquareTests(m);
} else {
assertEquals(m.rowCount(),m.columnCount());
assertEquals(m.rowCount(),m.checkSquare());
doSquareTransposeTest(m);
doTraceTests(m);
doLeadingDiagonalTests(m);
}
}
private void doSwapTest(AMatrix m) {
if ((m.rowCount()<2)||(m.columnCount()<2)) return;
m=m.clone();
AMatrix m2=m.clone();
m2.swapRows(0, 1);
assert(!m2.equals(m));
m2.swapRows(0, 1);
assert(m2.equals(m));
m2.swapColumns(0, 1);
assert(!m2.equals(m));
m2.swapColumns(0, 1);
assert(m2.equals(m));
}
void doRandomTests(AMatrix m) {
m=m.clone();
Matrixx.fillRandomValues(m);
doSwapTest(m);
doMutationTest(m);
}
private void doAddTest(AMatrix m) {
if (!m.isFullyMutable()) return;
AMatrix m2=m.exactClone();
AMatrix m3=m.exactClone();
m2.add(m);
m2.add(m);
m3.addMultiple(m, 2.0);
assertTrue(m2.epsilonEquals(m3));
}
private void doCloneSafeTest(AMatrix m) {
if ((m.rowCount()==0)||(m.columnCount()==0)) return;
AMatrix m2=m.clone();
m2.set(0,0,Math.PI);
assertNotSame(m.get(0,0),m2.get(0,0));
}
private void doSubMatrixTest(AMatrix m) {
int rc=m.rowCount();
int cc=m.columnCount();
if ((rc<=1)||(cc<=1)) return;
AMatrix sm=m.subMatrix(1, rc-1, 1, cc-1);
assertEquals(rc-1,sm.rowCount());
assertEquals(cc-1,sm.columnCount());
for (int i=1; i<rc; i++) {
for (int j=1; j<cc; j++) {
assertEquals(m.get(i,j),sm.get(i-1,j-1),0.0);
}
}
}
private void doBoundsTest(AMatrix m) {
int rc=m.rowCount();
int cc=m.columnCount();
try {
m.get(-1,-1);
fail();
} catch (IndexOutOfBoundsException a) {/* OK */}
try {
m.get(rc,cc);
fail();
} catch (IndexOutOfBoundsException a) {/* OK */}
try {
m.get(0,-1);
fail();
} catch (IndexOutOfBoundsException a) {/* OK */}
try {
m.get(0,cc);
fail();
} catch (IndexOutOfBoundsException a) {/* OK */}
try {
m.get(-1,0);
fail();
} catch (IndexOutOfBoundsException a) {/* OK */}
try {
m.get(rc,0);
fail();
} catch (IndexOutOfBoundsException a) {/* OK */}
if (m.isFullyMutable()) {
m=m.exactClone();
try {
m.set(-1,-1,1);
fail();
} catch (IndexOutOfBoundsException a) {/* OK */}
try {
m.set(rc,cc,1);
fail();
} catch (IndexOutOfBoundsException a) {/* OK */}
}
}
private void doRowColumnTests(AMatrix m) {
assertEquals(m.rowCount(),new MatrixTransform(m).outputDimensions());
assertEquals(m.columnCount(),new MatrixTransform(m).inputDimensions());
m=m.clone();
int rc=m.rowCount();
int cc=m.columnCount();
if ((rc==0)||(cc==0)) return;
for (int i=0; i<rc; i++) {
AVector row=m.getRow(i);
assertEquals(row,m.cloneRow(i));
assertEquals(cc,row.length());
}
for (int i=0; i<cc; i++) {
AVector col=m.getColumn(i);
assertEquals(rc,col.length());
}
AVector row=m.getRowView(0);
AVector col=m.getColumnView(0);
row.set(0,1.77);
assertEquals(1.77,m.get(0,0),0.0);
col.set(0,0.23);
assertEquals(0.23,m.get(0,0),0.0);
AVector all=m.asVector();
assertEquals(m.rowCount()*m.columnCount(),all.length());
all.set(0,0.78);
assertEquals(0.78,row.get(0),0.0);
assertEquals(0.78,col.get(0),0.0);
new TestArrays().testArray(row);
new TestArrays().testArray(col);
new TestArrays().testArray(all);
}
private void doBandTests(AMatrix m) {
int rc=m.rowCount();
int cc=m.columnCount();
int bandMin=-m.rowCount();
int bandMax=m.columnCount();
Matrix mc=m.toMatrix();
// TODO: consider what to do about out-of-range bands?
//assertNull(m.getBand(bandMin-1));
//assertNull(m.getBand(bandMax+1));
for (int i=bandMin; i<=bandMax; i++) {
AVector b=m.getBand(i);
assertEquals(b.length(),m.bandLength(i));
assertEquals(mc.getBand(i),b);
// wrapped band test
if (rc*cc!=0) assertEquals(Math.max(rc, cc),m.getBandWrapped(i).length());
}
assertEquals(m,BandedMatrix.create(m));
}
private void doVectorTest(AMatrix m) {
m=m.clone();
AVector v=m.asVector();
assertEquals(v,m.toVector());
assertEquals(m.elementSum(),v.elementSum(),0.000001);
AMatrix m2=Matrixx.createFromVector(v, m.rowCount(), m.columnCount());
assertEquals(m,m2);
assertEquals(v,m2.asVector());
if (v.length()>0) {
v.set(0,10.0);
assertEquals(10.0,m.get(0,0),0.0);
}
}
void doParseTest(AMatrix m) {
if (m.rowCount()==0) return;
assertEquals(m,Matrixx.parse(m.toString()));
}
void doBigComposeTest(AMatrix m) {
AMatrix a=Matrixx.createRandomSquareMatrix(m.rowCount());
AMatrix b=Matrixx.createRandomSquareMatrix(m.columnCount());
AMatrix mb=m.compose(b);
AMatrix amb=a.compose(mb);
AVector v=Vectorz.createUniformRandomVector(b.columnCount());
AVector ambv=a.transform(m.transform(b.transform(v)));
assertTrue(amb.transform(v).epsilonEquals(ambv));
}
private void testApplyOp(AMatrix m) {
if (!m.isFullyMutable()) return;
AMatrix c=m.exactClone();
AMatrix d=m.exactClone();
c.asVector().fill(5.0);
d.applyOp(Constant.create(5.0));
assertEquals(c,d);
assertTrue(d.epsilonEquals(c));
}
private void testExactClone(AMatrix m) {
AMatrix c=m.exactClone();
AMatrix d=m.clone();
Matrix mc=m.toMatrix();
assertEquals(m,mc);
assertEquals(m,c);
assertEquals(m,d);
}
private void testSparseClone(AMatrix m) {
AMatrix s=Matrixx.createSparse(m);
assertEquals(m,s);
AMatrix s2=Matrixx.createSparseRows(m);
assertEquals(m,s2);
}
void doScaleTest(AMatrix m) {
if(!m.isFullyMutable()) return;
AMatrix m1=m.exactClone();
AMatrix m2=m.clone();
m1.scale(2.0);
m2.add(m);
assertTrue(m1.epsilonEquals(m2));
m1.scale(0.0);
assertTrue(m1.isZero());
}
private void doMulTest(AMatrix m) {
AVector v=Vectorz.newVector(m.columnCount());
AVector t=Vectorz.newVector(m.rowCount());
m.transform(v, t);
AVector t2=m.transform(v);
assertEquals(t,t2);
assertEquals(t,m.innerProduct(v));
}
private void doNDArrayTest(AMatrix m) {
NDArray a=NDArray.newArray(m.getShape());
a.set(m);
int rc=m.rowCount();
int cc=m.columnCount();
for (int i=0; i<rc; i++) {
assertEquals(m.getRow(i),a.slice(i));
}
for (int i=0; i<cc; i++) {
assertEquals(m.getColumn(i),a.slice(1,i));
}
}
private void doTriangularTests(AMatrix m) {
boolean sym=m.isSymmetric();
boolean diag=m.isDiagonal();
boolean uppt=m.isUpperTriangular();
boolean lowt=m.isLowerTriangular();
if (diag) {
assertTrue(sym);
assertTrue(uppt);
assertTrue(lowt);
}
if (sym) {
assertTrue(m.isSquare());
assertEquals(m,m.getTranspose());
}
if (sym&uppt&lowt) {
assertTrue(diag);
}
if (uppt) {
assertTrue(m.getTranspose().isLowerTriangular());
}
if (lowt) {
assertTrue(m.getTranspose().isUpperTriangular());
}
}
void doGenericTests(AMatrix m) {
m.validate();
testApplyOp(m);
testExactClone(m);
testSparseClone(m);
doTransposeTest(m);
doTriangularTests(m);
doVectorTest(m);
doParseTest(m);
doNDArrayTest(m);
doScaleTest(m);
doBoundsTest(m);
doMulTest(m);
doAddTest(m);
doRowColumnTests(m);
doBandTests(m);
doCloneSafeTest(m);
doMutationTest(m);
doMaybeSquareTests(m);
doRandomTests(m);
doBigComposeTest(m);
doSubMatrixTest(m);
TestTransformz.doITransformTests(new MatrixTransform(m));
new TestArrays().testArray(m);
}
@Test public void g_ZeroMatrix() {
// zero matrices
doGenericTests(Matrixx.createImmutableZeroMatrix(3, 2));
doGenericTests(Matrixx.createImmutableZeroMatrix(5, 5));
doGenericTests(Matrixx.createImmutableZeroMatrix(3, 3).reorder(new int[] {2,0,1}));
doGenericTests(Matrixx.createImmutableZeroMatrix(1, 7));
doGenericTests(Matrixx.createImmutableZeroMatrix(1, 0));
doGenericTests(Matrixx.createImmutableZeroMatrix(0, 1));
doGenericTests(Matrixx.createImmutableZeroMatrix(0, 0));
}
@Test public void g_PrimitiveMatrix() {
// specialised 3x3 matrix
Matrix33 m33=new Matrix33();
randomise(m33);
doGenericTests(m33);
// specialised 2*2 matrix
Matrix22 m22=new Matrix22();
randomise(m22);
doGenericTests(m22);
// specialised 1*1 matrix
Matrix11 m11=new Matrix11();
randomise(m11);
doGenericTests(m11);
}
@Test public void g_PermutedMatrix() {
// general M*N matrix
VectorMatrixMN mmn=new VectorMatrixMN(6 ,7);
// permuted matrix
PermutedMatrix pmm=new PermutedMatrix(mmn,
Indexz.createRandomPermutation(mmn.rowCount()),
Indexz.createRandomPermutation(mmn.columnCount()));
doGenericTests(pmm);
doGenericTests(pmm.subMatrix(1, 4, 1, 5));
}
@Test public void g_VectorMatrix() {
// specialised Mx3 matrix
VectorMatrixM3 mm3=new VectorMatrixM3(10);
Arrayz.fillNormal(mm3,101);
doGenericTests(mm3);
doGenericTests(mm3.subMatrix(1, 1, 1, 1));
// general M*N matrix
VectorMatrixMN mmn=new VectorMatrixMN(6 ,7);
randomise(mmn);
doGenericTests(mmn);
doGenericTests(mmn.subMatrix(1, 4, 1, 5));
// small 2*2 matrix
mmn=new VectorMatrixMN(2,2);
doGenericTests(mmn);
// 1x0 matrix should work
mmn=new VectorMatrixMN(1 ,0);
doGenericTests(mmn);
// square M*M matrix
mmn=new VectorMatrixMN(6 ,6);
doGenericTests(mmn);
}
private static long seed;
private void randomise(INDArray m) {
Arrayz.fillNormal(m, seed++);
}
@Test public void g_Matrix() {
// general M*N matrix
VectorMatrixMN mmn=new VectorMatrixMN(6 ,7);
randomise(mmn);
Matrix am1=new Matrix(Matrix33.createScaleMatrix(4.0));
doGenericTests(am1);
Matrix am2=new Matrix(mmn);
doGenericTests(am2);
}
@Test public void g_DenseColumnMatrix() {
Matrix am1=new Matrix(Matrix33.createScaleMatrix(Math.PI));
doGenericTests(am1.getTranspose());
Matrix am2=new Matrix(new VectorMatrixMN(6 ,7));
randomise(am2);
doGenericTests(am2.getTranspose());
}
@Test public void g_SubsetMatrix() {
doGenericTests(SubsetMatrix.create(Index.of(0,1,2),3));
doGenericTests(SubsetMatrix.create(Index.of(0,1,3,10),12));
doGenericTests(SubsetMatrix.create(Index.of(0,3,2,1),4));
}
@Test public void g_ScalarMatrix() {
doGenericTests(ScalarMatrix.create(1,3.0));
doGenericTests(ScalarMatrix.create(3,Math.E));
doGenericTests(ScalarMatrix.create(5,0.0));
doGenericTests(ScalarMatrix.create(5,2.0).subMatrix(1, 3, 1, 3));
}
@Test public void g_RowMatrix() {
doGenericTests(new RowMatrix(Vector.of(1,2,3,4)));
doGenericTests(new RowMatrix(Vector3.of(1,2,3)));
}
@Test public void g_ColumnMatrix() {
doGenericTests(new ColumnMatrix(Vector.of(1,2,3,4)));
doGenericTests(new ColumnMatrix(Vector3.of(1,2,3)));
}
@Test public void g_StridedMatrix() {
StridedMatrix strm=StridedMatrix.create(1, 1);
doGenericTests(strm);
strm=StridedMatrix.create(Matrixx.createRandomMatrix(3, 4));
doGenericTests(strm);
strm=StridedMatrix.wrap(Matrix.create(Matrixx.createRandomMatrix(3, 3)));
doGenericTests(strm);
}
@Test public void g_PermutationMatrix() {
doGenericTests(PermutationMatrix.create(0,1,2));
doGenericTests(PermutationMatrix.create(4,2,3,1,0));
doGenericTests(PermutationMatrix.create(Indexz.createRandomPermutation(8)));
doGenericTests(PermutationMatrix.create(Indexz.createRandomPermutation(6)).subMatrix(1,3,2,4));
}
@Test public void g_BandedMatrix() {
doGenericTests(BandedMatrix.create(3, 3, -2, 2));
doGenericTests(BandedMatrix.create(Matrixx.createRandomMatrix(2, 2)));
doGenericTests(BandedMatrix.wrap(3, 4, 0, 0,Vector.of(1,2,3)));
}
@Test public void g_QuadtreeMatrix() {
doGenericTests(QuadtreeMatrix.create(new Matrix22(1,0,0,2)
,ZeroMatrix.create(2, 1)
,ZeroMatrix.create(1, 2)
,Matrixx.createScaleMatrix(1, 3)));
}
@Test public void g_SparseMatrix() {
doGenericTests(SparseRowMatrix.create(Vector.of(0,1,-Math.E),null,null,AxisVector.create(2, 3)));
doGenericTests(SparseRowMatrix.create(Matrixx.createRandomSquareMatrix(3)));
doGenericTests(SparseColumnMatrix.create(Vector.of(0,1,-Math.PI),null,null,AxisVector.create(2, 3)));
doGenericTests(SparseColumnMatrix.create(Matrixx.createRandomSquareMatrix(4)));
}
@Test public void g_TriangularMatrix() {
doGenericTests(UpperTriangularMatrix.createFrom(Matrixx.createRandomSquareMatrix(1)));
doGenericTests(UpperTriangularMatrix.createFrom(Matrixx.createRandomSquareMatrix(4)));
doGenericTests(UpperTriangularMatrix.createFrom(Matrixx.createRandomMatrix(4,3)));
doGenericTests(UpperTriangularMatrix.createFrom(Matrixx.createRandomMatrix(2,3)));
doGenericTests(LowerTriangularMatrix.createFrom(Matrixx.createRandomSquareMatrix(1)));
doGenericTests(LowerTriangularMatrix.createFrom(Matrixx.createRandomSquareMatrix(4)));
doGenericTests(LowerTriangularMatrix.createFrom(Matrixx.createRandomMatrix(4,3)));
doGenericTests(LowerTriangularMatrix.createFrom(Matrixx.createRandomMatrix(2,3)));
}
@Test public void g_ImmutableMatrix() {
doGenericTests(new ImmutableMatrix(Matrixx.createRandomMatrix(4, 5)));
doGenericTests(new ImmutableMatrix(Matrixx.createRandomMatrix(3, 3)));
}
@Test public void g_BlockDiagonalMatrix() {
doGenericTests(BlockDiagonalMatrix.create(IdentityMatrix.create(2),Matrixx.createRandomSquareMatrix(2)));
}
}