Package mikera.matrixx.impl

Source Code of mikera.matrixx.impl.SparseRowMatrix

package mikera.matrixx.impl;

import java.util.Arrays;
import java.util.List;
// import java.util.HashMap;
// import java.util.HashSet;
// import java.util.Map;
// import java.util.Map.Entry;

import mikera.arrayz.ISparse;
import mikera.matrixx.AMatrix;
import mikera.matrixx.Matrixx;
import mikera.matrixx.Matrix;
import mikera.vectorz.AVector;
import mikera.vectorz.Op;
import mikera.vectorz.Vector;
import mikera.vectorz.Vectorz;
import mikera.vectorz.impl.RepeatedElementVector;
import mikera.vectorz.impl.SparseIndexedVector;
import mikera.vectorz.util.ErrorMessages;
import mikera.vectorz.util.VectorzException;

/**
* Matrix stored as a sparse collection of sparse row vectors.
*
* This format is especially efficient for:
* - innerProduct() with another matrix, especially one with efficient
*   column access like SparseColumnMatrix
* - access via getRow() operation
* - transpose into SparseColumnMatrix
*
* @author Mike
*
*/
public class SparseRowMatrix extends ASparseRCMatrix implements ISparse, IFastRows {
  private static final long serialVersionUID = 8646257152425415773L;

  private static final long SPARSE_ELEMENT_THRESHOLD = 1000L;
 
  private final AVector emptyRow;

  protected SparseRowMatrix(int rowCount, int columnCount) {
    this(new AVector[rowCount],rowCount,columnCount);
  }

  protected SparseRowMatrix(AVector[] data, int rowCount, int columnCount) {
    super(rowCount,columnCount,data);
        if (data.length != rowCount)
            throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(rowCount, data.length));
    emptyRow=Vectorz.createZeroVector(columnCount);
  }

  protected SparseRowMatrix(AVector... vectors) {
    this(vectors, vectors.length, vectors[0].length());
  }

  protected SparseRowMatrix(List<AVector> data, int rowCount, int columnCount) {
    this(data.toArray(new AVector[0]),rowCount,columnCount);
  }

  protected SparseRowMatrix(List<AVector> data) {
    this(data.toArray(new AVector[0]));
  }

//    protected SparseRowMatrix(HashMap<Integer,AVector> data, int rowCount, int columnCount) {
//      super(rowCount,columnCount,data);
//      emptyColumn=Vectorz.createZeroVector(rowCount);
//    }

  public static SparseRowMatrix create(int rows, int cols) {
    return new SparseRowMatrix(rows, cols);
  }

  public static SparseRowMatrix create(AVector[] data, int rows, int cols) {
    return new SparseRowMatrix(data, rows, cols);
  }

  public static SparseRowMatrix create(AVector... vecs) {
    return new SparseRowMatrix(vecs);
        // don't validate; user can call validate() if they want it.
  }
 
  public static SparseRowMatrix create(List<AVector> vecs) {
    return create(vecs.toArray(new AVector[0]));
  }
 
  public static SparseRowMatrix wrap(AVector[] vecs, int rows, int cols) {
    return create(vecs, rows, cols);
  }
 
  public static SparseRowMatrix wrap(AVector... vecs) {
    return create(vecs);
  }
 
  public static SparseRowMatrix create(AMatrix source) {
    int rc = source.rowCount();
    int cc = source.columnCount();
    AVector[] data = new AVector[rc];
    for (int i = 0; i < rc; i++) {
      AVector row = source.getRow(i);
      if (!row.isZero())
          data[i] = Vectorz.createSparse(row);
    }
    return new SparseRowMatrix(data,rc,cc);
  }

  public static SparseRowMatrix wrap(List<AVector> vecs) {
    return create(vecs);
  }
 
//  public static SparseRowMatrix wrap(HashMap<Integer, AVector> data, int rows, int cols) {
//    return new SparseRowMatrix(data, rows, cols);
//  }

  @Override
  protected int lineCount() {
    return rows;
  }

  @Override
  protected int lineLength() {
    return cols;
  }

  @Override
  public double get(int i, int j) {
    return getRow(i).get(j);
  }

  @Override
  public void set(int i, int j, double value) {
    checkIndex(i,j);
    AVector v = unsafeGetVec(i);
    if (v == null) {
      if (value == 0.0)
        return;
      v = Vectorz.createSparseMutable(cols);
    } else if (v.isFullyMutable()) {
      v.set(j, value);
      return;
    } else {
      v = v.sparseClone();
    }
    unsafeSetVec(i, v);
    v.unsafeSet(j, value);
  }

  @Override
  public double unsafeGet(int row, int column) {
    return getRow(row).unsafeGet(column);
  }

  @Override
  public void unsafeSet(int row, int column, double value) {
    AVector v=getRow(row);
    if (v.isFullyMutable()) {
      v.unsafeSet(column,value);
    } else {
      v=v.mutable();
      replaceRow(row,v);
      v.unsafeSet(column,value);
    }
  }
 
  @Override
  public void set(AMatrix a) {
    checkSameShape(a);
    for (int i=0; i<rows; i++) {
      setRow(i,a.getRow(i));
    }
  }
 
  @Override
  public void setRow(int i, AVector v) {
    data[i]=v.copy();
  }
 
  @Override
  public void addAt(int i, int j, double d) {
    if (d==0.0) return;
    AVector v=unsafeGetVec(i);
    if (v.isFullyMutable()) {
      v.addAt(j, d);
    } else {
      v=v.mutable();
      v.addAt(j, d);
      replaceRow(i,v);
    }
  }
 
  @Override
  public void addToArray(double[] targetData, int offset) {
        for (int i = 0; i < rows; ++i) {
      AVector v = unsafeGetVec(i);
      if (v != null) v.addToArray(targetData, offset+cols*i);
    }
  }

  private AVector ensureMutableRow(int i) {
    AVector v = unsafeGetVec(i);
    if (v == null) {
      AVector nv=SparseIndexedVector.createLength(cols);
            unsafeSetVec(i, nv);
      return nv;
    }
    if (v.isFullyMutable()) return v;
    AVector mv=v.mutable();
    unsafeSetVec(i, mv);
    return mv;
  }

  @Override
  public AVector getRow(int i) {
    if ((i<0)||(i>=rows)) throw new IndexOutOfBoundsException(ErrorMessages.invalidSlice(this, 0, i));
    AVector v = unsafeGetVec(i);
    if (v == null) return emptyRow;
    return v;
  }
 
  @Override
  public AVector getRowView(int i) {
    return ensureMutableRow(i);
  }
 
  @Override
  public boolean isUpperTriangular() {
    int rc=rowCount();
    for (int i=1; i<rc; i++) {
      if (!getRow(i).isRangeZero(0, i)) return false;
    }
    return true;
  }

  @Override
  public void swapRows(int i, int j) {
    if (i == j)
      return;
    if ((i < 0) || (i >= rows))
      throw new IndexOutOfBoundsException(ErrorMessages.invalidSlice(this, 0, i));
    if ((j < 0) || (j >= rows))
      throw new IndexOutOfBoundsException(ErrorMessages.invalidSlice(this, 0, j));
    AVector a = unsafeGetVec(i);
    AVector b = unsafeGetVec(j);
    unsafeSetVec(i, b);
    unsafeSetVec(j, a);
  }

  @Override
  public void replaceRow(int i, AVector vec) {
    if ((i < 0) || (i >= rows))
      throw new IndexOutOfBoundsException(ErrorMessages.invalidSlice(this, 0, i));
    if (vec.length() != cols)
      throw new IllegalArgumentException(ErrorMessages.incompatibleShape(vec));
        unsafeSetVec(i, vec);
  }

  @Override
  public void add(AMatrix a) {
    int count=rowCount();
    for (int i=0; i<count; i++) {
      AVector myVec=unsafeGetVec(i);
      AVector aVec=a.getRow(i);
      if (myVec==null) {
        if (!aVec.isZero()) {
          unsafeSetVec(i,aVec.copy());
        }
      } else if (myVec.isMutable()) {
        myVec.add(aVec);
      } else {
        unsafeSetVec(i,myVec.addCopy(aVec));
      }
    }
  }
 
  @Override
  public void copyRowTo(int i, double[] data, int offset) {
    AVector v=this.unsafeGetVec(i);
    if (v==null) {
      Arrays.fill(data, offset, offset+cols, 0.0);     
    } else {
      v.getElements(data, offset);
    }
  }
 
  @Override
  public void copyColumnTo(int col, double[] targetData, int offset) {
    Arrays.fill(targetData, offset, offset+rows, 0.0);
        for (int i = 0; i < rows; ++i) {
            AVector v = unsafeGetVec(i);
            if (v != null)
          targetData[offset+i] = v.unsafeGet(col);
    }   
  }

  @Override
  public SparseColumnMatrix getTransposeView() {
    return SparseColumnMatrix.wrap(data, cols, rows);
  }

  @Override
  public AMatrix multiplyCopy(double a) {
    AVector[] ndata=new AVector[lineCount()];
    for (int i = 0; i < lineCount(); ++i) {
            AVector v = unsafeGetVec(i);
            if (v != null)
                ndata[i] = v.innerProduct(a);
    }
    return wrap(ndata,rows,cols);
  }

  @Override
  public AVector innerProduct(AVector a) {
    return transform(a);
  }
 
  @Override
  public AVector transform(AVector a) {
    AVector r=Vector.createLength(rows);
    for (int i=0; i<rows; i++) {
      r.set(i,getRow(i).dotProduct(a));
    }
    return r;
  }
 
  @Override
  public void applyOp(Op op) {
    boolean stoch = op.isStochastic();
    AVector rr = (stoch) ? null : RepeatedElementVector.create(lineLength(), op.apply(0.0));

    for (int i = 0; i < lineCount(); i++) {
      AVector v = unsafeGetVec(i);
      if (v == null) {
        if (!stoch) {
          unsafeSetVec(i, rr);
          continue;
        }
        v = Vector.createLength(lineLength());
        unsafeSetVec(i, v);
      } else if (!v.isFullyMutable()) {
        v = v.sparseClone();
        unsafeSetVec(i, v);
      }
      v.applyOp(op);
    }
  }

  @Override
  public double[] toDoubleArray() {
    double[] ds=new double[rows*cols];
    // we use adding to array, since rows themselves are likely to be sparse
        for (int i = 0; i < rows; ++i) {
            AVector v = unsafeGetVec(i);
      if (v != null)
                v.addToArray(ds, i*cols);
    }
    return ds;
  }
 
  @Override
  public AMatrix innerProduct(AMatrix a) {
    if (a instanceof SparseColumnMatrix) {
      return innerProduct((SparseColumnMatrix) a);
    }
    AMatrix r = Matrix.create(rows, a.columnCount());

        for (int i = 0; i < rows; ++i) {
      AVector row = unsafeGetVec(i);
            if (! ((row == null) || (row.isZero()))) {
          r.setRow(i,row.innerProduct(a));
            }
    }
    return r;
  }
 
  public AMatrix innerProduct(SparseColumnMatrix a) {
    AMatrix r = Matrixx.createSparse(rows, a.cols);

        for (int i = 0; i < rows; ++i) {
      AVector row = unsafeGetVec(i);
            if (! ((row == null) || (row.isZero()))) {
                for (int j = 0; j < cols; ++j) {
            AVector acol = a.unsafeGetVec(j);
            double v = ((acol == null) || acol.isZero()) ? 0.0 : row.dotProduct(acol);
            if (v!=0.0) r.unsafeSet(i, j, v);
          }
            }
    }
    return r;
  }

  @Override
  public SparseRowMatrix exactClone() {
    SparseRowMatrix result = new SparseRowMatrix(rows, cols);
        for (int i = 0; i < rows; ++i) {
      AVector row = unsafeGetVec(i);
      if (row != null)
                result.replaceRow(i, row.exactClone());
    }
    return result;
  }
 
  @Override
  public AMatrix clone() {
    if (this.elementCount() < SPARSE_ELEMENT_THRESHOLD)
      return super.clone();
    return exactClone();
  }

  @Override
  public AMatrix sparse() {
    return this;
  }

  @Override
  public void validate() {
    super.validate();
    for (int i=0; i<rows; i++) {
      if (getRow(i).length()!=cols) throw new VectorzException("Invalid column count at row: "+i);
    }
  }
 
  @Override
  public boolean equals(AMatrix m) {
    if (m==this) return true;
    if (!isSameShape(m)) return false;
    for (int i=0; i<rows; i++) {
      AVector v=unsafeGetVec(i);
            AVector ov = m.getRow(i);
      if (v==null) {
        if (!ov.isZero()) return false;
      } else {
        if (!v.equals(ov)) return false;
      }
    }
    return true;
  }
}
TOP

Related Classes of mikera.matrixx.impl.SparseRowMatrix

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.