Package edu.stanford.nlp.neural

Source Code of edu.stanford.nlp.neural.SimpleTensor$SimpleMatrixIteratorWrapper

package edu.stanford.nlp.neural;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;
import java.util.NoSuchElementException;

import org.ejml.simple.SimpleMatrix;

/**
* This class defines a block tensor, somewhat like a three
* dimensional matrix.  This can be created in various ways, such as
* by providing an array of SimpleMatrix slices, by providing the
* initial size to create a 0-initialized tensor, or by creating a
* random matrix.
*
* @author John Bauer
* @author Richard Socher
*/
public class SimpleTensor implements Serializable {
  private final SimpleMatrix[] slices;

  final int numRows;
  final int numCols;
  final int numSlices;

  /**
   * Creates a zero initialized tensor
   */
  public SimpleTensor(int numRows, int numCols, int numSlices) {
    slices = new SimpleMatrix[numSlices];
    for (int i = 0; i < numSlices; ++i) {
      slices[i] = new SimpleMatrix(numRows, numCols);
    }

    this.numRows = numRows;
    this.numCols = numCols;
    this.numSlices = numSlices;
  }

  /**
   * Copies the data in the slices.  Slices are copied rather than
   * reusing the original SimpleMatrix objects.  Each slice must be
   * the same size.
   */
  public SimpleTensor(SimpleMatrix[] slices) {
    this.numRows = slices[0].numRows();
    this.numCols = slices[0].numCols();
    this.numSlices = slices.length;
    this.slices = new SimpleMatrix[slices.length];
    for (int i = 0; i < numSlices; ++i) {
      if (slices[i].numRows() != numRows || slices[i].numCols() != numCols) {
        throw new IllegalArgumentException("Slice " + i + " has matrix dimensions " + slices[i].numRows() + "," + slices[i].numCols() + ", expected " + numRows + "," + numCols);
      }
      this.slices[i] = new SimpleMatrix(slices[i]);
    }
   
  }

  /**
   * Returns a randomly initialized tensor with values draft from the
   * uniform distribution between minValue and maxValue.
   */
  public static SimpleTensor random(int numRows, int numCols, int numSlices, double minValue, double maxValue, java.util.Random rand) {
    SimpleTensor tensor = new SimpleTensor(numRows, numCols, numSlices);
    for (int i = 0; i < numSlices; ++i) {
      tensor.slices[i] = SimpleMatrix.random(numRows, numCols, minValue, maxValue, rand);
    }
    return tensor;
  }

  /**
   * Number of rows in the tensor
   */
  public int numRows() {
    return numRows;
  }

  /**
   * Number of columns in the tensor
   */
  public int numCols() {
    return numCols;
  }

  /**
   * Number of slices in the tensor
   */
  public int numSlices() {
    return numSlices;
  }

  /**
   * Total number of elements in the tensor
   */
  public int getNumElements() {
    return numRows * numCols * numSlices;
  }

  public void set(double value) {
    for (int slice = 0; slice < numSlices; ++slice) {
      slices[slice].set(value);
    }
  }

  /**
   * Returns a new tensor which has the values of the original tensor
   * scaled by <code>scaling</code>.  The original object is
   * unaffected.
   */
  public SimpleTensor scale(double scaling) {
    SimpleTensor result = new SimpleTensor(numRows, numCols, numSlices);
    for (int slice = 0; slice < numSlices; ++slice) {
      result.slices[slice] = slices[slice].scale(scaling);
    }
    return result;
  }

  /**
   * Returns <code>other</code> added to the current object, which is unaffected.
   */
  public SimpleTensor plus(SimpleTensor other) {
    if (other.numRows != numRows || other.numCols != numCols || other.numSlices != numSlices) {
      throw new IllegalArgumentException("Sizes of tensors do not match.  Our size: " + numRows + "," + numCols + "," + numSlices + "; other size " + other.numRows + "," + other.numCols + "," + other.numSlices);
    }
    SimpleTensor result = new SimpleTensor(numRows, numCols, numSlices);
    for (int i = 0; i < numSlices; ++i) {
      result.slices[i] = slices[i].plus(other.slices[i]);
    }
    return result;
  }

  /**
   * Performs elementwise multiplication on the tensors.  The original
   * objects are unaffected.
   */
  public SimpleTensor elementMult(SimpleTensor other) {
    if (other.numRows != numRows || other.numCols != numCols || other.numSlices != numSlices) {
      throw new IllegalArgumentException("Sizes of tensors do not match.  Our size: " + numRows + "," + numCols + "," + numSlices + "; other size " + other.numRows + "," + other.numCols + "," + other.numSlices);
    }
    SimpleTensor result = new SimpleTensor(numRows, numCols, numSlices);
    for (int i = 0; i < numSlices; ++i) {
      result.slices[i] = slices[i].elementMult(other.slices[i]);
    }
    return result;
  }

  /**
   * Returns the sum of all elements in the tensor.
   */
  public double elementSum() {
    double sum = 0.0;
    for (SimpleMatrix slice : slices) {
      sum += slice.elementSum();
    }
    return sum;
  }

  /**
   * Use the given <code>matrix</code> in place of <code>slice</code>.
   * Does not copy the <code>matrix</code>, but rather uses the actual object.
   */
  public void setSlice(int slice, SimpleMatrix matrix) {
    if (slice < 0 || slice >= numSlices) {
      throw new IllegalArgumentException("Unexpected slice number " + slice + " for tensor with " + numSlices + " slices");
    }
    if (matrix.numCols() != numCols) {
      throw new IllegalArgumentException("Incompatible matrix size.  Has " + matrix.numCols() + " columns, tensor has " + numCols);
    }
    if (matrix.numRows() != numRows) {
      throw new IllegalArgumentException("Incompatible matrix size.  Has " + matrix.numRows() + " columns, tensor has " + numRows);
    }
    slices[slice] = matrix;
  }

  /**
   * Returns the SimpleMatrix at <code>slice</code>.
   * <br>
   * The actual slice is returned - do not alter this unless you know what you are doing.
   */
  public SimpleMatrix getSlice(int slice) {
    if (slice < 0 || slice >= numSlices) {
      throw new IllegalArgumentException("Unexpected slice number " + slice + " for tensor with " + numSlices + " slices");
    }
    return slices[slice];
  }

  /**
   * Returns a column vector where each entry is the nth bilinear
   * product of the nth slices of the two tensors.
   */
  public SimpleMatrix bilinearProducts(SimpleMatrix in) {
    if (in.numCols() != 1) {
      throw new AssertionError("Expected a column vector");
    }
    if (in.numRows() != numCols) {
      throw new AssertionError("Number of rows in the input does not match number of columns in tensor");
    }
    if (numRows != numCols) {
      throw new AssertionError("Can only perform this operation on a SimpleTensor with square slices");
    }
    SimpleMatrix inT = in.transpose();
    SimpleMatrix out = new SimpleMatrix(numSlices, 1);
    for (int slice = 0; slice < numSlices; ++slice) {
      double result = inT.mult(slices[slice]).mult(in).get(0);
      out.set(slice, result);
    }
    return out;
  }

  /**
   * Returns true iff every element of the tensor is 0
   */
  public boolean isZero() {
    for (int i = 0; i < numSlices; ++i) {
      if (!NeuralUtils.isZero(slices[i])) {
        return false;
      }
    }
    return true;
  }

  /**
   * Returns an iterator over the <code>SimpleMatrix</code> objects contained in the tensor.
   */
  public Iterator<SimpleMatrix> iteratorSimpleMatrix() {
    return Arrays.asList(slices).iterator();
  }

  /**
   * Returns an Iterator which returns the SimpleMatrices represented
   * by an Iterator over tensors.  This is useful for if you want to
   * perform some operation on each of the SimpleMatrix slices, such
   * as turning them into a parameter stack.
   */
  public static Iterator<SimpleMatrix> iteratorSimpleMatrix(Iterator<SimpleTensor> tensors) {
    return new SimpleMatrixIteratorWrapper(tensors);
  }

  private static class SimpleMatrixIteratorWrapper implements Iterator<SimpleMatrix> {
    Iterator<SimpleTensor> tensors;
    Iterator<SimpleMatrix> currentIterator;

    public SimpleMatrixIteratorWrapper(Iterator<SimpleTensor> tensors) {
      this.tensors = tensors;
      advanceIterator();
    }

    public boolean hasNext() {
      if (currentIterator == null) {
        return false;
      }
      if (currentIterator.hasNext()) {
        return true;
      }
      advanceIterator();
      return (currentIterator != null);
    }

    public SimpleMatrix next() {
      if (currentIterator != null && currentIterator.hasNext()) {
        return currentIterator.next();
      }
      advanceIterator();
      if (currentIterator != null) {
        return currentIterator.next();
      }
      throw new NoSuchElementException();
    }

    private void advanceIterator() {
      if (currentIterator != null && currentIterator.hasNext()) {
        return;
      }
      while (tensors.hasNext()) {
        currentIterator = tensors.next().iteratorSimpleMatrix();
        if (currentIterator.hasNext()) {
          return;
        }
      }
      currentIterator = null;
    }

    public void remove() {
      throw new UnsupportedOperationException();
    }
  }

  @Override
  public String toString() {
    StringBuilder result = new StringBuilder();
    for (int slice = 0; slice < numSlices; ++slice) {
      result.append("Slice " + slice + "\n");
      result.append(slices[slice]);
    }
    return result.toString();
  }

  private static final long serialVersionUID = 1;
}
TOP

Related Classes of edu.stanford.nlp.neural.SimpleTensor$SimpleMatrixIteratorWrapper

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.