Package mikera.arrayz

Source Code of mikera.arrayz.NDArray

package mikera.arrayz;

import java.nio.DoubleBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;

import mikera.arrayz.impl.BaseNDArray;
import mikera.arrayz.impl.IStridedArray;
import mikera.arrayz.impl.ImmutableArray;
import mikera.matrixx.Matrix;
import mikera.vectorz.AVector;
import mikera.vectorz.IOperator;
import mikera.vectorz.Op;
import mikera.vectorz.Vector;
import mikera.vectorz.Vectorz;
import mikera.vectorz.impl.ArrayIndexScalar;
import mikera.vectorz.impl.ArraySubVector;
import mikera.vectorz.impl.SingleDoubleIterator;
import mikera.vectorz.impl.Vector0;
import mikera.vectorz.util.ErrorMessages;
import mikera.vectorz.util.IntArrays;
import mikera.vectorz.util.VectorzException;

/**
* General purpose NDArray class.
*
* Allows arbitrary strided access over a dense double[] array.
*
* @author Mike
*
*/
public final class NDArray extends BaseNDArray {
  private static final long serialVersionUID = -262272579159731240L;

  NDArray(int... shape) {
    super(new double[(int)IntArrays.arrayProduct(shape)],
        shape.length,
        0,
        shape,
        IntArrays.calcStrides(shape));
  }
 
  NDArray(double[] data, int offset, int[] shape, int[] stride) {
    this(data,shape.length,offset,shape,stride);
  }
 
  NDArray(double[] data, int dimensions, int offset, int[] shape, int[] stride) {
    super(data,shape.length,offset,shape,stride);
  }
 
  public static NDArray wrap(double[] data, int[] shape) {
    int dims=shape.length;
    return new NDArray(data,dims,0,shape,IntArrays.calcStrides(shape));
  }
 
  public static NDArray wrap(Vector v) {
    return wrap(v.getArray(),v.getShape());
  }

  public static NDArray wrap(Matrix m) {
    return wrap(m.data,m.getShape());
  }
 
  public static NDArray wrap(IStridedArray a) {
    return new NDArray(a.getArray(),a.getArrayOffset(),a.getShape(),a.getStrides());
  }
 
  public static NDArray wrap(INDArray a) {
    if (!(a instanceof IStridedArray)) throw new IllegalArgumentException(a.getClass()+" is not a strided array!");
    return wrap((IStridedArray)a);
  }
 
  public static NDArray newArray(int... shape) {
    return new NDArray(shape);
  }
 
  @Override
  public void set(double value) {
    if (dimensions==0) {
      data[offset]=value;
    } else if (dimensions==1) {
      int n=sliceCount();
      int st=getStride(0);
      for (int i=0; i<n; i++) {
        data[offset+i*st]=value;
      }
    } else {
      for (INDArray s:getSlices()) {
        s.set(value);
      }
    }
  }

  @Override
  public void set(int x, double value) {
    if (dimensions==1) {
      data[offset+x*getStride(0)]=value;
    } else {
      throw new UnsupportedOperationException(ErrorMessages.invalidIndex(this,x));
    }
  }

  @Override
  public void set(int x, int y, double value) {
    if (dimensions==2) {
      data[offset+x*getStride(0)+y*getStride(1)]=value;
    } else {
      throw new UnsupportedOperationException(ErrorMessages.invalidIndex(this,x,y));
    }
  }

  @Override
  public void set(int[] indexes, double value) {
    int ix=offset;
    if (indexes.length!=dimensions) throw new IllegalArgumentException(ErrorMessages.invalidIndex(this,indexes));
    for (int i=0; i<dimensions; i++) {
      ix+=indexes[i]*getStride(i);
    }
    data[ix]=value;
  }
 
  @Override
  public INDArray getTranspose() {
    return getTransposeView();
  }
 
  @Override
  public INDArray getTransposeView() {
    return Arrayz.wrapStrided(data,offset,IntArrays.reverse(shape),IntArrays.reverse(stride));
  }

  @Override
  public AVector asVector() {
    if (dimensions==0) {
      return ArraySubVector.wrap(data,offset,1);
    } else if (dimensions==1) {
      return Vectorz.wrapStrided(data, offset, getShape(0), getStride(0));
    } else {
      AVector v=Vector0.INSTANCE;
      int n=sliceCount();
      for (int i=0; i<n; i++) {
        v=v.join(slice(i).asVector());
      }
      return v;
    }
  }

  @Override
  public INDArray reshape(int... dimensions) {
    return super.reshape(dimensions);
  }

  @Override
  public INDArray broadcast(int... dimensions) {
    return super.broadcast(dimensions);
  }

  @Override
  public INDArray slice(int majorSlice) {
    // if ((majorSlice<0)||(majorSlice>=shape[0])) throw new IllegalArgumentException(ErrorMessages.invalidSlice(this,majorSlice));
    if (dimensions==0) {
      throw new IllegalArgumentException("Can't slice a 0-d NDArray");
    } else if (dimensions==1) {
      return new ArrayIndexScalar(data,offset+majorSlice*getStride(0));
    } else if (dimensions==2) {
      int st=stride[1];
      return Vectorz.wrapStrided(data, offset+majorSlice*getStride(0), getShape(1), st);
    } else {
      return Arrayz.wrapStrided(data,
          offset+majorSlice*getStride(0),
          Arrays.copyOfRange(shape, 1,dimensions),
          Arrays.copyOfRange(stride, 1,dimensions));
    }
  }
 
  @Override
  public INDArray slice(int dimension, int index) {
    if ((dimension<0)||(dimension>=dimensions)) throw new IllegalArgumentException(ErrorMessages.invalidDimension(this, dimension));
    if (dimension==0) return slice(index);
    if (dimensions==2) {
      // note: dimension must be 1 if we are here
      return Vectorz.wrapStrided(data, offset+index*getStride(1), getShape(0), getStride(0));
    }
    return Arrayz.wrapStrided(data,
        offset+index*stride[dimension],
        IntArrays.removeIndex(shape,index),
        IntArrays.removeIndex(stride,index))
 
 
  @Override
  public NDArray subArray(int[] offsets, int[] shape) {
    int n=dimensions;
    if (offsets.length!=n) throw new IllegalArgumentException(ErrorMessages.invalidIndex(this, offsets));
    if (shape.length!=n) throw new IllegalArgumentException(ErrorMessages.invalidIndex(this, offsets));
   
    if (IntArrays.equals(shape, this.shape)) {
      if (IntArrays.isZero(offsets)) {
        return this;
      } else {
        throw new IllegalArgumentException("Invalid subArray offsets");
      }
    }
   
    return new NDArray(data,
        offset+IntArrays.dotProduct(offsets, stride),
        IntArrays.copyOf(shape),
        stride);
  }

  @Override
  public boolean isMutable() {
    return true;
  }

  @Override
  public boolean isFullyMutable() {
    return true;
  }

  @Override
  public boolean isElementConstrained() {
    return false;
  }

  @Override
  public boolean isView() {
    return (!isPackedArray());
  }

  @Override
  public void applyOp(Op op) {
    if (dimensions==0) {
      data[offset]=op.apply(data[offset]);
    } else if (dimensions==1) {
      int len=sliceCount();
      int st=getStride(0);
      if (st==1) {
        op.applyTo(data, offset, len);
      } else {
        for (int i=0; i<len; i++) {
          data[offset+i*st]=op.apply(data[offset+i*st]);
        }
      }
    } else {
      int n=sliceCount();
      for (int i=0; i<n; i++) {
        slice(i).applyOp(op);
      }   
    }
  }

  @Override
  public void applyOp(IOperator op) {
    applyOp((Op)op);
  }
 
  public boolean equals(NDArray a) {
    if (dimensions!=a.dimensions) return false;
    if (dimensions==0) return get()==a.get();
   
    int sc=sliceCount();
    if (a.sliceCount()!=sc) return false;
   
    for (int i=0; i<sc; i++) {
      if (!(slice(i).equals(a.slice(i)))) return false;
    }
    return true;
  }

  @Override
  public boolean equals(INDArray a) {
    if (a instanceof NDArray) {
      return equals((NDArray)a);
    }
    if (dimensions!=a.dimensionality()) return false;
    if (dimensions==0) return (get()==a.get());

    int sc=sliceCount();
    if (a.sliceCount()!=sc) return false;
   
    for (int i=0; i<sc; i++) {
      if (!(slice(i).equals(a.slice(i)))) return false;
    }
    return true;
  }

  @Override
  public NDArray exactClone() {
    NDArray c=new NDArray(data.clone(),offset,shape.clone(),stride.clone());
    return c;
  }
 
  @Override
  public INDArray clone() {
    return Array.create(this);
  }

  @Override
  public void multiply(double d) {
    if (dimensions==0) {
      data[offset]*=d;
    } else if (dimensions==1) {
      int n=sliceCount();
      for (int i=0; i<n; i++) {
        data[offset+i*getStride(0)]*=d;
      }
    } else {
      int n=sliceCount();
      for (int i=0; i<n; i++) {
        slice(i).scale(d);
      }
    }
  }

  @Override
  public void setElements(int pos,double[] values, int offset, int length) {
    if (length==0) return;
    if (dimensions==0) {
      if (length!=1) throw new IllegalArgumentException("Must have one element!");
      if (pos!=0) throw new IllegalArgumentException("Element index out of bounds: "+pos);
      data[this.offset]=values[offset];
    } else if (dimensions==1) {
      asVector().setElements(pos,values,offset,length);
    } else {
      super.setElements(pos, values, offset, length);
    }
  }
 
  @Override
  public void toDoubleBuffer(DoubleBuffer dest) {
    if (dimensions==0) {
      dest.put(data[offset]);
    } else if (isPackedArray()) {
      dest.put(data,0,data.length);
    } else {
      int sc=sliceCount();
      for (int i=0; i<sc; i++) {
        INDArray s=slice(i);
        s.toDoubleBuffer(dest);
      }
    }
  }

  @Override
  public double[] asDoubleArray() {
    return isPackedArray()?data:null;
  }

  @Override
  public List<INDArray> getSlices() {
    if (dimensions==0) {
      throw new IllegalArgumentException(ErrorMessages.noSlices(this));
    } else {
      ArrayList<INDArray> al=new ArrayList<INDArray>();
      int n=getShape(0);
      for (int i=0; i<n; i++) {
        al.add(slice(i));
      }
      return al;
    }
  }
 
  @Override
  public Iterator<Double> elementIterator() {
    if (dimensionality()==0) {
      return new SingleDoubleIterator(data[offset]);
    } else {
      return super.elementIterator();
    }
  }
 
  @Override public void validate() {
    if (dimensions>shape.length) throw new VectorzException("Insufficient shape data");
    if (dimensions>stride.length) throw new VectorzException("Insufficient stride data");
   
    if ((offset<0)||(offset>=data.length)) throw new VectorzException("Offset out of bounds");
    int[] endIndex=IntArrays.decrementAll(shape);
    int endOffset=offset+IntArrays.dotProduct(endIndex, stride);
    if ((endOffset<0)||(endOffset>data.length)) throw new VectorzException("End offset out of bounds");
    super.validate();
  }
 
  @Override
  public INDArray immutable() {
    return ImmutableArray.create(this);
  }

  @Override
  public double[] getArray() {
    return data;
  }

  public static INDArray wrapStrided(double[] data, int offset,
      int[] shape, int[] strides) {
    return new NDArray(data,offset,shape,strides);
  }
}
TOP

Related Classes of mikera.arrayz.NDArray

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.