package mikera.arrayz;
import java.nio.DoubleBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import mikera.arrayz.impl.BaseNDArray;
import mikera.arrayz.impl.IStridedArray;
import mikera.arrayz.impl.ImmutableArray;
import mikera.matrixx.Matrix;
import mikera.vectorz.AVector;
import mikera.vectorz.IOperator;
import mikera.vectorz.Op;
import mikera.vectorz.Vector;
import mikera.vectorz.Vectorz;
import mikera.vectorz.impl.ArrayIndexScalar;
import mikera.vectorz.impl.ArraySubVector;
import mikera.vectorz.impl.SingleDoubleIterator;
import mikera.vectorz.impl.Vector0;
import mikera.vectorz.util.ErrorMessages;
import mikera.vectorz.util.IntArrays;
import mikera.vectorz.util.VectorzException;
/**
* General purpose NDArray class.
*
* Allows arbitrary strided access over a dense double[] array.
*
* @author Mike
*
*/
public final class NDArray extends BaseNDArray {
private static final long serialVersionUID = -262272579159731240L;
NDArray(int... shape) {
super(new double[(int)IntArrays.arrayProduct(shape)],
shape.length,
0,
shape,
IntArrays.calcStrides(shape));
}
NDArray(double[] data, int offset, int[] shape, int[] stride) {
this(data,shape.length,offset,shape,stride);
}
NDArray(double[] data, int dimensions, int offset, int[] shape, int[] stride) {
super(data,shape.length,offset,shape,stride);
}
public static NDArray wrap(double[] data, int[] shape) {
int dims=shape.length;
return new NDArray(data,dims,0,shape,IntArrays.calcStrides(shape));
}
public static NDArray wrap(Vector v) {
return wrap(v.getArray(),v.getShape());
}
public static NDArray wrap(Matrix m) {
return wrap(m.data,m.getShape());
}
public static NDArray wrap(IStridedArray a) {
return new NDArray(a.getArray(),a.getArrayOffset(),a.getShape(),a.getStrides());
}
public static NDArray wrap(INDArray a) {
if (!(a instanceof IStridedArray)) throw new IllegalArgumentException(a.getClass()+" is not a strided array!");
return wrap((IStridedArray)a);
}
public static NDArray newArray(int... shape) {
return new NDArray(shape);
}
@Override
public void set(double value) {
if (dimensions==0) {
data[offset]=value;
} else if (dimensions==1) {
int n=sliceCount();
int st=getStride(0);
for (int i=0; i<n; i++) {
data[offset+i*st]=value;
}
} else {
for (INDArray s:getSlices()) {
s.set(value);
}
}
}
@Override
public void set(int x, double value) {
if (dimensions==1) {
data[offset+x*getStride(0)]=value;
} else {
throw new UnsupportedOperationException(ErrorMessages.invalidIndex(this,x));
}
}
@Override
public void set(int x, int y, double value) {
if (dimensions==2) {
data[offset+x*getStride(0)+y*getStride(1)]=value;
} else {
throw new UnsupportedOperationException(ErrorMessages.invalidIndex(this,x,y));
}
}
@Override
public void set(int[] indexes, double value) {
int ix=offset;
if (indexes.length!=dimensions) throw new IllegalArgumentException(ErrorMessages.invalidIndex(this,indexes));
for (int i=0; i<dimensions; i++) {
ix+=indexes[i]*getStride(i);
}
data[ix]=value;
}
@Override
public INDArray getTranspose() {
return getTransposeView();
}
@Override
public INDArray getTransposeView() {
return Arrayz.wrapStrided(data,offset,IntArrays.reverse(shape),IntArrays.reverse(stride));
}
@Override
public AVector asVector() {
if (dimensions==0) {
return ArraySubVector.wrap(data,offset,1);
} else if (dimensions==1) {
return Vectorz.wrapStrided(data, offset, getShape(0), getStride(0));
} else {
AVector v=Vector0.INSTANCE;
int n=sliceCount();
for (int i=0; i<n; i++) {
v=v.join(slice(i).asVector());
}
return v;
}
}
@Override
public INDArray reshape(int... dimensions) {
return super.reshape(dimensions);
}
@Override
public INDArray broadcast(int... dimensions) {
return super.broadcast(dimensions);
}
@Override
public INDArray slice(int majorSlice) {
// if ((majorSlice<0)||(majorSlice>=shape[0])) throw new IllegalArgumentException(ErrorMessages.invalidSlice(this,majorSlice));
if (dimensions==0) {
throw new IllegalArgumentException("Can't slice a 0-d NDArray");
} else if (dimensions==1) {
return new ArrayIndexScalar(data,offset+majorSlice*getStride(0));
} else if (dimensions==2) {
int st=stride[1];
return Vectorz.wrapStrided(data, offset+majorSlice*getStride(0), getShape(1), st);
} else {
return Arrayz.wrapStrided(data,
offset+majorSlice*getStride(0),
Arrays.copyOfRange(shape, 1,dimensions),
Arrays.copyOfRange(stride, 1,dimensions));
}
}
@Override
public INDArray slice(int dimension, int index) {
if ((dimension<0)||(dimension>=dimensions)) throw new IllegalArgumentException(ErrorMessages.invalidDimension(this, dimension));
if (dimension==0) return slice(index);
if (dimensions==2) {
// note: dimension must be 1 if we are here
return Vectorz.wrapStrided(data, offset+index*getStride(1), getShape(0), getStride(0));
}
return Arrayz.wrapStrided(data,
offset+index*stride[dimension],
IntArrays.removeIndex(shape,index),
IntArrays.removeIndex(stride,index));
}
@Override
public NDArray subArray(int[] offsets, int[] shape) {
int n=dimensions;
if (offsets.length!=n) throw new IllegalArgumentException(ErrorMessages.invalidIndex(this, offsets));
if (shape.length!=n) throw new IllegalArgumentException(ErrorMessages.invalidIndex(this, offsets));
if (IntArrays.equals(shape, this.shape)) {
if (IntArrays.isZero(offsets)) {
return this;
} else {
throw new IllegalArgumentException("Invalid subArray offsets");
}
}
return new NDArray(data,
offset+IntArrays.dotProduct(offsets, stride),
IntArrays.copyOf(shape),
stride);
}
@Override
public boolean isMutable() {
return true;
}
@Override
public boolean isFullyMutable() {
return true;
}
@Override
public boolean isElementConstrained() {
return false;
}
@Override
public boolean isView() {
return (!isPackedArray());
}
@Override
public void applyOp(Op op) {
if (dimensions==0) {
data[offset]=op.apply(data[offset]);
} else if (dimensions==1) {
int len=sliceCount();
int st=getStride(0);
if (st==1) {
op.applyTo(data, offset, len);
} else {
for (int i=0; i<len; i++) {
data[offset+i*st]=op.apply(data[offset+i*st]);
}
}
} else {
int n=sliceCount();
for (int i=0; i<n; i++) {
slice(i).applyOp(op);
}
}
}
@Override
public void applyOp(IOperator op) {
applyOp((Op)op);
}
public boolean equals(NDArray a) {
if (dimensions!=a.dimensions) return false;
if (dimensions==0) return get()==a.get();
int sc=sliceCount();
if (a.sliceCount()!=sc) return false;
for (int i=0; i<sc; i++) {
if (!(slice(i).equals(a.slice(i)))) return false;
}
return true;
}
@Override
public boolean equals(INDArray a) {
if (a instanceof NDArray) {
return equals((NDArray)a);
}
if (dimensions!=a.dimensionality()) return false;
if (dimensions==0) return (get()==a.get());
int sc=sliceCount();
if (a.sliceCount()!=sc) return false;
for (int i=0; i<sc; i++) {
if (!(slice(i).equals(a.slice(i)))) return false;
}
return true;
}
@Override
public NDArray exactClone() {
NDArray c=new NDArray(data.clone(),offset,shape.clone(),stride.clone());
return c;
}
@Override
public INDArray clone() {
return Array.create(this);
}
@Override
public void multiply(double d) {
if (dimensions==0) {
data[offset]*=d;
} else if (dimensions==1) {
int n=sliceCount();
for (int i=0; i<n; i++) {
data[offset+i*getStride(0)]*=d;
}
} else {
int n=sliceCount();
for (int i=0; i<n; i++) {
slice(i).scale(d);
}
}
}
@Override
public void setElements(int pos,double[] values, int offset, int length) {
if (length==0) return;
if (dimensions==0) {
if (length!=1) throw new IllegalArgumentException("Must have one element!");
if (pos!=0) throw new IllegalArgumentException("Element index out of bounds: "+pos);
data[this.offset]=values[offset];
} else if (dimensions==1) {
asVector().setElements(pos,values,offset,length);
} else {
super.setElements(pos, values, offset, length);
}
}
@Override
public void toDoubleBuffer(DoubleBuffer dest) {
if (dimensions==0) {
dest.put(data[offset]);
} else if (isPackedArray()) {
dest.put(data,0,data.length);
} else {
int sc=sliceCount();
for (int i=0; i<sc; i++) {
INDArray s=slice(i);
s.toDoubleBuffer(dest);
}
}
}
@Override
public double[] asDoubleArray() {
return isPackedArray()?data:null;
}
@Override
public List<INDArray> getSlices() {
if (dimensions==0) {
throw new IllegalArgumentException(ErrorMessages.noSlices(this));
} else {
ArrayList<INDArray> al=new ArrayList<INDArray>();
int n=getShape(0);
for (int i=0; i<n; i++) {
al.add(slice(i));
}
return al;
}
}
@Override
public Iterator<Double> elementIterator() {
if (dimensionality()==0) {
return new SingleDoubleIterator(data[offset]);
} else {
return super.elementIterator();
}
}
@Override public void validate() {
if (dimensions>shape.length) throw new VectorzException("Insufficient shape data");
if (dimensions>stride.length) throw new VectorzException("Insufficient stride data");
if ((offset<0)||(offset>=data.length)) throw new VectorzException("Offset out of bounds");
int[] endIndex=IntArrays.decrementAll(shape);
int endOffset=offset+IntArrays.dotProduct(endIndex, stride);
if ((endOffset<0)||(endOffset>data.length)) throw new VectorzException("End offset out of bounds");
super.validate();
}
@Override
public INDArray immutable() {
return ImmutableArray.create(this);
}
@Override
public double[] getArray() {
return data;
}
public static INDArray wrapStrided(double[] data, int offset,
int[] shape, int[] strides) {
return new NDArray(data,offset,shape,strides);
}
}