Package net.myrrix.common.math

Source Code of net.myrrix.common.math.MatrixUtils

/*
* Copyright Myrrix Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*      http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package net.myrrix.common.math;

import java.lang.reflect.Field;
import java.util.Arrays;

import com.google.common.base.Preconditions;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;

import net.myrrix.common.ClassUtils;
import net.myrrix.common.collection.FastByIDFloatMap;
import net.myrrix.common.collection.FastByIDMap;
import net.myrrix.common.collection.FastIDSet;

/**
* Contains utility methods for dealing with matrices, which are here represented as
* {@link FastByIDMap}s of {@link FastByIDFloatMap}s, or of {@code float[]}.
*
* @author Sean Owen
* @since 1.0
*/
public final class MatrixUtils {

  private static final int PRINT_COLUMN_WIDTH = 12;
  // This hack saves a lot of time spent copying out data from Array2DRowRealMatrix objects
  private static final Field MATRIX_DATA_FIELD;
  private static final LinearSystemSolver MATRIX_INVERTER;
  static {
    MATRIX_DATA_FIELD = ClassUtils.loadField(Array2DRowRealMatrix.class, "data");
    String lssClassName = Boolean.parseBoolean(System.getProperty("common.matrix.nativeMath", "false")) ?
        "net.myrrix.common.math.JBlasLinearSystemSolver" : "net.myrrix.common.math.CommonsMathLinearSystemSolver";
    MATRIX_INVERTER = ClassUtils.loadInstanceOf(lssClassName, LinearSystemSolver.class);
  }

  private MatrixUtils() {
  }

  /**
   * Efficiently increments an entry in two parallel, sparse matrices.
   *
   * @param row row to increment
   * @param column column to increment
   * @param value increment value
   * @param RbyRow matrix R to update, keyed by row
   * @param RbyColumn matrix R to update, keyed by column
   */
  public static void addTo(long row,
                           long column,
                           float value,
                           FastByIDMap<FastByIDFloatMap> RbyRow,
                           FastByIDMap<FastByIDFloatMap> RbyColumn) {
    addToByRow(row, column, value, RbyRow);
    addToByRow(column, row, value, RbyColumn);
  }

  /**
   * Efficiently increments an entry in a row-major sparse matrix.
   *
   * @param row row to increment
   * @param column column to increment
   * @param value increment value
   * @param RbyRow matrix R to update, keyed by row
   */
  private static void addToByRow(long row,
                                 long column,
                                 float value,
                                 FastByIDMap<FastByIDFloatMap> RbyRow) {

    FastByIDFloatMap theRow = RbyRow.get(row);
    if (theRow == null) {
      theRow = new FastByIDFloatMap();
      RbyRow.put(row, theRow);
    }
    theRow.increment(column, value);
  }

  /**
   * Efficiently removes an entry in two parallel, sparse matrices.
   *
   * @param row row to remove
   * @param column column to remove
   * @param RbyRow matrix R to update, keyed by row
   * @param RbyColumn matrix R to update, keyed by column
   */
  public static void remove(long row,
                            long column,
                            FastByIDMap<FastByIDFloatMap> RbyRow,
                            FastByIDMap<FastByIDFloatMap> RbyColumn) {
    removeByRow(row, column, RbyRow);
    removeByRow(column, row, RbyColumn);
  }
 
  /**
   * Efficiently removes an entry from a row-major sparse matrix.
   *
   * @param row row to remove
   * @param column column to remove
   * @param RbyRow matrix R to update, keyed by row
   */
  private static void removeByRow(long row, long column, FastByIDMap<FastByIDFloatMap> RbyRow) {
    FastByIDFloatMap theRow = RbyRow.get(row);
    if (theRow != null) {
      theRow.remove(column);
      if (theRow.isEmpty()) {
        RbyRow.remove(row);
      }
    }
  }

  /**
   * @return {@link LinearSystemSolver#isNonSingular(RealMatrix)}
   */
  public static boolean isNonSingular(RealMatrix M) {
    return MATRIX_INVERTER.isNonSingular(M);   
  }

  /**
   * @return {@link LinearSystemSolver#getSolver(RealMatrix)}
   */
  public static Solver getSolver(RealMatrix M) {
    return MATRIX_INVERTER.getSolver(M);
  }

  /**
   * @param M small {@link RealMatrix}
   * @param S wide, short matrix
   * @return M * S as a newly allocated matrix
   */
  public static FastByIDMap<float[]> multiply(RealMatrix M, FastByIDMap<float[]> S) {
    FastByIDMap<float[]> result = new FastByIDMap<float[]>(S.size());
    double[][] matrixData = accessMatrixDataDirectly(M);
    for (FastByIDMap.MapEntry<float[]> entry : S.entrySet()) {
      result.put(entry.getKey(), matrixMultiply(matrixData, entry.getValue()));
    }
    return result;
  }

  public static RealMatrix multiplyXYT(FastByIDMap<float[]> X, FastByIDMap<float[]> Y) {
    int Ysize = Y.size();
    int Xsize = X.size();
    RealMatrix result = new Array2DRowRealMatrix(Xsize, Ysize);
    for (int row = 0; row < Xsize; row++) {
      for (int col = 0; col < Ysize; col++) {
        result.setEntry(row, col, SimpleVectorMath.dot(X.get(row), Y.get(col)));
      }
    }
    return result;
  }

  /**
   * @param matrix an {@link Array2DRowRealMatrix}
   * @return its "data" field -- not a copy
   */
  public static double[][] accessMatrixDataDirectly(RealMatrix matrix) {
    try {
      return (double[][]) MATRIX_DATA_FIELD.get(matrix);
    } catch (IllegalAccessException iae) {
      throw new IllegalStateException(iae);
    }
  }

  public static double[] multiply(RealMatrix matrix, float[] V) {
    double[][] M = accessMatrixDataDirectly(matrix);
    int rows = M.length;
    int cols = V.length;
    double[] out = new double[rows];
    for (int i = 0; i < rows; i++) {
      double total = 0.0;
      double[] matrixRow = M[i];
      for (int j = 0; j < cols; j++) {
        total += V[j] * matrixRow[j];
      }
      out[i] = total;
    }
    return out;
  }

  /**
   * @param M matrix
   * @param V column vector
   * @return column vector M * V
   */
  private static float[] matrixMultiply(double[][] M, float[] V) {
    int rows = M.length;
    int cols = V.length;
    float[] out = new float[rows];
    for (int i = 0; i < rows; i++) {
      double total = 0.0;
      double[] matrixRow = M[i];
      for (int j = 0; j < cols; j++) {
        total += V[j] * matrixRow[j];
      }
      out[i] = (float) total;
    }
    return out;
  }

  /**
   * @param M tall, skinny matrix
   * @return MT * M as a dense matrix
   */
  public static RealMatrix transposeTimesSelf(FastByIDMap<float[]> M) {
    if (M == null || M.isEmpty()) {
      return null;
    }
    RealMatrix result = null;
    for (FastByIDMap.MapEntry<float[]> entry : M.entrySet()) {
      float[] vector = entry.getValue();
      int dimension = vector.length;
      if (result == null) {
        result = new Array2DRowRealMatrix(dimension, dimension);
      }
      for (int row = 0; row < dimension; row++) {
        float rowValue = vector[row];
        for (int col = 0; col < dimension; col++) {
          result.addToEntry(row, col, rowValue * vector[col]);
        }
      }
    }
    Preconditions.checkNotNull(result);
    return result;
  }

  /**
   * @param M matrix to print
   * @return a print-friendly rendering of a sparse matrix. Not useful for wide matrices.
   */
  public static String matrixToString(FastByIDMap<FastByIDFloatMap> M) {
    StringBuilder result = new StringBuilder();
    long[] colKeys = unionColumnKeysInOrder(M);
    appendWithPadOrTruncate("", result);
    for (long colKey : colKeys) {
      result.append('\t');
      appendWithPadOrTruncate(colKey, result);
    }
    result.append("\n\n");
    long[] rowKeys = keysInOrder(M);
    for (long rowKey : rowKeys) {
      appendWithPadOrTruncate(rowKey, result);
      FastByIDFloatMap row = M.get(rowKey);
      for (long colKey : colKeys) {
        result.append('\t');
        float value = row.get(colKey);
        if (Float.isNaN(value)) {
          appendWithPadOrTruncate("", result);
        } else {
          appendWithPadOrTruncate(value, result);
        }
      }
      result.append('\n');
    }
    result.append('\n');
    return result.toString();
  }

  private static long[] keysInOrder(FastByIDMap<?> map) {
    FastIDSet keys = new FastIDSet(map.size());
    LongPrimitiveIterator it = map.keySetIterator();
    while (it.hasNext()) {
      keys.add(it.nextLong());
    }
    long[] keysArray = keys.toArray();
    Arrays.sort(keysArray);
    return keysArray;
  }

  private static long[] unionColumnKeysInOrder(FastByIDMap<FastByIDFloatMap> M) {
    FastIDSet keys = new FastIDSet(1000);
    for (FastByIDMap.MapEntry<FastByIDFloatMap> entry : M.entrySet()) {
      LongPrimitiveIterator it = entry.getValue().keySetIterator();
      while (it.hasNext()) {
        keys.add(it.nextLong());
      }
    }
    long[] keysArray = keys.toArray();
    Arrays.sort(keysArray);
    return keysArray;
  }

  private static void appendWithPadOrTruncate(long value, StringBuilder to) {
    appendWithPadOrTruncate(Long.toString(value), to);
  }

  private static void appendWithPadOrTruncate(float value, StringBuilder to) {
    String stringValue = Float.toString(value);
    if (value >= 0.0f) {
      stringValue = ' ' + stringValue;
    }
    appendWithPadOrTruncate(stringValue, to);
  }

  private static void appendWithPadOrTruncate(CharSequence value, StringBuilder to) {
    int length = value.length();
    if (length >= PRINT_COLUMN_WIDTH) {
      to.append(value, 0, PRINT_COLUMN_WIDTH);
    } else {
      for (int i = length; i < PRINT_COLUMN_WIDTH; i++) {
        to.append(' ');
      }
      to.append(value);
    }
  }

}
TOP

Related Classes of net.myrrix.common.math.MatrixUtils

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.