Package mikera.matrixx.impl

Source Code of mikera.matrixx.impl.ADiagonalMatrix

package mikera.matrixx.impl;

import java.util.Arrays;

import mikera.matrixx.AMatrix;
import mikera.matrixx.Matrix;
import mikera.vectorz.AVector;
import mikera.vectorz.Vector;
import mikera.vectorz.Vectorz;
import mikera.vectorz.impl.ADenseArrayVector;
import mikera.vectorz.impl.ASingleElementVector;
import mikera.vectorz.impl.SingleElementVector;
import mikera.vectorz.util.ErrorMessages;
import mikera.vectorz.util.VectorzException;

/**
* Abstract base class for square diagonal matrices
* @author Mike
*
*/
public abstract class ADiagonalMatrix extends ASingleBandMatrix {
  private static final long serialVersionUID = -6770867175103162837L;

  protected final int dimensions;
 
  protected ADiagonalMatrix(int dimensions) {
    this.dimensions=dimensions;
  }
 
  @Override
  public int nonZeroBand() {
    return 0;
  }

  @Override
  public boolean isSquare() {
    return true;
  }
 
  @Override
  public boolean isZero() {
    return getLeadingDiagonal().isZero();
  }
 
  @Override
  public boolean isBoolean() {
    return getLeadingDiagonal().isBoolean();
  }
 
  @Override
  public boolean isSymmetric() {
    return true;
  }
 
  @Override
  public boolean isDiagonal() {
    return true;
  }
 
  @Override
  public boolean isRectangularDiagonal() {
    return true;
  }
 
  @Override
  public boolean isUpperTriangular() {
    return true;
  }
 
  @Override
  public boolean isLowerTriangular() {
    return true;
  }
 
  @Override
  public abstract boolean isMutable();
 
  @Override
  public boolean isFullyMutable() {
    return (dimensions<=1)&&(getLeadingDiagonal().isFullyMutable());
  }
 
  @Override
  protected void checkSameShape(AMatrix m) {
    int dims=dimensions;
    if((dims!=m.rowCount())||(dims!=m.columnCount())) {
      throw new IndexOutOfBoundsException(ErrorMessages.mismatch(this, m));
    }
  }
 
  @Override
  protected void checkSameShape(ARectangularMatrix m) {
    int dims=dimensions;
    if((dims!=m.rowCount())||(dims!=m.columnCount())) {
      throw new IndexOutOfBoundsException(ErrorMessages.mismatch(this, m));
    }
  }
 
  @Override
  protected final void checkIndex(int i, int j) {
    if ((i<0)||(i>=dimensions)||(j<0)||(j>=dimensions)) {
      throw new IndexOutOfBoundsException(ErrorMessages.invalidIndex(this, i,j));
    }
  }
 
  public final boolean isSameShape(ARectangularMatrix m) {
    return (dimensions==m.rows)&&(dimensions==m.cols);
  }
 
  @Override
  public final int upperBandwidthLimit() {
    return 0;
  }
 
  @Override
  public final int lowerBandwidthLimit() {
    return 0;
  }
 
  @Override
  public AVector getBand(int band) {
    if (band==0) {
      return getLeadingDiagonal();
    } else {
      if ((band>dimensions)||(band<-dimensions)) throw new IndexOutOfBoundsException(ErrorMessages.invalidBand(this, band));
      return Vectorz.createZeroVector(bandLength(band));
    }
  }
 
  @Override
  public AVector getNonZeroBand() {
    return getLeadingDiagonal();
  }
 
  @Override
  public double determinant() {
    double det=1.0;
    for (int i=0; i<dimensions; i++) {
      det*=unsafeGetDiagonalValue(i);
    }
    return det;
  }
 
  @Override
  public int rank() {
    return (int)getLeadingDiagonal().nonZeroCount();
  }
 
  /**
   * Returns the number of dimensions of this diagonal matrix
   * @return
   */
  public int dimensions() {
    return dimensions;
  }
 
  @Override
  public boolean isSameShape(AMatrix m) {
    return (dimensions==m.rowCount())&&(dimensions==m.columnCount());
  }
 
  @Override
  public int checkSquare() {
    return dimensions;
  }
 
  @Override
  public double elementMax(){
    double ldv=getLeadingDiagonal().elementMax();
    if (dimensions>1) return Math.max(0, ldv); else return ldv;
  }
 
  @Override
  public double elementMin(){
    double ldv=getLeadingDiagonal().elementMin();
    if (dimensions>1) return Math.min(0, ldv); else return ldv;
  }
 
  @Override
  public double elementSum(){
    return getLeadingDiagonal().elementSum();
  }
 
  @Override
  public double elementSquaredSum(){
    return getLeadingDiagonal().elementSquaredSum();
  }
 
  @Override
  public long nonZeroCount(){
    return getLeadingDiagonal().nonZeroCount();
  }
 
  @Override
  public void copyRowTo(int row, double[] dest, int destOffset) {
    Arrays.fill(dest, destOffset,destOffset+dimensions,0.0);
    dest[destOffset+row]=unsafeGetDiagonalValue(row);
  }
 
  @Override
  public void addToArray(double[] dest, int offset) {
    getLeadingDiagonal().addToArray(dest, offset, dimensions+1);
  }
 
  @Override
  public AMatrix addCopy(AMatrix a) {
    if (a.isDiagonal()) {
      if (a.rowCount()!=dimensions) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this,a));
      DiagonalMatrix m=DiagonalMatrix.create(this.getLeadingDiagonal());
      a.getLeadingDiagonal().addToArray(m.data,0);
      return m;
    } else {
      return a.addCopy(this);
    }
  }
 
  @Override
  public void copyColumnTo(int col, double[] dest, int destOffset) {
    // copying rows and columns is the same!
    copyRowTo(col,dest,destOffset);
  }
 
  public AMatrix innerProduct(ADiagonalMatrix a) {
    int dims=this.dimensions;
    if (dims!=a.dimensions) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this,a));
    DiagonalMatrix result=DiagonalMatrix.createDimensions(dims);
    for (int i=0; i<dims; i++) {
      result.data[i]=unsafeGetDiagonalValue(i)*a.unsafeGetDiagonalValue(i);
    }
    return result;
  }
 
  @Override
  public AMatrix innerProduct(AMatrix a) {
    if (a instanceof ADiagonalMatrix) {
      return innerProduct((ADiagonalMatrix) a);
    } else if (a instanceof Matrix) {
      return innerProduct((Matrix) a);
    }
    if (!(dimensions==a.rowCount())) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this,a));
    AMatrix m=a.clone();
    for (int i=0; i<dimensions; i++) {
      double dv=unsafeGetDiagonalValue(i);
      m.multiplyRow(i, dv);
    }
    return m;
  }
 
  @Override
  public Matrix innerProduct(Matrix a) {
    if (!(dimensions==a.rowCount())) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this,a));
    Matrix m=a.clone();
    for (int i=0; i<dimensions; i++) {
      double dv=unsafeGetDiagonalValue(i);
      m.multiplyRow(i, dv);
    }
    return m;
  }
 
  @Override
  public Matrix transposeInnerProduct(Matrix s) {
    return innerProduct(s);
  }
 
  @Override
  public void transformInPlace(AVector v) {
    if (v instanceof ADenseArrayVector) {
      transformInPlace((ADenseArrayVector) v);
      return;
    }
    if (v.length()!=dimensions) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this,v));
    for (int i=0; i<dimensions; i++) {
      v.unsafeSet(i,v.unsafeGet(i)*unsafeGetDiagonalValue(i));
    }
  }
 
  @Override
  public void transformInPlace(ADenseArrayVector v) {
    double[] data=v.getArray();
    int offset=v.getArrayOffset();
    for (int i=0; i<dimensions; i++) {
      data[i+offset]*=unsafeGetDiagonalValue(i);
    }
  }
 
  @Override
  public void transform(Vector source, Vector dest) {
    int rc = rowCount();
    int cc = rc;
    if (source.length()!=cc) throw new IllegalArgumentException(ErrorMessages.wrongSourceLength(source));
    if (dest.length()!=rc) throw new IllegalArgumentException(ErrorMessages.wrongDestLength(dest));
    double[] sdata=source.getArray();
    double[] ddata=dest.getArray();
    for (int row = 0; row < rc; row++) {
      ddata[row]=sdata[row]*unsafeGetDiagonalValue(row);
    }
  }
 
  @Override
  public int rowCount() {
    return dimensions;
  }

  @Override
  public int columnCount() {
    return dimensions;
  }
 
  @Override
  public boolean isIdentity() {
    return getLeadingDiagonal().elementsEqual(1.0);
  }
 
  @Override
  public void transposeInPlace() {
    // already done!
  }
 
  @Override
  public double calculateElement(int i, AVector v) {
    return v.unsafeGet(i)*unsafeGetDiagonalValue(i);
  }
 
  @Override
  public void set(int row, int column, double value) {
    throw new UnsupportedOperationException(ErrorMessages.notFullyMutable(this, row, column));
  }
 
  @Override
  public abstract AVector getLeadingDiagonal();

  public double getDiagonalValue(int i) {
    if ((i<0)||(i>=dimensions)) throw new IndexOutOfBoundsException();
    return unsafeGet(i,i);
  }
 
  @Override
  public ASingleElementVector getRow(int row) {
    return SingleElementVector.create(getDiagonalValue(row), row, dimensions);
  }
 
  @Override
  public ASingleElementVector getColumn(int col) {
    return getRow(col);
  }
 
  public double unsafeGetDiagonalValue(int i) {
    return unsafeGet(i,i);
  }
 
  @Override
  public ADiagonalMatrix getTranspose() {
    return this;
  }
 
  @Override
  public ADiagonalMatrix getTransposeView() {
    return this;
  }
 
  @Override
  public double density() {
    return 1.0/dimensions;
  }
 
  @Override
  public Matrix toMatrix() {
    Matrix m=Matrix.create(dimensions, dimensions);
    for (int i=0; i<dimensions; i++) {
      m.data[i*(dimensions+1)]=unsafeGetDiagonalValue(i);
    }
    return m;
  }
 
  @Override
  public double trace() {
    return getLeadingDiagonal().elementSum();
  }
 
  @Override
  public double diagonalProduct() {
    return getLeadingDiagonal().elementProduct();
  }
 
  @Override
  public double[] toDoubleArray() {
    double[] data=new double[dimensions*dimensions];
    getLeadingDiagonal().addToArray(data, 0, dimensions+1);
    return data;
  }
 
  @Override
  public final Matrix toMatrixTranspose() {
    return toMatrix();
  }
 
  @Override
  public boolean equalsTranspose(AMatrix m) {
    return equals(m);
  }
 
  @Override
  public void validate() {
    if (dimensions!=getLeadingDiagonal().length()) throw new VectorzException("dimension mismatch: "+dimensions);
   
    super.validate();
  }
 
  @Override
  public abstract ADiagonalMatrix exactClone();

  @Override
  public boolean hasUncountable() {
    return getLeadingDiagonal().hasUncountable();
  }
 
    /**
     * Returns the sum of all the elements raised to a specified power
     * @return
     */
    @Override
    public double elementPowSum(double p) {
        return getLeadingDiagonal().elementPowSum(p);
    }
   
    /**
     * Returns the sum of the absolute values of all the elements raised to a specified power
     * @return
     */
    @Override
    public double elementAbsPowSum(double p) {
        return getLeadingDiagonal().elementAbsPowSum(p);
    }
}
TOP

Related Classes of mikera.matrixx.impl.ADiagonalMatrix

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.