Package mikera.matrixx.algo

Source Code of mikera.matrixx.algo.Multiplications

package mikera.matrixx.algo;

import mikera.matrixx.AMatrix;
import mikera.matrixx.Matrix;
import mikera.matrixx.impl.ImmutableMatrix;
import mikera.vectorz.util.DoubleArrays;
import mikera.vectorz.util.ErrorMessages;

public class Multiplications {
  // target number of elements in working set group
  // aim for around 200kb => fits comfortably in L2 cache in modern machines
  protected static final int WORKING_SET_TARGET=8192;
 
  /**
   * General purpose matrix multiplication, with smart selection of algorithm based
   * on matrix size and type.
   *
   * @param a
   * @param b
   * @return
   */
  public static Matrix multiply(AMatrix a, AMatrix b) {
    if (a instanceof Matrix) {
      return multiply((Matrix)a,b);
    } else if (a instanceof ImmutableMatrix) {
      return multiply(Matrix.wrap(a.rowCount(),a.columnCount(),((ImmutableMatrix)a).getInternalData()),b);
    } else {
      return blockedMultiply(a.toMatrix(),b);
    }
  }
 
  public static Matrix multiply(Matrix a, AMatrix b) {
    return blockedMultiply(a,b);
  }
 
  /**
   * Performs fast matrix multiplication using temporary working storage for the second matrix
   * @param a
   * @param b
   * @return
   */
  public static Matrix blockedMultiply(Matrix a, AMatrix b) {
    int rc=a.rowCount();
    int cc=b.columnCount();
    int ic=a.columnCount();
   
    if ((ic!=b.rowCount())) {
      throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(a,b));
    }   

    Matrix result=Matrix.create(rc, cc);
    if (ic==0) return result;
   
    int block=(WORKING_SET_TARGET/ic)+1;
    // working set stores up to <block> number of columns from second matrix
    Matrix wsb=Matrix.create(Math.min(block,cc), ic);
   
    for (int bj=0; bj<cc; bj+=block) {
      int bjsize=Math.min(block, cc-bj);
     
      // copy columns into working set
      for (int t=0; t<bjsize; t++) {
        b.copyColumnTo(bj+t,wsb.data,t*ic);
      }
     
      for (int bi=0; bi<rc; bi+=block) {
        int bisize=Math.min(block, rc-bi);
       
        // compute inner block
        for (int i=bi; i<(bi+bisize); i++) {
          int aDataOffset=i*ic;
          for (int j=bj; j<(bj+bjsize); j++) {
            double val=DoubleArrays.dotProduct(a.data, aDataOffset, wsb.data, ic*(j-bj), ic);
            result.unsafeSet(i, j, val);
          }
        }
      }
    }
    return result;
  }
 
  /**
   * Performs fast matrix multiplication using temporary working storage for both matrices
   * @param a
   * @param b
   * @return
   */
  public static Matrix doubleBlockedMultiply(AMatrix a, AMatrix b) {
    int rc=a.rowCount();
    int cc=b.columnCount();
    int ic=a.columnCount();
   
    if ((ic!=b.rowCount())) {
      throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(a,b));
    }   

    Matrix result=Matrix.create(rc, cc);
    if (ic==0) return result;
   
    int block=(WORKING_SET_TARGET/ic)+1;
    // working sets stores up to <block> number of columns from each matrix
    Matrix wsa=Matrix.create(Math.min(block,rc), ic);
    Matrix wsb=Matrix.create(Math.min(block,cc), ic);
   
    for (int bj=0; bj<cc; bj+=block) {
      int bjsize=Math.min(block, cc-bj);
     
      // copy columns into working set
      for (int t=0; t<bjsize; t++) {
        b.copyColumnTo(bj+t,wsb.data,t*ic);
      }
     
      for (int bi=0; bi<rc; bi+=block) {
        int bisize=Math.min(block, rc-bi);
       
        // copy columns into working set
        for (int t=0; t<bisize; t++) {
          b.copyRowTo(bi+t,wsa.data,t*ic);
        }
       
        // compute inner block
        for (int i=bi; i<(bi+bisize); i++) {
          for (int j=bj; j<(bj+bjsize); j++) {
            double val=DoubleArrays.dotProduct(wsa.data, ic*(i-bi), wsb.data, ic*(j-bj), ic);
            result.unsafeSet(i, j, val);
          }
        }
      }
    }
    return result;
  }
 
  public static Matrix directMultiply(Matrix a, AMatrix b) {
    int rc=a.rowCount();
    int cc=b.columnCount();
    int ic=a.columnCount();
   
    if ((ic!=b.rowCount())) {
      throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(a,b));
    }

    Matrix result=Matrix.create(rc,cc);
    double[] tmp=new double[ic];
    for (int j=0; j<cc; j++) {
      b.copyColumnTo(j, tmp, 0);
     
      for (int i=0; i<rc; i++) {
//        double acc=0.0;
//        for (int k=0; k<ic; k++) {
//          acc+=a.unsafeGet(i, k)*tmp[k];
//        }
        double acc=DoubleArrays.dotProduct(a.data, i*ic, tmp, 0, ic);
        result.unsafeSet(i,j,acc);
      }
    }
    return result;   
  }
 
  public static AMatrix naiveMultiply(AMatrix a, AMatrix b) {
    int rc=a.rowCount();
    int cc=b.columnCount();
    int ic=a.columnCount();
   
    if ((ic!=b.rowCount())) {
      throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(a,b));
    }

    Matrix result=Matrix.create(rc,cc);
    for (int i=0; i<rc; i++) {
      for (int j=0; j<cc; j++) {
        double acc=0.0;
        for (int k=0; k<ic; k++) {
          acc+=a.unsafeGet(i, k)*b.unsafeGet(k, j);
        }
        result.unsafeSet(i,j,acc);
      }
    }
    return result;   
  }
}
TOP

Related Classes of mikera.matrixx.algo.Multiplications

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.