package mikera.matrixx.impl;
import java.util.Arrays;
import mikera.arrayz.ISparse;
import mikera.indexz.Index;
import mikera.indexz.Indexz;
import mikera.matrixx.AMatrix;
import mikera.matrixx.Matrix;
import mikera.vectorz.AVector;
import mikera.vectorz.Vector;
import mikera.vectorz.impl.AxisVector;
import mikera.vectorz.util.ErrorMessages;
import mikera.vectorz.util.VectorzException;
/**
* Class representing a square permutation matrix
* i.e. has single 1.0 in every row and column
* @author Mike
*
*/
public final class PermutationMatrix extends ABooleanMatrix implements IFastRows, IFastColumns, ISparse {
private static final long serialVersionUID = 8098287603508120428L;
private final Index perm;
private final int size;
private PermutationMatrix(Index perm) {
if (!perm.isPermutation()) throw new IllegalArgumentException("Not a valid permutation: "+perm);
this.perm=perm;
size=perm.length();
}
public static PermutationMatrix createIdentity(int length) {
return new PermutationMatrix(Indexz.createSequence(length));
}
public static PermutationMatrix createSwap(int i, int j, int length) {
PermutationMatrix p=createIdentity(length);
p.swapRows(i, j);
return p;
}
public static PermutationMatrix create(Index rowPermutations) {
return new PermutationMatrix(rowPermutations.clone());
}
public static PermutationMatrix wrap(Index rowPermutations) {
return new PermutationMatrix(rowPermutations);
}
public static PermutationMatrix create(int... rowPermutations) {
Index index=Index.of(rowPermutations);
return wrap(index);
}
public static PermutationMatrix wrap(int[] rowPermutations) {
return wrap(Index.wrap(rowPermutations));
}
public static PermutationMatrix createRandomPermutation(int length) {
Index index=Indexz.createRandomPermutation(length);
return new PermutationMatrix(index);
}
@Override
public void addToArray(double[] data, int offset) {
for (int i=0; i<size; i++) {
data[offset+(i*size)+perm.get(i)]+=1.0;
}
}
@Override
public boolean isMutable() {
// PermutationMatrix is mutable (rows can be swapped, etc.)
return true;
}
@Override
public boolean isSymmetric() {
return isIdentity();
}
@Override
public double determinant() {
return perm.isEvenPermutation()?1.0:-1.0;
}
@Override
public int rank() {
return size;
}
@Override
public boolean isIdentity() {
int[] data=perm.data;
for (int i=0; i<size; i++) {
if (data[i]!=i) return false;
}
return true;
}
@Override
public boolean isOrthogonal() {
return true;
}
@Override
public boolean isOrthogonal(double tolerance) {
return true;
}
@Override
public boolean hasOrthonormalColumns() {
return true;
}
@Override
public boolean hasOrthonormalRows() {
return true;
}
@Override
public boolean isDiagonal() {
return isIdentity();
}
@Override
public boolean isBoolean() {
return true;
}
@Override
public boolean isUpperTriangular() {
return isIdentity();
}
@Override
public boolean isLowerTriangular() {
return isIdentity();
}
@Override
public boolean isSquare() {
return true;
}
@Override
public int rowCount() {
return size;
}
@Override
public int columnCount() {
return size;
}
@Override
public double elementSum() {
return size;
}
@Override
public double elementSquaredSum() {
return size;
}
@Override
public long nonZeroCount() {
return size;
}
@Override
public double trace() {
int result=0;
for (int i=0; i<size; i++) {
if (perm.data[i]==i) result++;
}
return result;
}
@Override
public PermutationMatrix inverse() {
return getTranspose();
}
@Override
public PermutationMatrix getTranspose() {
return new PermutationMatrix(perm.invert());
}
@Override
public double get(int row, int column) {
if (column<0||(column>=size)) throw new IndexOutOfBoundsException(ErrorMessages.invalidIndex(this,row,column));
return (perm.get(row)==column)?1.0:0.0;
}
@Override
public void unsafeSet(int row, int column, double value) {
if (get(row,column)==value) return;
throw new UnsupportedOperationException(ErrorMessages.notFullyMutable(this,row,column));
}
@Override
public double unsafeGet(int row, int column) {
return (perm.unsafeGet(row)==column)?1.0:0.0;
}
@Override
public void set(int row, int column, double value) {
throw new UnsupportedOperationException(ErrorMessages.notFullyMutable(this,row,column));
}
@Override
public AxisVector getRow(int i) {
return AxisVector.create(perm.get(i), size);
}
@Override
public AxisVector getColumn(int j) {
return AxisVector.create(perm.find(j), size);
}
@Override
public void copyRowTo(int row, double[] dest, int destOffset) {
Arrays.fill(dest, destOffset,destOffset+size,0.0);
dest[destOffset+perm.get(row)]=1.0;
}
@Override
public void copyColumnTo(int col, double[] dest, int destOffset) {
Arrays.fill(dest, destOffset,destOffset+size,0.0);
dest[destOffset+perm.find(col)]=1.0;
}
@Override
public void swapRows(int i, int j) {
if (i!=j) {
perm.swap(i, j);
}
}
@Override
public void swapColumns(int i, int j) {
if (i!=j) {
int a=perm.find(i);
int b=perm.find(j);
perm.swap(a, b);
}
}
@Override
public void transform(AVector source, AVector dest) {
if ((source instanceof Vector )&&(dest instanceof Vector)) {
transform ((Vector)source, (Vector)dest);
return;
}
if(rowCount()!=dest.length()) throw new IllegalArgumentException(ErrorMessages.wrongDestLength(dest));
if(columnCount()!=source.length()) throw new IllegalArgumentException(ErrorMessages.wrongSourceLength(dest));
for (int i=0; i<size; i++) {
dest.unsafeSet(i,source.unsafeGet(perm.unsafeGet(i)));
}
}
@Override
public void transform(Vector source, Vector dest) {
int rc = rowCount();
int cc = columnCount();
if (source.length()!=cc) throw new IllegalArgumentException(ErrorMessages.wrongSourceLength(source));
if (dest.length()!=rc) throw new IllegalArgumentException(ErrorMessages.wrongDestLength(dest));
for (int i=0; i<size; i++) {
dest.unsafeSet(i,source.unsafeGet(perm.unsafeGet(i)));
}
}
@Override
public double calculateElement(int i, AVector inputVector) {
return inputVector.unsafeGet(perm.get(i));
}
@Override
public double calculateElement(int i, Vector inputVector) {
return inputVector.unsafeGet(perm.get(i));
}
@Override
public Matrix innerProduct(AMatrix a) {
if (a instanceof Matrix) return innerProduct((Matrix)a);
if (a.rowCount()!=size) throw new IllegalArgumentException(ErrorMessages.mismatch(this,a));
int cc=a.columnCount();
Matrix result=Matrix.create(size,cc);
for (int i=0; i<size; i++) {
int dstIndex=i*cc;
int srcRow=perm.get(i);
a.copyRowTo(srcRow, result.data, dstIndex);
}
return result;
}
@Override
public Matrix innerProduct(Matrix a) {
if (a.rowCount()!=size) throw new IllegalArgumentException(ErrorMessages.mismatch(this,a));
int cc=a.columnCount();
Matrix result=Matrix.create(size,cc);
for (int i=0; i<size; i++) {
int srcIndex=perm.get(i)*cc;
int dstIndex=i*cc;
System.arraycopy(a.data,srcIndex,result.data,dstIndex,cc);
}
return result;
}
@Override
public Matrix transposeInnerProduct(Matrix s) {
return getTranspose().innerProduct(s);
}
@Override
public void validate() {
super.validate();
if (size!=perm.length()) throw new VectorzException("Whoops!");
}
@Override
public double density() {
return 1.0/size;
}
@Override
public PermutationMatrix exactClone() {
return new PermutationMatrix(perm.clone());
}
@Override
public boolean hasUncountable() {
return false;
}
/**
* Returns the sum of all the elements raised to a specified power
* @return
*/
@Override
public double elementPowSum(double p) {
return size;
}
/**
* Returns the sum of the absolute values of all the elements raised to a specified power
* @return
*/
@Override
public double elementAbsPowSum(double p) {
return elementPowSum(p);
}
}