package mikera.arrayz.impl;
import java.nio.DoubleBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import mikera.arrayz.Array;
import mikera.arrayz.Arrayz;
import mikera.arrayz.INDArray;
import mikera.arrayz.ISparse;
import mikera.indexz.AIndex;
import mikera.indexz.Index;
import mikera.matrixx.AMatrix;
import mikera.matrixx.Matrix;
import mikera.matrixx.Matrixx;
import mikera.matrixx.impl.VectorMatrixMN;
import mikera.util.Maths;
import mikera.vectorz.AScalar;
import mikera.vectorz.AVector;
import mikera.vectorz.IOperator;
import mikera.vectorz.Op;
import mikera.vectorz.Ops;
import mikera.vectorz.Scalar;
import mikera.vectorz.Tools;
import mikera.vectorz.Vector;
import mikera.vectorz.Vectorz;
import mikera.vectorz.impl.SingleDoubleIterator;
import mikera.vectorz.util.ErrorMessages;
import mikera.vectorz.util.IntArrays;
import mikera.vectorz.util.LongArrays;
/**
* Abstract base class for INDArray implementations
*
* Contains generic implementations for most INDArray operations, enabling new INDArray implementations
* to inherit these (at least until more optimised implementations can be written).
*
* Most INDArray instances should ultimately inherit from AbstractArray
*
* @author Mike
* @param <T> The type of array slices
*/
public abstract class AbstractArray<T> implements INDArray, Iterable<T> {
private static final long serialVersionUID = -958234961396539071L;
public abstract double get();
public abstract double get(int x);
public abstract double get(int x, int y);
@Override
public double get(AIndex ix) {
return get(ix.toArray());
}
@Override
public double get(Index ix) {
return get(ix.getData());
}
@Override
public int getShape(int dim) {
return getShape()[dim];
}
@Override
public int[] getShapeClone() {
int n=dimensionality();
int[] sh=new int[n];
for (int i=0; i<n; i++) {
sh[i]=getShape(i);
}
return sh;
}
@Override
public long[] getLongShape() {
return LongArrays.copyOf(getShape());
}
@Override
public boolean epsilonEquals(INDArray a) {
return epsilonEquals(a,Vectorz.TEST_EPSILON);
}
@Override
public boolean epsilonEquals(INDArray a, double epsilon) {
if (dimensionality()==0) {
double d=get()-a.get();
return (Math.abs(d)<=epsilon);
} else {
int sc=sliceCount();
if (a.sliceCount()!=sc) return false;
for (int i=0; i<sc; i++) {
INDArray s=slice(i);
if (!s.epsilonEquals(a.slice(i),epsilon)) return false;
}
return true;
}
}
@Override
public boolean isBoolean() {
if (dimensionality()==0) return Tools.isBoolean(get());
int sc=sliceCount();
for (int i=0; i<sc; i++) {
INDArray s=slice(i);
if (!s.isBoolean()) return false;
}
return true;
}
@Override
public boolean isSparse() {
return (this instanceof ISparse);
}
@Override
public boolean isDense() {
return (this instanceof IDense);
}
@Override
public boolean isMutable() {
int n=sliceCount();
for (int i=0; i<n; i++) {
if (slice(i).isMutable()) return true;
}
return false;
}
@Override
public boolean isFullyMutable() {
int n=sliceCount();
for (int i=0; i<n; i++) {
if (!slice(i).isFullyMutable()) return false;
}
return true;
}
@Override
public void applyOp(Op op) {
int n=sliceCount();
for (int i=0; i<n; i++) {
slice(i).applyOp(op);
}
}
@Override
public void applyOp(IOperator op) {
int n=sliceCount();
for (int i=0; i<n; i++) {
slice(i).applyOp(op);
}
}
@Override
public void multiply(double d) {
int n=sliceCount();
for (int i=0; i<n; i++) {
slice(i).multiply(d);
}
}
@Override
public INDArray multiplyCopy(double d) {
INDArray r=clone();
r.multiply(d);
return r;
}
@Override
public INDArray applyOpCopy(Op op) {
INDArray r=clone();
r.applyOp(op);
return r;
}
@Override
public boolean isElementConstrained() {
int n=sliceCount();
for (int i=0; i<n; i++) {
if (slice(i).isElementConstrained()) return true;
}
return false;
}
@Override
public boolean isSameShape(INDArray a) {
int dims=dimensionality();
if (dims!=a.dimensionality()) return false;
for (int i=0; i<dims; i++) {
if (getShape(i)!=a.getShape(i)) return false;
}
return true;
}
@Override
public AVector asVector() {
if (this instanceof IDenseArray) {
IDenseArray a=(IDenseArray) this;
return Vectorz.wrap(a.getArray(), a.getArrayOffset(), (int)elementCount());
}
int n=sliceCount();
AVector result=slice(0).asVector();
for (int i=1; i<n; i++) {
result=result.join(slice(i).asVector());
}
return result;
}
@Override
public void setElements(double[] values, int offset) {
setElements(0,values,offset,(int)elementCount());
}
@Override
public void setElements(int pos, double[] values, int offset, int length) {
if (length==0) return;
int ss=(int)(slice(0).elementCount());
int s1=pos/ss;
int s2=(pos+length-1)/ss;
if (s1==s2) {
slice(s1).setElements(pos-s1*ss,values,offset,length);
return;
}
int si=offset;
int l1 = (s1+1)*ss-pos;
if (l1>0) {
slice(s1).setElements(pos-s1*ss, values, si, l1);
si+=l1;
}
for (int i=s1+1; i<s2; i++) {
slice(i).setElements(values, si);
si+=ss;
}
int l2=(pos+length)-(s2*ss);
if (l2>0) {
slice(s2).setElements(0,values,si,l2);
}
}
@Override
public boolean isZero() {
if (dimensionality()==0) return get()==0.0;
int sc=sliceCount();
for (int i=0; i<sc; i++) {
INDArray s=slice(i);
if (!s.isZero()) return false;
}
return true;
}
@Override
public INDArray ensureMutable() {
if (isFullyMutable()&&!isView()) return this;
return clone();
}
@Override
public void fill(double value) {
if (dimensionality()==0) {
set(value);
} else {
int sc=sliceCount();
for (int i=0; i<sc; i++) {
INDArray s=slice(i);
s.fill(value);
}
}
}
@Override
public INDArray innerProduct(double a) {
INDArray result=clone();
result.scale(a);
return result;
}
@Override
public INDArray innerProduct(INDArray a) {
int dims=dimensionality();
switch (dims) {
case 0: {
a=a.clone();
a.scale(get());
return a;
}
case 1: {
return toVector().innerProduct(a);
}
case 2: {
return Matrix.create(this).innerProduct(a);
}
}
int sc=sliceCount();
ArrayList<INDArray> sips=new ArrayList<INDArray>();
for (int i=0; i<sc; i++) {
sips.add(slice(i).innerProduct(a));
}
return SliceArray.create(sips);
}
@Override
public INDArray innerProduct(AScalar s) {
return innerProduct(s.get());
}
@Override
public INDArray innerProduct(AVector a) {
return innerProduct((INDArray) a);
}
@Override
public INDArray outerProduct(INDArray a) {
ArrayList<INDArray> al=new ArrayList<INDArray>();
for (Object s:this) {
if (s instanceof INDArray) {
al.add(((INDArray)s).outerProduct(a));
} else {
double x=Tools.toDouble(s);
INDArray sa=a.clone();
sa.scale(x);
al.add(sa);
}
}
return Arrayz.create(al);
}
@Override
public INDArray getTranspose() {
return getTransposeCopy();
}
@Override
public INDArray getTransposeView() {
throw new UnsupportedOperationException();
}
@Override
public INDArray getTransposeCopy() {
Array nd=Array.create(this);
return nd.getTransposeView();
}
public final void scale(double d) {
multiply(d);
}
@Override
public void scaleAdd(double factor, double constant) {
multiply(factor);
add(constant);
}
public void set(double value) {
set(new int[0],value);
}
public void set(int x, double value) {
set(new int[] {x},value);
}
public void set(int x, int y, double value) {
set(new int[] {x,y},value);
}
public void set (INDArray a) {
int tdims=this.dimensionality();
int adims=a.dimensionality();
if (adims<tdims) {
int sc=sliceCount();
for (int i=0; i<sc; i++) {
INDArray s=slice(i);
s.set(a);
}
} else if (adims==tdims) {
if (tdims==0) {
set(a.get());
return;
}
int sc=sliceCount();
for (int i=0; i<sc; i++) {
INDArray s=slice(i);
s.set(a.slice(i));
}
} else {
throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a));
}
}
@Override
public void clamp(double min, double max) {
if (dimensionality()==0) {
set(Maths.bound(get(), min, max));
return;
}
int len=sliceCount();
for (int i = 0; i < len; i++) {
slice(i).clamp(min, max);
}
}
public void set(Object o) {
if (o instanceof INDArray) {set((INDArray)o); return;}
if (o instanceof Number) {
set(((Number)o).doubleValue()); return;
}
if (o instanceof Iterable<?>) {
int i=0;
for (Object ob: ((Iterable<?>)o)) {
slice(i).set(ob);
}
return;
}
if (o instanceof double[]) {
setElements((double[])o);
return;
}
throw new UnsupportedOperationException("Can't set to value for "+o.getClass().toString());
}
public void setElements(double... values) {
int vl=values.length;
if (vl!=elementCount()) throw new IllegalArgumentException("Wrong array length: "+vl);
setElements(0,values,0,vl);
}
public void square() {
applyOp(Ops.SQUARE);
}
@Override
public INDArray squareCopy() {
INDArray r=clone();
r.square();
return r;
}
@Override
public INDArray absCopy() {
INDArray r=clone();
r.abs();
return r;
}
@Override
public INDArray reciprocalCopy() {
INDArray r=clone();
r.reciprocal();
return r;
}
@Override
public INDArray signumCopy() {
INDArray r=clone();
r.signum();
return r;
}
@Override
public Iterator<T> iterator() {
return new SliceIterator<T>(this);
}
@Override
public Iterator<Double> elementIterator() {
if (dimensionality()==0) {
return new SingleDoubleIterator(get());
} else {
return new SliceElementIterator(this);
}
}
public boolean equals(Object o) {
if (!(o instanceof INDArray)) return false;
return equals((INDArray)o);
}
@Override
public boolean equalsArray(double[] data) {
if (data.length!=elementCount()) return false;
return equalsArray(data,0);
}
@Override
public int hashCode() {
return asVector().hashCode();
}
@Override
public String toString() {
if (dimensionality()==0) {
return Double.toString(get());
}
StringBuilder sb=new StringBuilder();
int length=sliceCount();
sb.append('[');
if (length>0) {
sb.append(slice(0).toString());
for (int i = 1; i < length; i++) {
sb.append(',');
sb.append(slice(i).toString());
}
}
sb.append(']');
return sb.toString();
}
@Override
public INDArray clone() {
return Arrayz.create(this);
}
@Override
public INDArray copy() {
if (!isMutable()) return this;
return clone();
}
@Override
public INDArray scaleCopy(double d) {
INDArray r=clone();
r.scale(d);
return r;
}
@Override
public INDArray negateCopy() {
INDArray r=clone();
r.negate();
return r;
}
@Override
public boolean equals(INDArray a) {
int dims=dimensionality();
if (a.dimensionality()!=dims) return false;
if (dims==0) {
return (get()==a.get());
} else if (dims==1) {
return equals(a.asVector());
} else {
int sc=sliceCount();
for (int i=0; i<sc; i++) {
if (!slice(i).equals(a.slice(i))) return false;
}
}
return true;
}
public boolean equals(AVector a) {
if (dimensionality()!=1) return false;
int sc=sliceCount();
if (a.length()!=sc) return false;
for (int i=0; i<sc; i++) {
if (!(get(i)==a.unsafeGet(i))) return false;
}
return true;
}
@Override
public boolean equalsArray(double[] data, int offset) {
int dims=dimensionality();
if (dims==0) {
return (data[offset]==get());
} else if (dims==1) {
return asVector().equalsArray(data, offset);
} else {
int sc=sliceCount();
int skip=(int) slice(0).elementCount();
for (int i=0; i<sc; i++) {
if (!slice(i).equalsArray(data,offset+i*skip)) return false;
}
}
return true;
}
@Override
public void add(INDArray a) {
int dims=dimensionality();
if (dims==0) {
add(a.get());
return;
}
int adims=a.dimensionality();
int n=sliceCount();
int na=a.sliceCount();
if (dims==adims) {
if (n!=na) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a));
for (int i=0; i<n; i++) {
slice(i).add(a.slice(i));
}
} else if (adims<dims) {
for (int i=0; i<n; i++) {
slice(i).add(a);
}
} else {
throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a));
}
}
@Override
public void add(double a) {
int dims=dimensionality();
if (dims ==0) {
set(a+get());
} else {
int n=sliceCount();
for (int i=0; i<n; i++) {
slice(i).add(a);
}
}
}
@Override
public void addAt(int i, double v) {
int ss=(int)(elementCount()/sliceCount());
int s=i/ss;
slice(s).addAt(i-s*ss, v);
}
@Override
public INDArray addCopy(INDArray a) {
INDArray r=this.broadcastCloneLike(a);
r.add(a);
return r;
}
@Override
public INDArray subCopy(INDArray a) {
INDArray r=this.broadcastCloneLike(a);
r.sub(a);
return r;
}
@Override
public INDArray multiplyCopy(INDArray a) {
INDArray r=this.broadcastCloneLike(a);
r.multiply(a);
return r;
}
@Override
public INDArray divideCopy(INDArray a) {
INDArray r=this.broadcastCloneLike(a);
r.divide(a);
return r;
}
@Override
public void addToArray(double[] data, int offset) {
int dims=dimensionality();
if (dims ==0) {
data[offset]+=get();
} else {
int n=sliceCount();
INDArray s0=slice(0);
int ec=(int) s0.elementCount();
s0.addToArray(data, offset);
for (int i=1; i<n; i++) {
slice(i).addToArray(data, offset+i*ec);
}
}
}
@Override
public void pow(double exponent) {
int dims=dimensionality();
if (dims ==0) {
set(Math.pow(get(), exponent));
} else {
int n=sliceCount();
for (int i=0; i<n; i++) {
slice(i).pow(exponent);
}
}
}
@Override
public void sub(double a) {
add(-a);
}
@Override
public void multiply(INDArray a) {
int dims=dimensionality();
if (dims==0) {set(get()*a.get()); return;}
int adims=a.dimensionality();
if (adims==0) {multiply(a.get()); return;}
int n=sliceCount();
int na=a.sliceCount();
if (dims==adims) {
if (n!=na) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a));
for (int i=0; i<n; i++) {
slice(i).multiply(a.slice(i));
}
} else if (adims<dims) {
for (int i=0; i<n; i++) {
slice(i).multiply(a);
}
} else {
throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a));
}
}
@Override
public void divide(INDArray a) {
int dims=dimensionality();
if (dims==0) {set(get()/a.get()); return;}
int adims=a.dimensionality();
if (adims==0) {scale(1.0/a.get()); return;}
int n=sliceCount();
int na=a.sliceCount();
if (dims==adims) {
if (n!=na) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a));
for (int i=0; i<n; i++) {
slice(i).divide(a.slice(i));
}
} else if (adims<dims) {
for (int i=0; i<n; i++) {
slice(i).divide(a);
}
} else {
throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a));
}
}
@Override
public void divide(double factor) {
multiply(1.0/factor);
}
@Override
public long nonZeroCount() {
if (dimensionality()==0) {
return (get()==0.0)?0:1;
}
long result=0;
int n=sliceCount();
for (int i=0; i<n; i++) {
result+=slice(i).nonZeroCount();
}
return result;
}
public double density() {
return ((double)nonZeroCount())/elementCount();
}
@Override
public double elementSum() {
if (dimensionality()==0) {
return get();
}
double result=0;
int n=sliceCount();
for (int i=0; i<n; i++) {
result+=slice(i).elementSum();
}
return result;
}
@Override
public double elementProduct() {
if (dimensionality()==0) {
return get();
}
double result=1.0;
int n=sliceCount();
for (int i=0; i<n; i++) {
result*=slice(i).elementProduct();
if (result==0.0) return 0.0;
}
return result;
}
@Override
public double elementMax(){
if (dimensionality()==0) {
return get();
}
double result=-Double.MAX_VALUE;
int n=sliceCount();
for (int i=0; i<n; i++) {
double v=slice(i).elementMax();
if (v>result) result=v;
}
return result;
}
@Override
public boolean elementsEqual(double value) {
if (dimensionality()==0) {
return get()==value;
}
int n=sliceCount();
for (int i=0; i<n; i++) {
if (!slice(i).elementsEqual(value)) return false;
}
return true;
}
@Override
public double elementMin(){
if (dimensionality()==0) {
return get();
}
double result=Double.MAX_VALUE;
int n=sliceCount();
for (int i=0; i<n; i++) {
double v=slice(i).elementMin();
if (v<result) result=v;
}
return result;
}
@Override
public double elementSquaredSum() {
if (dimensionality()==0) {
double value=get();
return value*value;
}
double result=0;
int n=sliceCount();
for (int i=0; i<n; i++) {
result+=slice(i).elementSquaredSum();
}
return result;
}
@Override
public void sub(INDArray a) {
int dims=dimensionality();
if (dims==0) {
sub(a.get());
return;
}
int n=sliceCount();
int na=a.sliceCount();
int adims=a.dimensionality();
if (dims==adims) {
if (n!=na) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a));
for (int i=0; i<n; i++) {
slice(i).sub(a.slice(i));
}
} else if (adims<dims) {
for (int i=0; i<n; i++) {
slice(i).sub(a);
}
} else {
throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a));
}
}
@Override
public void negate() {
multiply(-1.0);
}
@Override
public void reciprocal() {
if (dimensionality()==0) {
set(1.0/get());
} else {
int sc=sliceCount();
for (int i=0; i<sc; i++) {
slice(i).reciprocal();
}
}
}
@Override
public void abs() {
if (dimensionality()==0) {
set(Math.abs(get()));
} else {
int sc=sliceCount();
for (int i=0; i<sc; i++) {
slice(i).abs();
}
}
}
@Override
public void sqrt() {
if (dimensionality()==0) {
set(Math.sqrt(get()));
} else {
int sc=sliceCount();
for (int i=0; i<sc; i++) {
slice(i).sqrt();
}
}
}
@Override
public void log() {
if (dimensionality()==0) {
set(Math.log(get()));
} else {
int sc=sliceCount();
for (int i=0; i<sc; i++) {
slice(i).log();
}
}
}
@Override
public void exp() {
if (dimensionality()==0) {
set(Math.exp(get()));
} else {
int sc=sliceCount();
for (int i=0; i<sc; i++) {
slice(i).exp();
}
}
}
@Override
public void signum() {
if (dimensionality()==0) {
set(Math.signum(get()));
} else {
int sc=sliceCount();
for (int i=0; i<sc; i++) {
slice(i).signum();
}
}
}
@Override
public INDArray reshape(int... targetShape) {
return Arrayz.createFromVector(asVector(), targetShape);
}
@Override
public INDArray reorder(int dim, int[] order) {
if (order.length==0) {
int[] shape=getShapeClone();
shape[0]=0;
return Arrayz.createZeroArray(shape);
}
int dims=dimensionality();
if ((dim<0)||(dim>=dims)) throw new IndexOutOfBoundsException(ErrorMessages.invalidDimension(this, dim));
ArrayList<INDArray> al=new ArrayList<INDArray>();
for (int i : order) {
al.add(slice(dim,i));
}
int n=al.size();
int[] shp=this.getShapeClone();
shp[dim]=n;
if (dims==2) {
if (dim==0) {
return VectorMatrixMN.create(al,shp);
}
}
if (dim==0) {
return SliceArray.create(al,shp);
} else {
Array a=Array.newArray(shp);
for (int i=0; i<n; i++) {
a.slice(dim, i).set(al.get(i));
}
return a;
}
}
@Override
public INDArray reorder(int[] order) {
return reorder(0,order);
}
@Override
public List<?> getSlices(int dimension) {
int l=getShape(dimension);
ArrayList<INDArray> al=new ArrayList<INDArray>(l);
for (int i=0; i<l; i++) {
al.add(slice(dimension,i));
}
return al;
}
@Override
public List<?> getSlices() {
int n=sliceCount();
ArrayList<Object> al=new ArrayList<Object>(n);
for (int i=0; i<n; i++) {
al.add(slice(i));
}
return al;
}
@Override
public List<INDArray> getSliceViews() {
int n=sliceCount();
ArrayList<INDArray> al=new ArrayList<INDArray>(n);
for (int i=0; i<n; i++) {
al.add(slice(i));
}
return al;
}
@Override
public INDArray subArray(int[] offsets, int[] shape) {
int n=dimensionality();
if (offsets.length!=n) throw new IllegalArgumentException(ErrorMessages.invalidIndex(this, offsets));
if (shape.length!=n) throw new IllegalArgumentException(ErrorMessages.invalidIndex(this, offsets));
int[] thisShape=this.getShape();
if (IntArrays.equals(shape, thisShape)) {
if (IntArrays.isZero(offsets)) {
return this;
} else {
throw new IllegalArgumentException("Invalid subArray offsets");
}
}
ArrayList<INDArray> al=new ArrayList<INDArray>();
int endIndex=offsets[0]+shape[0];
int[] zzoffsets=IntArrays.removeIndex(offsets, 0);
int[] zzshape=IntArrays.removeIndex(shape, 0);
for (int i=offsets[0]; i<endIndex; i++) {
al.add(slice(i).subArray(zzoffsets, zzshape));
}
return SliceArray.create(al);
}
@Override
public INDArray join(INDArray a, int dimension) {
return JoinedArray.join(this,a,dimension);
}
@Override
public INDArray join(INDArray a) {
return JoinedArray.join(this,a,0);
}
@Override
public INDArray rotateView(int dimension, int shift) {
int dlen=getShape(dimension);
int n=dimensionality();
shift = Maths.mod(shift,dlen);
if (shift==0) return this;
int[] off=new int[n];
int[] shp=getShapeClone();
shp[dimension]=shift;
INDArray right=subArray(off,shp);
shp[dimension]=dlen-shift;
off[dimension]=shift;
INDArray left=subArray(off,shp);
return left.join(right,dimension);
}
@Override
public Vector toVector() {
int n=(int)elementCount();
double[] data=new double[n];
this.getElements(data, 0);
return Vector.wrap(data);
}
@Override
public Array toArray() {
return Array.create(this);
}
@Override
public List<Double> asElementList() {
return asVector().asElementList();
}
@Override
public final double[] getElements() {
return toDoubleArray();
}
@Override
public void getElements(double[] dest, int offset) {
if (dimensionality()==0) {
dest[offset]=get();
return;
}
int sc=sliceCount();
for (int i=0; i<sc; i++) {
INDArray s=slice(i);
s.getElements(dest,offset);
offset+=s.elementCount();
}
}
@Override
public void copyTo(double[] arr) {
getElements(arr,0);
}
@Override
public void toDoubleBuffer(DoubleBuffer dest) {
int sc=sliceCount();
for (int i=0; i<sc; i++) {
INDArray s=slice(i);
s.toDoubleBuffer(dest);
}
}
@Override
public double[] toDoubleArray() {
double[] result=Array.createStorage(this.getShape());
if (this.isSparse()) {
addToArray(result,0);
} else {
getElements(result,0);
}
return result;
}
@Override
public double[] asDoubleArray() {
return null;
}
@Override
public INDArray[] toSliceArray() {
int n=sliceCount();
INDArray[] al=new INDArray[n];
for (int i=0; i<n; i++) {
al[i]=slice(i);
}
return al;
}
@Override
public Object sliceValue(int i) {
if (dimensionality()==1) {
return get(i);
}
return slice(i);
}
@Override
public INDArray broadcast(int... targetShape) {
int dims=dimensionality();
int tdims=targetShape.length;
if (tdims<dims) {
throw new IllegalArgumentException(ErrorMessages.incompatibleBroadcast(this, targetShape));
} else if (dims==tdims) {
if (IntArrays.equals(targetShape, this.getShape())) return this;
throw new IllegalArgumentException(ErrorMessages.incompatibleBroadcast(this, targetShape));
} else {
int n=targetShape[0];
INDArray s=broadcast(Arrays.copyOfRange(targetShape, 1, tdims));
return SliceArray.repeat(s,n);
}
}
@Override
public INDArray immutable() {
if (!isMutable()) return this;
return ImmutableArray.create(this);
}
@Override
public INDArray mutable() {
if (isFullyMutable()&&(!isView())) return this;
return clone();
}
@Override
public final INDArray mutableClone() {
return clone();
}
@Override
public INDArray sparse() {
int dims=dimensionality();
if (dims==0) return this;
if (dims==1) {
if (this instanceof ISparse) return this;
return Vectorz.createSparse(this.asVector());
}
if (dims==2) {
if (this instanceof ISparse) return this;
return Matrixx.createSparse(this.getSliceViews());
}
int n=this.sliceCount();
List<INDArray> sls=this.getSliceViews();
for (int i=0; i<n; i++) {
sls.set(i,sls.get(i).sparse());
}
return SliceArray.create(sls);
}
@Override
public INDArray dense() {
if (this instanceof IDense) return this;
int dims=dimensionality();
if (dims==0) {
if (this instanceof AScalar) return this;
return Scalar.create(get());
}
if (dims==1) {
return Vector.create(this);
}
if (dims==2) {
return Matrix.create(this);
}
return Array.create(this);
}
@Override
public INDArray denseClone() {
int dims=dimensionality();
if (dims==0) {
return Scalar.create(get());
}
if (dims==1) {
return Vector.create(this);
}
if (dims==2) {
return Matrix.create(this);
}
return Array.create(this);
}
@Override
public INDArray sparseClone() {
int dims=dimensionality();
if (dims==0) return this;
if (dims==1) {
return Vectorz.createSparseMutable(this.asVector());
}
if (dims==2) {
if (this instanceof AMatrix) return Matrixx.createSparseRows((AMatrix)this);
return Matrixx.createSparseRows(this);
}
int n=this.sliceCount();
List<INDArray> sls=this.getSliceViews();
for (int i=0; i<n; i++) {
sls.set(i,sls.get(i).sparseClone());
}
return SliceArray.create(sls);
}
@Override
public INDArray broadcastLike(INDArray target) {
return broadcast(target.getShape());
}
@Override
public AMatrix broadcastLike(AMatrix target) {
return Matrixx.toMatrix(broadcast(target.getShape()));
}
@Override
public AVector broadcastLike(AVector target) {
return Vectorz.toVector(broadcast(target.getShape()));
}
@Override
public INDArray broadcastCloneLike(INDArray target) {
int dims=dimensionality();
int targetDims=target.dimensionality();
INDArray r=this;
if (dims<targetDims) r=r.broadcastLike(target);
return r.clone();
}
@Override
public INDArray broadcastCopyLike(INDArray target) {
if (isMutable()) {
return broadcastCloneLike(target);
} else {
return broadcastLike(target);
}
}
@Override
public void validate() {
// TODO: any generic validation?
}
/**
* Returns true if any element is this array is NaN or infinite
* @return
*/
@Override
public boolean hasUncountable() {
if (dimensionality()==0) return Double.isNaN(get()) || Double.isInfinite(get());
int sc=sliceCount();
for (int i=0; i<sc; i++) {
INDArray s=slice(i);
if (s.hasUncountable()) return true;
}
return false;
}
/**
* Returns the sum of all the elements raised to a specified power
* @return
*/
public double elementPowSum(double p) {
if (dimensionality()==0) {
double value=get();
return Math.pow(value, p);
}
double result=0;
int n=sliceCount();
for (int i=0; i<n; i++) {
result+=slice(i).elementPowSum(p);
}
return result;
}
/**
* Returns the sum of the absolute values of all the elements raised to a specified power
* @return
*/
public double elementAbsPowSum(double p) {
if (dimensionality()==0) {
double value=Math.abs(get());
return Math.pow(value, p);
}
double result=0;
int n=sliceCount();
for (int i=0; i<n; i++) {
result+=slice(i).elementAbsPowSum(p);
}
return result;
}
}