package mikera.matrixx.impl;
import java.util.Arrays;
import mikera.matrixx.AMatrix;
import mikera.matrixx.Matrix;
import mikera.vectorz.AVector;
import mikera.vectorz.Vector;
import mikera.vectorz.Vectorz;
import mikera.vectorz.impl.ADenseArrayVector;
import mikera.vectorz.impl.ASingleElementVector;
import mikera.vectorz.impl.SingleElementVector;
import mikera.vectorz.util.ErrorMessages;
import mikera.vectorz.util.VectorzException;
/**
* Abstract base class for square diagonal matrices
* @author Mike
*
*/
public abstract class ADiagonalMatrix extends ASingleBandMatrix {
private static final long serialVersionUID = -6770867175103162837L;
protected final int dimensions;
protected ADiagonalMatrix(int dimensions) {
this.dimensions=dimensions;
}
@Override
public int nonZeroBand() {
return 0;
}
@Override
public boolean isSquare() {
return true;
}
@Override
public boolean isZero() {
return getLeadingDiagonal().isZero();
}
@Override
public boolean isBoolean() {
return getLeadingDiagonal().isBoolean();
}
@Override
public boolean isSymmetric() {
return true;
}
@Override
public boolean isDiagonal() {
return true;
}
@Override
public boolean isRectangularDiagonal() {
return true;
}
@Override
public boolean isUpperTriangular() {
return true;
}
@Override
public boolean isLowerTriangular() {
return true;
}
@Override
public abstract boolean isMutable();
@Override
public boolean isFullyMutable() {
return (dimensions<=1)&&(getLeadingDiagonal().isFullyMutable());
}
@Override
protected void checkSameShape(AMatrix m) {
int dims=dimensions;
if((dims!=m.rowCount())||(dims!=m.columnCount())) {
throw new IndexOutOfBoundsException(ErrorMessages.mismatch(this, m));
}
}
@Override
protected void checkSameShape(ARectangularMatrix m) {
int dims=dimensions;
if((dims!=m.rowCount())||(dims!=m.columnCount())) {
throw new IndexOutOfBoundsException(ErrorMessages.mismatch(this, m));
}
}
@Override
protected final void checkIndex(int i, int j) {
if ((i<0)||(i>=dimensions)||(j<0)||(j>=dimensions)) {
throw new IndexOutOfBoundsException(ErrorMessages.invalidIndex(this, i,j));
}
}
public final boolean isSameShape(ARectangularMatrix m) {
return (dimensions==m.rows)&&(dimensions==m.cols);
}
@Override
public final int upperBandwidthLimit() {
return 0;
}
@Override
public final int lowerBandwidthLimit() {
return 0;
}
@Override
public AVector getBand(int band) {
if (band==0) {
return getLeadingDiagonal();
} else {
if ((band>dimensions)||(band<-dimensions)) throw new IndexOutOfBoundsException(ErrorMessages.invalidBand(this, band));
return Vectorz.createZeroVector(bandLength(band));
}
}
@Override
public AVector getNonZeroBand() {
return getLeadingDiagonal();
}
@Override
public double determinant() {
double det=1.0;
for (int i=0; i<dimensions; i++) {
det*=unsafeGetDiagonalValue(i);
}
return det;
}
@Override
public int rank() {
return (int)getLeadingDiagonal().nonZeroCount();
}
/**
* Returns the number of dimensions of this diagonal matrix
* @return
*/
public int dimensions() {
return dimensions;
}
@Override
public boolean isSameShape(AMatrix m) {
return (dimensions==m.rowCount())&&(dimensions==m.columnCount());
}
@Override
public int checkSquare() {
return dimensions;
}
@Override
public double elementMax(){
double ldv=getLeadingDiagonal().elementMax();
if (dimensions>1) return Math.max(0, ldv); else return ldv;
}
@Override
public double elementMin(){
double ldv=getLeadingDiagonal().elementMin();
if (dimensions>1) return Math.min(0, ldv); else return ldv;
}
@Override
public double elementSum(){
return getLeadingDiagonal().elementSum();
}
@Override
public double elementSquaredSum(){
return getLeadingDiagonal().elementSquaredSum();
}
@Override
public long nonZeroCount(){
return getLeadingDiagonal().nonZeroCount();
}
@Override
public void copyRowTo(int row, double[] dest, int destOffset) {
Arrays.fill(dest, destOffset,destOffset+dimensions,0.0);
dest[destOffset+row]=unsafeGetDiagonalValue(row);
}
@Override
public void addToArray(double[] dest, int offset) {
getLeadingDiagonal().addToArray(dest, offset, dimensions+1);
}
@Override
public AMatrix addCopy(AMatrix a) {
if (a.isDiagonal()) {
if (a.rowCount()!=dimensions) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this,a));
DiagonalMatrix m=DiagonalMatrix.create(this.getLeadingDiagonal());
a.getLeadingDiagonal().addToArray(m.data,0);
return m;
} else {
return a.addCopy(this);
}
}
@Override
public void copyColumnTo(int col, double[] dest, int destOffset) {
// copying rows and columns is the same!
copyRowTo(col,dest,destOffset);
}
public AMatrix innerProduct(ADiagonalMatrix a) {
int dims=this.dimensions;
if (dims!=a.dimensions) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this,a));
DiagonalMatrix result=DiagonalMatrix.createDimensions(dims);
for (int i=0; i<dims; i++) {
result.data[i]=unsafeGetDiagonalValue(i)*a.unsafeGetDiagonalValue(i);
}
return result;
}
@Override
public AMatrix innerProduct(AMatrix a) {
if (a instanceof ADiagonalMatrix) {
return innerProduct((ADiagonalMatrix) a);
} else if (a instanceof Matrix) {
return innerProduct((Matrix) a);
}
if (!(dimensions==a.rowCount())) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this,a));
AMatrix m=a.clone();
for (int i=0; i<dimensions; i++) {
double dv=unsafeGetDiagonalValue(i);
m.multiplyRow(i, dv);
}
return m;
}
@Override
public Matrix innerProduct(Matrix a) {
if (!(dimensions==a.rowCount())) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this,a));
Matrix m=a.clone();
for (int i=0; i<dimensions; i++) {
double dv=unsafeGetDiagonalValue(i);
m.multiplyRow(i, dv);
}
return m;
}
@Override
public Matrix transposeInnerProduct(Matrix s) {
return innerProduct(s);
}
@Override
public void transformInPlace(AVector v) {
if (v instanceof ADenseArrayVector) {
transformInPlace((ADenseArrayVector) v);
return;
}
if (v.length()!=dimensions) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this,v));
for (int i=0; i<dimensions; i++) {
v.unsafeSet(i,v.unsafeGet(i)*unsafeGetDiagonalValue(i));
}
}
@Override
public void transformInPlace(ADenseArrayVector v) {
double[] data=v.getArray();
int offset=v.getArrayOffset();
for (int i=0; i<dimensions; i++) {
data[i+offset]*=unsafeGetDiagonalValue(i);
}
}
@Override
public void transform(Vector source, Vector dest) {
int rc = rowCount();
int cc = rc;
if (source.length()!=cc) throw new IllegalArgumentException(ErrorMessages.wrongSourceLength(source));
if (dest.length()!=rc) throw new IllegalArgumentException(ErrorMessages.wrongDestLength(dest));
double[] sdata=source.getArray();
double[] ddata=dest.getArray();
for (int row = 0; row < rc; row++) {
ddata[row]=sdata[row]*unsafeGetDiagonalValue(row);
}
}
@Override
public int rowCount() {
return dimensions;
}
@Override
public int columnCount() {
return dimensions;
}
@Override
public boolean isIdentity() {
return getLeadingDiagonal().elementsEqual(1.0);
}
@Override
public void transposeInPlace() {
// already done!
}
@Override
public double calculateElement(int i, AVector v) {
return v.unsafeGet(i)*unsafeGetDiagonalValue(i);
}
@Override
public void set(int row, int column, double value) {
throw new UnsupportedOperationException(ErrorMessages.notFullyMutable(this, row, column));
}
@Override
public abstract AVector getLeadingDiagonal();
public double getDiagonalValue(int i) {
if ((i<0)||(i>=dimensions)) throw new IndexOutOfBoundsException();
return unsafeGet(i,i);
}
@Override
public ASingleElementVector getRow(int row) {
return SingleElementVector.create(getDiagonalValue(row), row, dimensions);
}
@Override
public ASingleElementVector getColumn(int col) {
return getRow(col);
}
public double unsafeGetDiagonalValue(int i) {
return unsafeGet(i,i);
}
@Override
public ADiagonalMatrix getTranspose() {
return this;
}
@Override
public ADiagonalMatrix getTransposeView() {
return this;
}
@Override
public double density() {
return 1.0/dimensions;
}
@Override
public Matrix toMatrix() {
Matrix m=Matrix.create(dimensions, dimensions);
for (int i=0; i<dimensions; i++) {
m.data[i*(dimensions+1)]=unsafeGetDiagonalValue(i);
}
return m;
}
@Override
public double trace() {
return getLeadingDiagonal().elementSum();
}
@Override
public double diagonalProduct() {
return getLeadingDiagonal().elementProduct();
}
@Override
public double[] toDoubleArray() {
double[] data=new double[dimensions*dimensions];
getLeadingDiagonal().addToArray(data, 0, dimensions+1);
return data;
}
@Override
public final Matrix toMatrixTranspose() {
return toMatrix();
}
@Override
public boolean equalsTranspose(AMatrix m) {
return equals(m);
}
@Override
public void validate() {
if (dimensions!=getLeadingDiagonal().length()) throw new VectorzException("dimension mismatch: "+dimensions);
super.validate();
}
@Override
public abstract ADiagonalMatrix exactClone();
@Override
public boolean hasUncountable() {
return getLeadingDiagonal().hasUncountable();
}
/**
* Returns the sum of all the elements raised to a specified power
* @return
*/
@Override
public double elementPowSum(double p) {
return getLeadingDiagonal().elementPowSum(p);
}
/**
* Returns the sum of the absolute values of all the elements raised to a specified power
* @return
*/
@Override
public double elementAbsPowSum(double p) {
return getLeadingDiagonal().elementAbsPowSum(p);
}
}