package mikera.arrayz.impl;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import mikera.arrayz.Array;
import mikera.arrayz.Arrayz;
import mikera.arrayz.INDArray;
import mikera.matrixx.Matrix;
import mikera.vectorz.AVector;
import mikera.vectorz.IOperator;
import mikera.vectorz.Op;
import mikera.vectorz.impl.Vector0;
import mikera.vectorz.util.ErrorMessages;
import mikera.vectorz.util.IntArrays;
import mikera.vectorz.util.VectorzException;
/**
* A general n-dimensional double array implemented as a collection of (n-1) dimensional slices
*
* @author Mike
*
* @param <T>
*/
public final class SliceArray<T extends INDArray> extends BaseShapedArray {
private static final long serialVersionUID = -2343678749417219155L;
private final long[] longShape;
private final T[] slices;
private SliceArray(int[] shape, T[] slices) {
super(shape);
this.slices=slices;
int dims=shape.length;
this.longShape=new long[dims];
for (int i=0; i<dims; i++) {
longShape[i]=shape[i];
}
}
@SuppressWarnings("unchecked")
public static <T extends INDArray> SliceArray<T> create(INDArray a) {
return new SliceArray<T>(a.getShape(),(T[]) a.toSliceArray());
}
public static <T extends INDArray> SliceArray<T> of(T... slices) {
return new SliceArray<T>(IntArrays.consArray(slices.length,slices[0].getShape()),slices.clone());
}
public static SliceArray<INDArray> repeat (INDArray s, int n) {
ArrayList<INDArray> al=new ArrayList<INDArray>();
for (int i=0; i<n; i++) {
al.add(s);
}
return SliceArray.create(al);
}
public static SliceArray<INDArray> create(List<INDArray> slices) {
int slen=slices.size();
if (slen==0) throw new VectorzException("Empty list of slices provided to SliceArray");
INDArray[] arr=new INDArray[slen];
return new SliceArray<INDArray>(IntArrays.consArray(slen,slices.get(0).getShape()),slices.toArray(arr));
}
public static SliceArray<INDArray> create(List<INDArray> slices, int[] shape) {
int slen=slices.size();
INDArray[] arr=new INDArray[slen];
return new SliceArray<INDArray>(shape,slices.toArray(arr));
}
@Override
public int dimensionality() {
return shape.length;
}
@Override
public long[] getLongShape() {
return longShape;
}
@Override
public double get(int... indexes) {
int d=indexes.length;
T slice=slices[indexes[0]];
switch (d) {
case 0: throw new VectorzException("Can't do 0D get on SliceArray!");
case 1: return slice.get();
case 2: return slice.get(indexes[1]);
case 3: return slice.get(indexes[1],indexes[2]);
default: return slice.get(Arrays.copyOfRange(indexes,1,d));
}
}
@Override
public void set(double value) {
for (T s:slices) {
s.set(value);
}
}
public void fill(double value) {
for (T s:slices) {
s.fill(value);
}
}
@Override
public void set(int[] indexes, double value) {
int d=indexes.length;
if (d==0) {
set(value);
}
T slice=slices[indexes[0]];
switch (d) {
case 0: throw new VectorzException("Can't do 0D set on SliceArray!");
case 1: slice.set(value); return;
case 2: slice.set(indexes[1],value); return;
case 3: slice.set(indexes[1],indexes[2],value); return;
default: slice.set(Arrays.copyOfRange(indexes,1,d),value); return;
}
}
@Override
public AVector asVector() {
AVector v=Vector0.INSTANCE;
for (INDArray a:slices) {
v=v.join(a.asVector());
}
return v;
}
@Override
public INDArray reshape(int... dimensions) {
return Arrayz.createFromVector(asVector(), dimensions);
}
@Override
public INDArray slice(int majorSlice) {
return slices[majorSlice];
}
@Override
public INDArray slice(int dimension, int index) {
if (dimension<0) throw new IllegalArgumentException(ErrorMessages.invalidDimension(this, dimension));
if (dimension==0) return slice(index);
ArrayList<INDArray> al=new ArrayList<INDArray>();
for (INDArray s:this) {
al.add(s.slice(dimension-1,index));
}
return SliceArray.create(al);
}
@Override
public long elementCount() {
return IntArrays.arrayProduct(shape);
}
public INDArray innerProduct(INDArray a) {
int dims=dimensionality();
switch (dims) {
case 0: {
a=a.clone();
a.scale(get());
return a;
}
case 1: {
return this.toVector().innerProduct(a);
}
case 2: {
return Matrix.create(this).innerProduct(a);
}
}
ArrayList<INDArray> al=new ArrayList<INDArray>();
for (INDArray s:this) {
al.add(s.innerProduct(a));
}
return Arrayz.create(al);
}
public INDArray outerProduct(INDArray a) {
ArrayList<INDArray> al=new ArrayList<INDArray>();
for (INDArray s:this) {
al.add(s.outerProduct(a));
}
return Arrayz.create(al);
}
@Override
public boolean isMutable() {
for (INDArray a:slices) {
if (a.isMutable()) return true;
}
return false;
}
@Override
public boolean isFullyMutable() {
for (INDArray a:slices) {
if (!a.isFullyMutable()) return false;
}
return true;
}
@Override
public boolean isZero() {
for (INDArray a:slices) {
if (!a.isZero()) return false;
}
return true;
}
@Override
public boolean isBoolean() {
for (INDArray a:slices) {
if (!a.isBoolean()) return false;
}
return true;
}
@Override
public boolean isElementConstrained() {
for (INDArray a:slices) {
if (a.isElementConstrained()) return true;
}
return false;
}
@Override
public boolean isView() {
return true;
}
@Override
public void applyOp(Op op) {
for (INDArray a:slices) {
a.applyOp(op);
}
}
@Override
public void applyOp(IOperator op) {
for (INDArray a:slices) {
a.applyOp(op);
}
}
@Override
public void multiply(double d) {
for (INDArray a:slices) {
a.scale(d);
}
}
@Override
public boolean equals(INDArray a) {
if (!Arrays.equals(a.getShape(), this.getShape())) return false;
for (int i=0; i<slices.length; i++) {
if (!slices[i].equals(a.slice(i))) return false;
}
return true;
}
@SuppressWarnings("unchecked")
@Override
public SliceArray<T> exactClone() {
T[] newSlices=slices.clone();
for (int i=0; i<slices.length; i++) {
newSlices[i]=(T) newSlices[i].exactClone();
}
return new SliceArray<T>(shape,newSlices);
}
@Override
public List<?> getSlices() {
ArrayList<Object> al=new ArrayList<Object>();
if (dimensionality()==1) {
int n=sliceCount();
for (int i=0; i<n ; i++) {
al.add(Double.valueOf(slices[i].get()));
}
return al;
}
for (INDArray sl:this) {
al.add(sl);
}
return al;
}
@Override
public INDArray[] toSliceArray() {
int n=sliceCount();
INDArray[] al=new INDArray[n];
for (int i=0; i<n; i++) {
al[i]=slice(i);
}
return al;
}
@Override
public double[] toDoubleArray() {
double[] result=Array.createStorage(this.getShape());
int skip=(int)slice(0).elementCount();
for (int i=0; i<slices.length; i++) {
INDArray s=slices[i];
if (s.isSparse()) {
s.addToArray(result,skip*i);
} else {
s.getElements(result,skip*i);
}
}
return result;
}
@Override
public boolean equalsArray(double[] values, int offset) {
int skip=(int)slice(0).elementCount();
int di=offset;
for (int i=0; i<slices.length; i++) {
if (!slices[i].equalsArray(values,di)) return false;
di+=skip;
}
return true;
}
@Override
public void validate() {
if (shape.length!=longShape.length) throw new VectorzException("Shape mismatch");
long ec=0;
for (int i=0; i<slices.length; i++) {
T s=slices[i];
ec+=s.elementCount();
slices[i].validate();
int[] ss=s.getShape();
for (int j=0; j<ss.length; j++) {
if(getShape(j+1)!=ss[j]) throw new VectorzException("Slice shape mismatch");
}
}
if (ec!=elementCount()) throw new VectorzException("Element count mismatch");
super.validate();
}
@Override
public double get() {
throw new IllegalArgumentException("0d get not supported on "+getClass());
}
@Override
public double get(int x) {
return slices[x].get();
}
@Override
public double get(int x, int y) {
return slices[x].get(y);
}
}