package mikera.matrixx.algo;
import mikera.matrixx.AMatrix;
import mikera.matrixx.Matrix;
import mikera.matrixx.impl.ImmutableMatrix;
import mikera.vectorz.util.DoubleArrays;
import mikera.vectorz.util.ErrorMessages;
public class Multiplications {
// target number of elements in working set group
// aim for around 200kb => fits comfortably in L2 cache in modern machines
protected static final int WORKING_SET_TARGET=8192;
/**
* General purpose matrix multiplication, with smart selection of algorithm based
* on matrix size and type.
*
* @param a
* @param b
* @return
*/
public static Matrix multiply(AMatrix a, AMatrix b) {
if (a instanceof Matrix) {
return multiply((Matrix)a,b);
} else if (a instanceof ImmutableMatrix) {
return multiply(Matrix.wrap(a.rowCount(),a.columnCount(),((ImmutableMatrix)a).getInternalData()),b);
} else {
return blockedMultiply(a.toMatrix(),b);
}
}
public static Matrix multiply(Matrix a, AMatrix b) {
return blockedMultiply(a,b);
}
/**
* Performs fast matrix multiplication using temporary working storage for the second matrix
* @param a
* @param b
* @return
*/
public static Matrix blockedMultiply(Matrix a, AMatrix b) {
int rc=a.rowCount();
int cc=b.columnCount();
int ic=a.columnCount();
if ((ic!=b.rowCount())) {
throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(a,b));
}
Matrix result=Matrix.create(rc, cc);
if (ic==0) return result;
int block=(WORKING_SET_TARGET/ic)+1;
// working set stores up to <block> number of columns from second matrix
Matrix wsb=Matrix.create(Math.min(block,cc), ic);
for (int bj=0; bj<cc; bj+=block) {
int bjsize=Math.min(block, cc-bj);
// copy columns into working set
for (int t=0; t<bjsize; t++) {
b.copyColumnTo(bj+t,wsb.data,t*ic);
}
for (int bi=0; bi<rc; bi+=block) {
int bisize=Math.min(block, rc-bi);
// compute inner block
for (int i=bi; i<(bi+bisize); i++) {
int aDataOffset=i*ic;
for (int j=bj; j<(bj+bjsize); j++) {
double val=DoubleArrays.dotProduct(a.data, aDataOffset, wsb.data, ic*(j-bj), ic);
result.unsafeSet(i, j, val);
}
}
}
}
return result;
}
/**
* Performs fast matrix multiplication using temporary working storage for both matrices
* @param a
* @param b
* @return
*/
public static Matrix doubleBlockedMultiply(AMatrix a, AMatrix b) {
int rc=a.rowCount();
int cc=b.columnCount();
int ic=a.columnCount();
if ((ic!=b.rowCount())) {
throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(a,b));
}
Matrix result=Matrix.create(rc, cc);
if (ic==0) return result;
int block=(WORKING_SET_TARGET/ic)+1;
// working sets stores up to <block> number of columns from each matrix
Matrix wsa=Matrix.create(Math.min(block,rc), ic);
Matrix wsb=Matrix.create(Math.min(block,cc), ic);
for (int bj=0; bj<cc; bj+=block) {
int bjsize=Math.min(block, cc-bj);
// copy columns into working set
for (int t=0; t<bjsize; t++) {
b.copyColumnTo(bj+t,wsb.data,t*ic);
}
for (int bi=0; bi<rc; bi+=block) {
int bisize=Math.min(block, rc-bi);
// copy columns into working set
for (int t=0; t<bisize; t++) {
b.copyRowTo(bi+t,wsa.data,t*ic);
}
// compute inner block
for (int i=bi; i<(bi+bisize); i++) {
for (int j=bj; j<(bj+bjsize); j++) {
double val=DoubleArrays.dotProduct(wsa.data, ic*(i-bi), wsb.data, ic*(j-bj), ic);
result.unsafeSet(i, j, val);
}
}
}
}
return result;
}
public static Matrix directMultiply(Matrix a, AMatrix b) {
int rc=a.rowCount();
int cc=b.columnCount();
int ic=a.columnCount();
if ((ic!=b.rowCount())) {
throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(a,b));
}
Matrix result=Matrix.create(rc,cc);
double[] tmp=new double[ic];
for (int j=0; j<cc; j++) {
b.copyColumnTo(j, tmp, 0);
for (int i=0; i<rc; i++) {
// double acc=0.0;
// for (int k=0; k<ic; k++) {
// acc+=a.unsafeGet(i, k)*tmp[k];
// }
double acc=DoubleArrays.dotProduct(a.data, i*ic, tmp, 0, ic);
result.unsafeSet(i,j,acc);
}
}
return result;
}
public static AMatrix naiveMultiply(AMatrix a, AMatrix b) {
int rc=a.rowCount();
int cc=b.columnCount();
int ic=a.columnCount();
if ((ic!=b.rowCount())) {
throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(a,b));
}
Matrix result=Matrix.create(rc,cc);
for (int i=0; i<rc; i++) {
for (int j=0; j<cc; j++) {
double acc=0.0;
for (int k=0; k<ic; k++) {
acc+=a.unsafeGet(i, k)*b.unsafeGet(k, j);
}
result.unsafeSet(i,j,acc);
}
}
return result;
}
}