/*
* Copyright Myrrix Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.myrrix.common.math;
import java.lang.reflect.Field;
import java.util.Arrays;
import com.google.common.base.Preconditions;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import net.myrrix.common.ClassUtils;
import net.myrrix.common.collection.FastByIDFloatMap;
import net.myrrix.common.collection.FastByIDMap;
import net.myrrix.common.collection.FastIDSet;
/**
* Contains utility methods for dealing with matrices, which are here represented as
* {@link FastByIDMap}s of {@link FastByIDFloatMap}s, or of {@code float[]}.
*
* @author Sean Owen
* @since 1.0
*/
public final class MatrixUtils {
private static final int PRINT_COLUMN_WIDTH = 12;
// This hack saves a lot of time spent copying out data from Array2DRowRealMatrix objects
private static final Field MATRIX_DATA_FIELD;
private static final LinearSystemSolver MATRIX_INVERTER;
static {
MATRIX_DATA_FIELD = ClassUtils.loadField(Array2DRowRealMatrix.class, "data");
String lssClassName = Boolean.parseBoolean(System.getProperty("common.matrix.nativeMath", "false")) ?
"net.myrrix.common.math.JBlasLinearSystemSolver" : "net.myrrix.common.math.CommonsMathLinearSystemSolver";
MATRIX_INVERTER = ClassUtils.loadInstanceOf(lssClassName, LinearSystemSolver.class);
}
private MatrixUtils() {
}
/**
* Efficiently increments an entry in two parallel, sparse matrices.
*
* @param row row to increment
* @param column column to increment
* @param value increment value
* @param RbyRow matrix R to update, keyed by row
* @param RbyColumn matrix R to update, keyed by column
*/
public static void addTo(long row,
long column,
float value,
FastByIDMap<FastByIDFloatMap> RbyRow,
FastByIDMap<FastByIDFloatMap> RbyColumn) {
addToByRow(row, column, value, RbyRow);
addToByRow(column, row, value, RbyColumn);
}
/**
* Efficiently increments an entry in a row-major sparse matrix.
*
* @param row row to increment
* @param column column to increment
* @param value increment value
* @param RbyRow matrix R to update, keyed by row
*/
private static void addToByRow(long row,
long column,
float value,
FastByIDMap<FastByIDFloatMap> RbyRow) {
FastByIDFloatMap theRow = RbyRow.get(row);
if (theRow == null) {
theRow = new FastByIDFloatMap();
RbyRow.put(row, theRow);
}
theRow.increment(column, value);
}
/**
* Efficiently removes an entry in two parallel, sparse matrices.
*
* @param row row to remove
* @param column column to remove
* @param RbyRow matrix R to update, keyed by row
* @param RbyColumn matrix R to update, keyed by column
*/
public static void remove(long row,
long column,
FastByIDMap<FastByIDFloatMap> RbyRow,
FastByIDMap<FastByIDFloatMap> RbyColumn) {
removeByRow(row, column, RbyRow);
removeByRow(column, row, RbyColumn);
}
/**
* Efficiently removes an entry from a row-major sparse matrix.
*
* @param row row to remove
* @param column column to remove
* @param RbyRow matrix R to update, keyed by row
*/
private static void removeByRow(long row, long column, FastByIDMap<FastByIDFloatMap> RbyRow) {
FastByIDFloatMap theRow = RbyRow.get(row);
if (theRow != null) {
theRow.remove(column);
if (theRow.isEmpty()) {
RbyRow.remove(row);
}
}
}
/**
* @return {@link LinearSystemSolver#isNonSingular(RealMatrix)}
*/
public static boolean isNonSingular(RealMatrix M) {
return MATRIX_INVERTER.isNonSingular(M);
}
/**
* @return {@link LinearSystemSolver#getSolver(RealMatrix)}
*/
public static Solver getSolver(RealMatrix M) {
return MATRIX_INVERTER.getSolver(M);
}
/**
* @param M small {@link RealMatrix}
* @param S wide, short matrix
* @return M * S as a newly allocated matrix
*/
public static FastByIDMap<float[]> multiply(RealMatrix M, FastByIDMap<float[]> S) {
FastByIDMap<float[]> result = new FastByIDMap<float[]>(S.size());
double[][] matrixData = accessMatrixDataDirectly(M);
for (FastByIDMap.MapEntry<float[]> entry : S.entrySet()) {
result.put(entry.getKey(), matrixMultiply(matrixData, entry.getValue()));
}
return result;
}
public static RealMatrix multiplyXYT(FastByIDMap<float[]> X, FastByIDMap<float[]> Y) {
int Ysize = Y.size();
int Xsize = X.size();
RealMatrix result = new Array2DRowRealMatrix(Xsize, Ysize);
for (int row = 0; row < Xsize; row++) {
for (int col = 0; col < Ysize; col++) {
result.setEntry(row, col, SimpleVectorMath.dot(X.get(row), Y.get(col)));
}
}
return result;
}
/**
* @param matrix an {@link Array2DRowRealMatrix}
* @return its "data" field -- not a copy
*/
public static double[][] accessMatrixDataDirectly(RealMatrix matrix) {
try {
return (double[][]) MATRIX_DATA_FIELD.get(matrix);
} catch (IllegalAccessException iae) {
throw new IllegalStateException(iae);
}
}
public static double[] multiply(RealMatrix matrix, float[] V) {
double[][] M = accessMatrixDataDirectly(matrix);
int rows = M.length;
int cols = V.length;
double[] out = new double[rows];
for (int i = 0; i < rows; i++) {
double total = 0.0;
double[] matrixRow = M[i];
for (int j = 0; j < cols; j++) {
total += V[j] * matrixRow[j];
}
out[i] = total;
}
return out;
}
/**
* @param M matrix
* @param V column vector
* @return column vector M * V
*/
private static float[] matrixMultiply(double[][] M, float[] V) {
int rows = M.length;
int cols = V.length;
float[] out = new float[rows];
for (int i = 0; i < rows; i++) {
double total = 0.0;
double[] matrixRow = M[i];
for (int j = 0; j < cols; j++) {
total += V[j] * matrixRow[j];
}
out[i] = (float) total;
}
return out;
}
/**
* @param M tall, skinny matrix
* @return MT * M as a dense matrix
*/
public static RealMatrix transposeTimesSelf(FastByIDMap<float[]> M) {
if (M == null || M.isEmpty()) {
return null;
}
RealMatrix result = null;
for (FastByIDMap.MapEntry<float[]> entry : M.entrySet()) {
float[] vector = entry.getValue();
int dimension = vector.length;
if (result == null) {
result = new Array2DRowRealMatrix(dimension, dimension);
}
for (int row = 0; row < dimension; row++) {
float rowValue = vector[row];
for (int col = 0; col < dimension; col++) {
result.addToEntry(row, col, rowValue * vector[col]);
}
}
}
Preconditions.checkNotNull(result);
return result;
}
/**
* @param M matrix to print
* @return a print-friendly rendering of a sparse matrix. Not useful for wide matrices.
*/
public static String matrixToString(FastByIDMap<FastByIDFloatMap> M) {
StringBuilder result = new StringBuilder();
long[] colKeys = unionColumnKeysInOrder(M);
appendWithPadOrTruncate("", result);
for (long colKey : colKeys) {
result.append('\t');
appendWithPadOrTruncate(colKey, result);
}
result.append("\n\n");
long[] rowKeys = keysInOrder(M);
for (long rowKey : rowKeys) {
appendWithPadOrTruncate(rowKey, result);
FastByIDFloatMap row = M.get(rowKey);
for (long colKey : colKeys) {
result.append('\t');
float value = row.get(colKey);
if (Float.isNaN(value)) {
appendWithPadOrTruncate("", result);
} else {
appendWithPadOrTruncate(value, result);
}
}
result.append('\n');
}
result.append('\n');
return result.toString();
}
private static long[] keysInOrder(FastByIDMap<?> map) {
FastIDSet keys = new FastIDSet(map.size());
LongPrimitiveIterator it = map.keySetIterator();
while (it.hasNext()) {
keys.add(it.nextLong());
}
long[] keysArray = keys.toArray();
Arrays.sort(keysArray);
return keysArray;
}
private static long[] unionColumnKeysInOrder(FastByIDMap<FastByIDFloatMap> M) {
FastIDSet keys = new FastIDSet(1000);
for (FastByIDMap.MapEntry<FastByIDFloatMap> entry : M.entrySet()) {
LongPrimitiveIterator it = entry.getValue().keySetIterator();
while (it.hasNext()) {
keys.add(it.nextLong());
}
}
long[] keysArray = keys.toArray();
Arrays.sort(keysArray);
return keysArray;
}
private static void appendWithPadOrTruncate(long value, StringBuilder to) {
appendWithPadOrTruncate(Long.toString(value), to);
}
private static void appendWithPadOrTruncate(float value, StringBuilder to) {
String stringValue = Float.toString(value);
if (value >= 0.0f) {
stringValue = ' ' + stringValue;
}
appendWithPadOrTruncate(stringValue, to);
}
private static void appendWithPadOrTruncate(CharSequence value, StringBuilder to) {
int length = value.length();
if (length >= PRINT_COLUMN_WIDTH) {
to.append(value, 0, PRINT_COLUMN_WIDTH);
} else {
for (int i = length; i < PRINT_COLUMN_WIDTH; i++) {
to.append(' ');
}
to.append(value);
}
}
}