Package no.uib.cipr.matrix.sparse

Source Code of no.uib.cipr.matrix.sparse.LinkedSparseMatrix$ReusableMatrixEntry

package no.uib.cipr.matrix.sparse;

import lombok.AllArgsConstructor;
import lombok.ToString;
import lombok.extern.java.Log;
import no.uib.cipr.matrix.AbstractMatrix;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.MatrixEntry;
import no.uib.cipr.matrix.Vector;
import no.uib.cipr.matrix.io.MatrixInfo;
import no.uib.cipr.matrix.io.MatrixSize;
import no.uib.cipr.matrix.io.MatrixVectorReader;

import java.io.IOException;
import java.util.Iterator;

/**
* A Linked List (with shortcuts to important nodes)
* implementation of an {@code n x m} Matrix with {@code z} elements that
* has a typical {@code O(z / m)} insertion / lookup cost
* and an iterator that traverses columns then rows:
* a good fit for unstructured sparse matrices. A secondary
* link maintains fast transpose iteration.
* <p/>
* However, memory requirements
* ({@code 1 instance (8 bytes), 2 int (16 bytes), 2 ref (16 bytes), 1 double (8 bytes) = 48 bytes}
* per matrix element, plus {@code 8 x numcol + 8 x numrow bytes}s for the cache) are slightly higher
* than structured sparse matrix storage. Note that on 32 bit JVMs, or on 64 bit JVMs
* with <a href="https://wikis.oracle.com/display/HotSpotInternals/CompressedOops">CompressedOops</a>
* enabled, references and ints only cost 4 bytes each, bringing the cost to 28 bytes per element.
*
* @author Sam Halliday
*/
@Log
public class LinkedSparseMatrix extends AbstractMatrix {

  // java.util.LinkedList is doubly linked and therefore too heavyweight.
  @AllArgsConstructor
  @ToString(exclude = {"rowTail", "colTail"})
  static class Node {
    final int row, col;
    double val;
    Node rowTail, colTail;
  }

  // there is a lot of duplicated code in this class between
  // row and col linkages, but subtle differences make it
  // extremely difficult to factor away.
  class Linked {

    final Node head = new Node(0, 0, 0, null, null);

    Node[] rows = new Node[numRows], cols = new Node[numColumns];

    private boolean isHead(int row, int col) {
      return head.row == row && head.col == col;
    }

    // true if node exists, it's row tail exists, and has this row/col
    private boolean isNextByRow(Node node, int row, int col) {
      return node != null && node.rowTail != null && node.rowTail.row == row && node.rowTail.col == col;
    }

    // true if node exists, it's col tail exists, and has this row/col
    private boolean isNextByCol(Node node, int row, int col) {
      return node != null && node.colTail != null && node.colTail.col == col && node.colTail.row == row;
    }

    public double get(int row, int col) {
      if (isHead(row, col))
        return head.val;
      if (col <= row) {
        Node node = findPreceedingByRow(row, col);
        if (isNextByRow(node, row, col))
          return node.rowTail.val;
      } else {
        Node node = findPreceedingByCol(row, col);
        if (isNextByCol(node, row, col))
          return node.colTail.val;
      }
      return 0;
    }

    public void set(int row, int col, double val) {
      if (val == 0) {
        delete(row, col);
        return;
      }
      if (isHead(row, col))
        head.val = val;
      else {
        Node prevRow = findPreceedingByRow(row, col);
        if (isNextByRow(prevRow, row, col))
          prevRow.rowTail.val = val;
        else {
          Node prevCol = findPreceedingByCol(row, col);
          Node nextCol = findNextByCol(row, col);
          prevRow.rowTail = new Node(row, col, val, prevRow.rowTail, nextCol);
          prevCol.colTail = prevRow.rowTail;
          updateCache(prevRow.rowTail);
        }
      }
    }

    private Node findNextByCol(int row, int col) {
      Node cur = cachedByCol(col - 1);
      while (cur != null) {
        if (row < cur.row && col <= cur.col || col < cur.col) return cur;
        cur = cur.colTail;
      }
      return cur;
    }

    private void updateCache(Node inserted) {
      if (rows[inserted.row] == null || inserted.col > rows[inserted.row].col)
        rows[inserted.row] = inserted;
      if (cols[inserted.col] == null || inserted.row > cols[inserted.col].row)
        cols[inserted.col] = inserted;
    }

    private void delete(int row, int col) {
      if (isHead(row, col)) {
        head.val = 0;
        return;
      }
      Node precRow = findPreceedingByRow(row, col);
      Node precCol = findPreceedingByCol(row, col);
      if (isNextByRow(precRow, row, col)) {
        if (rows[row] == precRow.rowTail)
          rows[row] = precRow.row == row ? precRow : null;
        precRow.rowTail = precRow.rowTail.rowTail;
      }
      if (isNextByCol(precCol, row, col)) {
        if (cols[col] == precCol.colTail)
          cols[col] = precCol.col == col ? precCol : null;
        precCol.colTail = precCol.colTail.colTail;
      }
    }

    // returns the node that either references this
    // index, or should reference it if inserted.
    Node findPreceedingByRow(int row, int col) {
      Node last = cachedByRow(row - 1);
      Node cur = last;
      while (cur != null && cur.row <= row) {
        if (cur.row == row && cur.col >= col) return last;
        last = cur;
        cur = cur.rowTail;
      }
      return last;
    }

    // helper for findPreceeding
    private Node cachedByRow(int row) {
      for (int i = row; i >= 0; i--)
        if (rows[i] != null)
          return rows[i];
      return head;
    }

    Node findPreceedingByCol(int row, int col) {
      Node last = cachedByCol(col - 1);
      Node cur = last;
      while (cur != null && cur.col <= col) {
        if (cur.col == col && cur.row >= row) return last;
        last = cur;
        cur = cur.colTail;
      }
      return last;
    }

    private Node cachedByCol(int col) {
      for (int i = col; i >= 0; i--)
        if (cols[i] != null)
          return cols[i];
      return head;
    }

    Node startOfRow(int row) {
      if (row == 0) return head;
      Node prec = findPreceedingByRow(row, 0);
      if (prec.rowTail != null && prec.rowTail.row == row)
        return prec.rowTail;
      return null;
    }

    Node startOfCol(int col) {
      if (col == 0) return head;
      Node prec = findPreceedingByCol(0, col);
      if (prec != null && prec.colTail != null && prec.colTail.col == col)
        return prec.colTail;
      return null;
    }
  }

  Linked links;

  public LinkedSparseMatrix(int numRows, int numColumns) {
    super(numRows, numColumns);
    links = new Linked();
  }

  public LinkedSparseMatrix(Matrix A) {
    super(A);
    links = new Linked();
    set(A);
  }

  public LinkedSparseMatrix(MatrixVectorReader r) throws IOException {
    super(0, 0);
    try {
      MatrixInfo info = r.readMatrixInfo();
      if (info.isComplex()) throw new IllegalArgumentException("complex matrices not supported");
      if (!info.isCoordinate()) throw new IllegalArgumentException("only coordinate matrices supported");
      MatrixSize size = r.readMatrixSize(info);
      numRows = size.numRows();
      numColumns = size.numColumns();
      links = new Linked();

      int nz = size.numEntries();
      int[] row = new int[nz];
      int[] column = new int[nz];
      double[] entry = new double[nz];
      r.readCoordinate(row, column, entry);
      r.add(-1, row);
      r.add(-1, column);
      for (int i = 0; i < nz; ++i)
        set(row[i], column[i], entry[i]);
    } finally {
      r.close();
    }
  }


  @Override
  public Matrix zero() {
    links = new Linked();
    return this;
  }

  @Override
  public double get(int row, int column) {
    return links.get(row, column);
  }

  @Override
  public void set(int row, int column, double value) {
    check(row, column);
    links.set(row, column, value);
  }

  // avoids object creation
  static class ReusableMatrixEntry implements MatrixEntry {

    int row, col;
    double val;

    @Override
    public int column() {
      return col;
    }

    @Override
    public int row() {
      return row;
    }

    @Override
    public double get() {
      return val;
    }

    @Override
    public void set(double value) {
      throw new UnsupportedOperationException();
    }

    @Override
    public String toString() {
      return row + "," + col + "=" + val;
    }

  }

  @Override
  public Iterator<MatrixEntry> iterator() {
    return new Iterator<MatrixEntry>() {
      Node cur = links.head;
      ReusableMatrixEntry entry = new ReusableMatrixEntry();

      @Override
      public boolean hasNext() {
        return cur != null;
      }

      @Override
      public MatrixEntry next() {
        entry.row = cur.row;
        entry.col = cur.col;
        entry.val = cur.val;
        cur = cur.rowTail;
        return entry;
      }

      @Override
      public void remove() {
        throw new UnsupportedOperationException("TODO");
      }
    };
  }

  @Override
  public Matrix scale(double alpha) {
    if (alpha == 0)
      zero();
    else if (alpha != 1) for (MatrixEntry e : this)
      set(e.row(), e.column(), e.get() * alpha);
    return this;
  }

  @Override
  public Matrix copy() {
    return new LinkedSparseMatrix(this);
  }

  @Override
  public Matrix transpose() {
    Linked old = links;
    numRows = numColumns;
    numColumns = old.rows.length;
    links = new Linked();
    Node node = old.head;
    while (node != null) {
      set(node.col, node.row, node.val);
      node = node.rowTail;
    }
    return this;
  }

  @Override
  public Vector multAdd(double alpha, Vector x, Vector y) {
    checkMultAdd(x, y);
    if (alpha == 0) return y;
    Node node = links.head;
    while (node != null) {
      y.add(node.row, alpha * node.val * x.get(node.col));
      node = node.rowTail;
    }
    return y;
  }

  @Override
  public Vector transMultAdd(double alpha, Vector x, Vector y) {
    checkTransMultAdd(x, y);
    if (alpha == 0) return y;
    Node node = links.head;
    while (node != null) {
      y.add(node.col, alpha * node.val * x.get(node.row));
      node = node.colTail;
    }
    return y;
  }

  // TODO: optimise matrix mults based on RHS Matrix

  @Override
  public Matrix multAdd(double alpha, Matrix B, Matrix C) {
    checkMultAdd(B, C);
    if (alpha == 0)
      return C;
    for (int i = 0; i < numRows; i++) {
      Node row = links.startOfRow(i);
      if (row != null)
        for (int j = 0; j < B.numColumns(); j++) {
          Node node = row;
          double v = 0;
          while (node != null && node.row == i) {
            v += (B.get(node.col, j) * node.val);
            node = node.rowTail;
          }
          if (v != 0) C.add(i, j, alpha * v);
        }
    }
    return C;
  }

  @Override
  public Matrix transBmultAdd(double alpha, Matrix B, Matrix C) {
    checkTransBmultAdd(B, C);
    if (alpha == 0)
      return C;
    for (int i = 0; i < numRows; i++) {
      Node row = links.startOfRow(i);
      if (row != null)
        for (int j = 0; j < B.numRows(); j++) {
          Node node = row;
          double v = 0;
          while (node != null && node.row == i) {
            v += (B.get(j, node.col) * node.val);
            node = node.rowTail;
          }
          if (v != 0) C.add(i, j, alpha * v);
        }
    }
    return C;
  }

  @Override
  public Matrix transAmultAdd(double alpha, Matrix B, Matrix C) {
    checkTransAmultAdd(B, C);
    if (alpha == 0)
      return C;
    for (int i = 0; i < numColumns; i++) {
      Node row = links.startOfCol(i);
      if (row != null)
        for (int j = 0; j < B.numColumns(); j++) {
          Node node = row;
          double v = 0;
          while (node != null && node.col == i) {
            v += (B.get(node.row, j) * node.val);
            node = node.colTail;
          }
          if (v != 0) C.add(i, j, alpha * v);
        }
    }
    return C;
  }

  @Override
  public Matrix transABmultAdd(double alpha, Matrix B, Matrix C) {
    checkTransABmultAdd(B, C);
    if (alpha == 0)
      return C;
    for (int i = 0; i < numColumns; i++) {
      Node row = links.startOfCol(i);
      if (row != null)
        for (int j = 0; j < B.numRows(); j++) {
          Node node = row;
          double v = 0;
          while (node != null && node.col == i) {
            v += (B.get(j, node.row) * node.val);
            node = node.colTail;
          }
          if (v != 0) C.add(i, j, alpha * v);
        }
    }
    return C;
  }
}
TOP

Related Classes of no.uib.cipr.matrix.sparse.LinkedSparseMatrix$ReusableMatrixEntry

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.