Package net.fec.openrq.util.linearalgebra.matrix.sparse

Source Code of net.fec.openrq.util.linearalgebra.matrix.sparse.CRSByteMatrix

/*
* Copyright 2014 OpenRQ Team
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* Copyright 2011-2014, by Vladimir Kostyukov and Contributors.
*
* This file is part of la4j project (http://la4j.org)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* You may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* Contributor(s): Chandler May
* Maxim Samoylov
* Anveshi Charuvaka
* Clement Skau
* Catherine da Graca
*/
package net.fec.openrq.util.linearalgebra.matrix.sparse;


import static net.fec.openrq.util.math.OctetOps.aIsGreaterThanB;
import static net.fec.openrq.util.math.OctetOps.aIsLessThanB;
import static net.fec.openrq.util.math.OctetOps.aPlusB;
import static net.fec.openrq.util.math.OctetOps.aTimesB;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.WritableByteChannel;
import java.util.Objects;

import net.fec.openrq.util.checking.Indexables;
import net.fec.openrq.util.linearalgebra.LinearAlgebra;
import net.fec.openrq.util.linearalgebra.factory.Factory;
import net.fec.openrq.util.linearalgebra.io.ByteVectorIterator;
import net.fec.openrq.util.linearalgebra.matrix.ByteMatrices;
import net.fec.openrq.util.linearalgebra.matrix.ByteMatrix;
import net.fec.openrq.util.linearalgebra.matrix.functor.MatrixFunction;
import net.fec.openrq.util.linearalgebra.matrix.functor.MatrixProcedure;
import net.fec.openrq.util.linearalgebra.matrix.source.MatrixSource;
import net.fec.openrq.util.linearalgebra.serialize.Serialization;
import net.fec.openrq.util.linearalgebra.vector.ByteVector;


/**
* This is a CRS (Compressed Row Storage) matrix class.
*/
public class CRSByteMatrix extends AbstractCompressedByteMatrix implements SparseByteMatrix {

    private final SparseVectors sparseRows;


    public CRSByteMatrix() {

        this(0, 0);
    }

    public CRSByteMatrix(int rows, int columns, byte array[]) {

        this(ByteMatrices.asArray1DSource(rows, columns, array));
    }

    public CRSByteMatrix(ByteMatrix matrix) {

        this(ByteMatrices.asMatrixSource(matrix));
    }

    public CRSByteMatrix(byte array[][]) {

        this(ByteMatrices.asArray2DSource(array));
    }

    public CRSByteMatrix(MatrixSource source) {

        this(source.rows(), source.columns());

        for (int i = 0; i < rows(); i++) {
            sparseRows.initializeVector(i, ByteMatrices.asRowVectorSource(source, i));
        }
    }

    public CRSByteMatrix(int rows, int columns) {

        super(LinearAlgebra.CRS_FACTORY, rows, columns);
        this.sparseRows = new SparseVectors(rows, columns);
    }

    public CRSByteMatrix(int rows, int columns, byte columnValues[][], int columnIndices[][], int[] rowCardinalities) {

        super(LinearAlgebra.CRS_FACTORY, rows, columns);
        this.sparseRows = new SparseVectors(rows, columns, columnValues, columnIndices, rowCardinalities);
    }

    private CRSByteMatrix(int rows, int columns, SparseVectors sparseRows) {

        super(LinearAlgebra.CRS_FACTORY, rows, columns);
        this.sparseRows = Objects.requireNonNull(sparseRows);
    }

    @Override
    public int cardinality() {

        int cardinality = 0;
        for (int i = 0; i < rows(); i++) {
            cardinality += sparseRows.vectorR(i).nonZeros();
        }

        return cardinality;
    }

    @Override
    public byte safeGet(int i, int j) {

        return sparseRows.vectorR(i).get(j);
    }

    @Override
    public void safeSet(int i, int j, byte value) {

        sparseRows.vectorRW(i).set(j, value);
    }

    @Override
    public void clear() {

        for (int i = 0; i < rows(); i++) {
            clearRow(i);
        }
    }

    @Override
    public void clearRow(int i) {

        sparseRows.vectorR(i).clear(); // this is non mutable on an empty vector
    }

    // =========================================================================
    // Optimized multiplications that take advantage of row sparsity in matrix.

    @Override
    public ByteMatrix multiply(byte value) {

        return multiply(value, factory());
    }

    @Override
    public ByteMatrix multiply(byte value, Factory factory) {

        ensureFactoryIsNotNull(factory);

        ByteMatrix result = blank(factory);

        if (value != 0) {
            for (int i = 0; i < rows(); i++) {
                ByteVectorIterator it = nonZeroRowIterator(i);
                while (it.hasNext()) {
                    it.next();
                    final byte prod = aTimesB(value, it.get());
                    result.set(i, it.index(), prod);
                }
            }
        }

        return result;
    }

    @Override
    public ByteVector multiply(ByteVector vector) {

        return multiply(vector, factory());
    }

    @Override
    public ByteVector multiply(ByteVector vector, Factory factory) {

        ensureFactoryIsNotNull(factory);
        ensureArgumentIsNotNull(vector, "vector");

        if (columns() != vector.length()) {
            fail("Wrong vector length: " + vector.length() + ". Should be: " + columns() + ".");
        }

        ByteVector result = factory.createVector(rows());

        for (int i = 0; i < rows(); i++) {
            byte acc = 0;
            ByteVectorIterator it = nonZeroRowIterator(i);
            while (it.hasNext()) {
                it.next();
                final byte prod = aTimesB(it.get(), vector.get(it.index()));
                acc = aPlusB(acc, prod);
            }

            if (acc != 0) {
                result.set(i, acc);
            }
        }

        return result;
    }

    @Override
    public ByteMatrix multiply(ByteMatrix matrix) {

        return multiply(matrix, factory());
    }

    @Override
    public ByteMatrix multiply(ByteMatrix matrix, Factory factory) {

        ensureFactoryIsNotNull(factory);
        ensureArgumentIsNotNull(matrix, "matrix");

        if (columns() != matrix.rows()) {
            fail("Wrong matrix dimensions: " + matrix.rows() + "x" + matrix.columns() +
                 ". Should be: " + columns() + "x_.");
        }

        ByteMatrix result = factory.createMatrix(rows(), matrix.columns());

        for (int i = 0; i < rows(); i++) {
            for (int j = 0; j < result.columns(); j++) {
                byte acc = 0;
                ByteVectorIterator it = nonZeroRowIterator(i);
                while (it.hasNext()) {
                    it.next();
                    final byte prod = aTimesB(it.get(), matrix.get(it.index(), j));
                    acc = aPlusB(acc, prod);
                }

                if (acc != 0) {
                    result.set(i, j, acc);
                }
            }
        }

        return result;
    }

    @Override
    public ByteMatrix multiply(
        ByteMatrix matrix,
        int fromThisRow,
        int toThisRow,
        int fromThisColumn,
        int toThisColumn,
        int fromOtherRow,
        int toOtherRow,
        int fromOtherColumn,
        int toOtherColumn)
    {

        return multiply(
            matrix,
            fromThisRow, toThisRow,
            fromThisColumn, toThisColumn,
            fromOtherRow, toOtherRow,
            fromOtherColumn, toOtherColumn,
            factory());
    }

    @Override
    public ByteMatrix multiply(
        ByteMatrix matrix,
        int fromThisRow,
        int toThisRow,
        int fromThisColumn,
        int toThisColumn,
        int fromOtherRow,
        int toOtherRow,
        int fromOtherColumn,
        int toOtherColumn,
        Factory factory)
    {

        ensureFactoryIsNotNull(factory);
        ensureArgumentIsNotNull(matrix, "matrix");
        Indexables.checkFromToBounds(fromThisRow, toThisRow, rows());
        Indexables.checkFromToBounds(fromThisColumn, toThisColumn, columns());
        Indexables.checkFromToBounds(fromOtherRow, fromOtherRow, matrix.rows());
        Indexables.checkFromToBounds(fromOtherColumn, toOtherColumn, matrix.columns());

        if ((toThisColumn - fromThisColumn) != (toOtherRow - fromOtherRow)) {
            fail("Wrong matrix dimensions: " +
                 (toOtherRow - fromOtherRow) + "x" + (toOtherColumn - fromOtherColumn) +
                 ". Should be: " + (toThisColumn - fromThisColumn) + "x_.");
        }

        ByteMatrix result = factory.createMatrix(toThisRow - fromThisRow, toOtherColumn - fromOtherColumn);

        for (int i = fromThisRow; i < toThisRow; i++) {
            for (int j = fromOtherColumn; j < toOtherColumn; j++) {
                byte acc = 0;
                ByteVectorIterator it = nonZeroRowIterator(i, fromThisColumn, toThisColumn);
                while (it.hasNext()) {
                    it.next();
                    final byte prod = aTimesB(it.get(), matrix.get(it.index(), j));
                    acc = aPlusB(acc, prod);
                }

                if (acc != 0) {
                    result.set(i - fromThisRow, j - fromOtherColumn, acc);
                }
            }
        }

        return result;
    }

    @Override
    public ByteVector multiplyRow(int i, ByteMatrix matrix) {

        return multiplyRow(i, matrix, factory());
    }

    @Override
    public ByteVector multiplyRow(int i, ByteMatrix matrix, Factory factory) {

        ensureFactoryIsNotNull(factory);
        ensureArgumentIsNotNull(matrix, "matrix");

        if (columns() != matrix.rows()) {
            fail("Wrong matrix dimensions: " + matrix.rows() + "x" + matrix.columns() +
                 ". Should be: " + columns() + "x_.");
        }

        ByteVector result = factory.createVector(matrix.columns());

        for (int j = 0; j < matrix.columns(); j++) {
            byte acc = 0;

            ByteVectorIterator it = nonZeroRowIterator(i);
            while (it.hasNext()) {
                it.next();
                final byte prod = aTimesB(it.get(), matrix.get(it.index(), j));
                acc = aPlusB(acc, prod);
            }

            result.set(j, acc);
        }

        return result;
    }

    @Override
    public ByteVector multiplyRow(int i, ByteMatrix matrix, int fromColumn, int toColumn) {

        return multiplyRow(i, matrix, fromColumn, toColumn, factory());
    }

    @Override
    public ByteVector multiplyRow(int i, ByteMatrix matrix, int fromColumn, int toColumn, Factory factory) {

        ensureFactoryIsNotNull(factory);
        ensureArgumentIsNotNull(matrix, "matrix");
        Indexables.checkFromToBounds(fromColumn, toColumn, columns());

        if ((toColumn - fromColumn) != matrix.rows()) {
            fail("Wrong matrix dimensions: " + matrix.rows() + "x" + matrix.columns() +
                 ". Should be: " + (toColumn - fromColumn) + "x_.");
        }

        ByteVector result = factory.createVector(matrix.columns());

        for (int j = 0; j < matrix.columns(); j++) {
            byte acc = 0;

            ByteVectorIterator it = nonZeroRowIterator(i, fromColumn, toColumn);
            while (it.hasNext()) {
                it.next();
                final byte prod = aTimesB(it.get(), matrix.get(it.index() - fromColumn, j));
                acc = aPlusB(acc, prod);
            }

            result.set(j, acc);
        }

        return result;
    }

    // Optimized multiplications that take advantage of row sparsity in matrix.
    // =========================================================================

    @Override
    public ByteMatrix transpose() {

        return transpose(factory());
    }

    @Override
    public ByteMatrix transpose(Factory factory) {

        ensureFactoryIsNotNull(factory);

        ByteMatrix result = factory.createMatrix(columns(), rows());

        for (int i = 0; i < rows(); i++) {
            ByteVectorIterator it = nonZeroRowIterator(i);
            while (it.hasNext()) {
                it.next();
                result.set(it.index(), i, it.get());
            }
        }

        return result;
    }

    @Override
    public void divideRowInPlace(int i, byte value) {

        Indexables.checkIndexBounds(i, rows());

        sparseRows.vectorRW(i).divideInPlace(value);
    }

    @Override
    public void divideRowInPlace(int i, byte value, int fromColumn, int toColumn) {

        Indexables.checkIndexBounds(i, rows());
        Indexables.checkFromToBounds(fromColumn, toColumn, columns());

        sparseRows.vectorRW(i).divideInPlace(value, fromColumn, toColumn);
    }

    @Override
    public ByteVector getRow(int i) {

        Indexables.checkIndexBounds(i, rows());

        return sparseRows.vectorR(i).copy();
    }

    @Override
    public void swapRows(int i, int j) {

        Indexables.checkIndexBounds(i, rows());
        Indexables.checkIndexBounds(j, rows());

        sparseRows.swapVectors(i, j);
    }

    @Override
    public void swapColumns(int i, int j) {

        Indexables.checkIndexBounds(i, columns());
        Indexables.checkIndexBounds(j, columns());

        if (i != j) {
            for (int row = 0; row < rows(); row++) {
                sparseRows.vectorR(row).swap(i, j); // vectorR because swap is non-destructive in empty vectors
            }
        }
    }

    @Override
    public ByteMatrix copy() {

        return new CRSByteMatrix(rows(), columns(), sparseRows.copy());
    }

    @Override
    public boolean nonZeroAt(int i, int j) {

        checkBounds(i, j);

        return sparseRows.vectorR(i).nonZeroAt(j);
    }

    @Override
    public int nonZerosInRow(int i) {

        Indexables.checkIndexBounds(i, rows());

        return sparseRows.vectorR(i).nonZeros();
    }

    @Override
    public int nonZerosInRow(int i, int fromColumn, int toColumn) {

        Indexables.checkIndexBounds(i, rows());
        Indexables.checkFromToBounds(fromColumn, toColumn, columns());

        return sparseRows.vectorR(i).nonZeros(fromColumn, toColumn);
    }

    @Override
    public int[] nonZeroPositionsInRow(int i) {

        Indexables.checkIndexBounds(i, rows());

        return sparseRows.vectorR(i).nonZeroPositions();
    }

    @Override
    public int[] nonZeroPositionsInRow(int i, int fromColumn, int toColumn) {

        Indexables.checkIndexBounds(i, rows());
        Indexables.checkFromToBounds(fromColumn, toColumn, columns());

        return sparseRows.vectorR(i).nonZeroPositions(fromColumn, toColumn);
    }

    @Override
    public void each(MatrixProcedure procedure) {

        for (int i = 0; i < rows(); i++) {
            ByteVectorIterator it = rowIterator(i);
            while (it.hasNext()) {
                it.next();
                procedure.apply(i, it.index(), it.get());
            }
        }
    }

    @Override
    public void eachNonZero(MatrixProcedure procedure) {

        for (int i = 0; i < rows(); i++) {
            ByteVectorIterator it = nonZeroRowIterator(i);
            while (it.hasNext()) {
                it.next();
                procedure.apply(i, it.index(), it.get());
            }
        }
    }

    @Override
    public void safeUpdate(int i, int j, MatrixFunction function) {

        sparseRows.vectorRW(i).update(j, ByteMatrices.asRowVectorFunction(function, i));
    }

    @Override
    public void updateNonZero(MatrixFunction function) {

        for (int i = 0; i < rows(); i++) {
            ByteVectorIterator it = nonZeroRowIterator(i);
            while (it.hasNext()) {
                it.next();
                it.set(function.evaluate(i, it.index(), it.get()));
            }
        }
    }

    @Override
    public void addRowsInPlace(int srcRow, int destRow) {

        Indexables.checkIndexBounds(srcRow, rows());
        Indexables.checkIndexBounds(destRow, rows());

        sparseRows.vectorRW(destRow).addInPlace(sparseRows.vectorR(srcRow));
    }

    @Override
    public void addRowsInPlace(int srcRow, int destRow, int fromColumn, int toColumn) {

        Indexables.checkIndexBounds(srcRow, rows());
        Indexables.checkIndexBounds(destRow, rows());
        Indexables.checkFromToBounds(fromColumn, toColumn, columns());

        sparseRows.vectorRW(destRow).addInPlace(sparseRows.vectorR(srcRow), fromColumn, toColumn);
    }

    @Override
    public void addRowsInPlace(byte srcMultiplier, int srcRow, int destRow) {

        Indexables.checkIndexBounds(srcRow, rows());
        Indexables.checkIndexBounds(destRow, rows());

        sparseRows.vectorRW(destRow).addInPlace(srcMultiplier, sparseRows.vectorR(srcRow));
    }

    @Override
    public void addRowsInPlace(byte srcMultiplier, int srcRow, int destRow, int fromColumn, int toColumn) {

        Indexables.checkIndexBounds(srcRow, rows());
        Indexables.checkIndexBounds(destRow, rows());
        Indexables.checkFromToBounds(fromColumn, toColumn, columns());

        sparseRows.vectorRW(destRow).addInPlace(srcMultiplier, sparseRows.vectorR(srcRow), fromColumn, toColumn);
    }

    @Override
    public byte maxInRow(int i) {

        byte max = foldNonZeroInRow(i, ByteMatrices.mkMaxAccumulator());
        if (sparseRows.vectorR(i).nonZeros() == columns() || aIsGreaterThanB(max, (byte)0)) {
            return max;
        }
        else {
            return 0;
        }
    }

    @Override
    public byte minInRow(int i) {

        byte min = foldNonZeroInRow(i, ByteMatrices.mkMinAccumulator());
        if (sparseRows.vectorR(i).nonZeros() == columns() || aIsLessThanB(min, (byte)0)) {
            return min;
        }
        else {
            return 0;
        }
    }

    @Override
    public ByteVectorIterator rowIterator(int i) {

        Indexables.checkIndexBounds(i, rows());
        return sparseRows.vectorRW(i).iterator();
    }

    @Override
    public ByteVectorIterator rowIterator(int i, int fromColumn, int toColumn) {

        Indexables.checkIndexBounds(i, rows());
        Indexables.checkFromToBounds(fromColumn, toColumn, columns());
        return sparseRows.vectorRW(i).iterator(fromColumn, toColumn);
    }

    @Override
    public ByteVectorIterator nonZeroRowIterator(int i) {

        Indexables.checkIndexBounds(i, rows());
        return sparseRows.vectorRW(i).nonZeroIterator();
    }

    @Override
    public ByteVectorIterator nonZeroRowIterator(int i, int fromColumn, int toColumn) {

        Indexables.checkIndexBounds(i, rows());
        Indexables.checkFromToBounds(fromColumn, toColumn, columns());
        return sparseRows.vectorRW(i).nonZeroIterator(fromColumn, toColumn);
    }

    @Override
    public ByteBuffer serializeToBuffer() {

        final ByteBuffer buffer = ByteBuffer.allocate(getSerializedDataSize());
        Serialization.writeType(buffer, Serialization.Type.SPARSE_ROW_MATRIX);
        Serialization.writeMatrixRows(buffer, rows());
        Serialization.writeMatrixColumns(buffer, columns());

        for (int i = 0; i < rows(); i++) {
            Serialization.writeMatrixRowCardinality(buffer, nonZerosInRow(i));
            ByteVectorIterator it = nonZeroRowIterator(i);
            while (it.hasNext()) {
                it.next();
                Serialization.writeMatrixColumnIndex(buffer, it.index());
                Serialization.writeMatrixValue(buffer, it.get());
            }
        }

        buffer.rewind();
        return buffer;
    }

    @Override
    public void serializeToChannel(WritableByteChannel ch) throws IOException {

        Serialization.writeType(ch, Serialization.Type.SPARSE_ROW_MATRIX);
        Serialization.writeMatrixRows(ch, rows());
        Serialization.writeMatrixColumns(ch, columns());

        for (int i = 0; i < rows(); i++) {
            Serialization.writeMatrixRowCardinality(ch, nonZerosInRow(i));
            ByteVectorIterator it = nonZeroRowIterator(i);
            while (it.hasNext()) {
                it.next();
                Serialization.writeMatrixColumnIndex(ch, it.index());
                Serialization.writeMatrixValue(ch, it.get());
            }
        }
    }

    private int getSerializedDataSize() {

        final long dataSize = Serialization.SERIALIZATION_TYPE_NUMBYTES +
                              Serialization.MATRIX_ROWS_NUMBYTES +
                              Serialization.MATRIX_COLUMNS_NUMBYTES +
                              Serialization.MATRIX_ROW_CARDINALITY_NUMBYTES * (long)rows() +
                              Serialization.MATRIX_COLUMN_INDEX_NUMBYTES * (long)cardinality() +
                              cardinality();

        if (dataSize > Integer.MAX_VALUE) {
            throw new UnsupportedOperationException("matrix is too large to be serialized");
        }

        return (int)dataSize;
    }
}
TOP

Related Classes of net.fec.openrq.util.linearalgebra.matrix.sparse.CRSByteMatrix

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.