Package mikera.arrayz.impl

Source Code of mikera.arrayz.impl.AbstractArray

package mikera.arrayz.impl;

import java.nio.DoubleBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;

import mikera.arrayz.Array;
import mikera.arrayz.Arrayz;
import mikera.arrayz.INDArray;
import mikera.arrayz.ISparse;
import mikera.indexz.AIndex;
import mikera.indexz.Index;
import mikera.matrixx.AMatrix;
import mikera.matrixx.Matrix;
import mikera.matrixx.Matrixx;
import mikera.matrixx.impl.VectorMatrixMN;
import mikera.util.Maths;
import mikera.vectorz.AScalar;
import mikera.vectorz.AVector;
import mikera.vectorz.IOperator;
import mikera.vectorz.Op;
import mikera.vectorz.Ops;
import mikera.vectorz.Scalar;
import mikera.vectorz.Tools;
import mikera.vectorz.Vector;
import mikera.vectorz.Vectorz;
import mikera.vectorz.impl.SingleDoubleIterator;
import mikera.vectorz.util.ErrorMessages;
import mikera.vectorz.util.IntArrays;
import mikera.vectorz.util.LongArrays;
/**
* Abstract base class for INDArray implementations
*
* Contains generic implementations for most INDArray operations, enabling new INDArray implementations
* to inherit these (at least until more optimised implementations can be written).
*
* Most INDArray instances should ultimately inherit from AbstractArray
*
* @author Mike
* @param <T> The type of array slices
*/
public abstract class AbstractArray<T> implements INDArray, Iterable<T> {
  private static final long serialVersionUID = -958234961396539071L;

  public abstract double get();
 
  public abstract double get(int x);
 
  public abstract double get(int x, int y);
 
  @Override
  public double get(AIndex ix) {
    return get(ix.toArray());
  }
 
  @Override
  public double get(Index ix) {
    return get(ix.getData());
  }
 
  @Override
  public int getShape(int dim) {
    return getShape()[dim];
  }
 
  @Override
  public int[] getShapeClone() {
    int n=dimensionality();
    int[] sh=new int[n];
    for (int i=0; i<n; i++) {
      sh[i]=getShape(i);
    }
    return sh;
  }
 
  @Override
  public long[] getLongShape() {
    return LongArrays.copyOf(getShape());
  }
 
  @Override
  public boolean epsilonEquals(INDArray a) {
    return epsilonEquals(a,Vectorz.TEST_EPSILON);
  }
 
  @Override
  public boolean epsilonEquals(INDArray a, double epsilon) {
    if (dimensionality()==0) {
      double d=get()-a.get();
      return (Math.abs(d)<=epsilon);
    } else {
      int sc=sliceCount();
      if (a.sliceCount()!=sc) return false;
      for (int i=0; i<sc; i++) {
        INDArray s=slice(i);
        if (!s.epsilonEquals(a.slice(i),epsilon)) return false;
      }     
      return true;
    }
  }
 
  @Override
  public boolean isBoolean() {
    if (dimensionality()==0) return Tools.isBoolean(get());
    int sc=sliceCount();
    for (int i=0; i<sc; i++) {
      INDArray s=slice(i);
      if (!s.isBoolean()) return false;
    }
    return true;
  }
 
  @Override
  public boolean isSparse() {
    return (this instanceof ISparse);
  }
 
  @Override
  public boolean isDense() {
    return (this instanceof IDense);
  }
 
  @Override
  public boolean isMutable() {
    int n=sliceCount();
    for (int i=0; i<n; i++) {
      if (slice(i).isMutable()) return true;
    }
    return false;
  }

  @Override
  public boolean isFullyMutable() {
    int n=sliceCount();
    for (int i=0; i<n; i++) {
      if (!slice(i).isFullyMutable()) return false;
    }
    return true
  }
 
  @Override
  public void applyOp(Op op) {
    int n=sliceCount();
    for (int i=0; i<n; i++) {
      slice(i).applyOp(op);
    }
  }

  @Override
  public void applyOp(IOperator op) {
    int n=sliceCount();
    for (int i=0; i<n; i++) {
      slice(i).applyOp(op);
    }
  }
 

  @Override
  public void multiply(double d) {
    int n=sliceCount();
    for (int i=0; i<n; i++) {
      slice(i).multiply(d);
    }
  }
 
  @Override
  public INDArray multiplyCopy(double d) {
    INDArray r=clone();
    r.multiply(d);
    return r;
  }
 
  @Override
  public INDArray applyOpCopy(Op op) {
    INDArray r=clone();
    r.applyOp(op);
    return r;
  }

  @Override
  public boolean isElementConstrained() {
    int n=sliceCount();
    for (int i=0; i<n; i++) {
      if (slice(i).isElementConstrained()) return true;
    }
    return false;
  }
 
  @Override
  public boolean isSameShape(INDArray a) {
    int dims=dimensionality();
    if (dims!=a.dimensionality()) return false;
    for (int i=0; i<dims; i++) {
      if (getShape(i)!=a.getShape(i)) return false;
    }
    return true;
  }
 
  @Override
  public AVector asVector() {
    if (this instanceof IDenseArray) {
      IDenseArray a=(IDenseArray) this;
      return Vectorz.wrap(a.getArray(), a.getArrayOffset(), (int)elementCount());
    }
    int n=sliceCount();
    AVector result=slice(0).asVector();
    for (int i=1; i<n; i++) {
      result=result.join(slice(i).asVector());
    }
    return result;
  }
 
  @Override
  public void setElements(double[] values, int offset) {
    setElements(0,values,offset,(int)elementCount());
  }
 
  @Override
  public void setElements(int pos, double[] values, int offset, int length) {
    if (length==0) return;
    int ss=(int)(slice(0).elementCount());
    int s1=pos/ss;
    int s2=(pos+length-1)/ss;
    if (s1==s2) {
      slice(s1).setElements(pos-s1*ss,values,offset,length);
      return;
    }
   
    int si=offset;
    int l1 = (s1+1)*ss-pos;
    if (l1>0) {
      slice(s1).setElements(pos-s1*ss, values, si, l1);
      si+=l1;
    }
    for (int i=s1+1; i<s2; i++) {
      slice(i).setElements(values, si);
      si+=ss;
    }
    int l2=(pos+length)-(s2*ss);
    if (l2>0) {
      slice(s2).setElements(0,values,si,l2);
    }
  }
 
  @Override
  public boolean isZero() {
    if (dimensionality()==0) return get()==0.0;
    int sc=sliceCount();
    for (int i=0; i<sc; i++) {
      INDArray s=slice(i);
      if (!s.isZero()) return false;
    }
    return true;
  }
 
  @Override
  public INDArray ensureMutable() {
    if (isFullyMutable()&&!isView()) return this;
    return clone();
  }
 
  @Override
  public void fill(double value) {
    if (dimensionality()==0) {
      set(value);
    } else {
      int sc=sliceCount();
      for (int i=0; i<sc; i++) {
        INDArray s=slice(i);
        s.fill(value);
      }     
    }
  }
 
  @Override
  public INDArray innerProduct(double a) {
    INDArray result=clone();
    result.scale(a);
    return result;
  }
 
  @Override
  public INDArray innerProduct(INDArray a) {
    int dims=dimensionality();
    switch (dims) {
      case 0: {
        a=a.clone();
        a.scale(get());
        return a;
      }
      case 1: {
        return toVector().innerProduct(a);
      }
      case 2: {
        return Matrix.create(this).innerProduct(a);
      }
    }
    int sc=sliceCount();
    ArrayList<INDArray> sips=new ArrayList<INDArray>();
    for (int i=0; i<sc; i++) {
      sips.add(slice(i).innerProduct(a));
    }
    return SliceArray.create(sips);
  }
 
  @Override
  public INDArray innerProduct(AScalar s) {
    return innerProduct(s.get());
  }
 
  @Override
  public INDArray innerProduct(AVector a) {
    return innerProduct((INDArray) a);
  }
 
  @Override
  public INDArray outerProduct(INDArray a) {
    ArrayList<INDArray> al=new ArrayList<INDArray>();
    for (Object s:this) {
      if (s instanceof INDArray) {
        al.add(((INDArray)s).outerProduct(a));
      } else {
        double x=Tools.toDouble(s);
        INDArray sa=a.clone();
        sa.scale(x);
        al.add(sa);
      }
    }
    return Arrayz.create(al);
  }
 
  @Override
  public INDArray getTranspose() {
    return getTransposeCopy();
  }
 
  @Override
  public INDArray getTransposeView() {
    throw new UnsupportedOperationException();
  }
 
  @Override
  public INDArray getTransposeCopy() {
    Array nd=Array.create(this);
    return nd.getTransposeView();
  }
 
  public final void scale(double d) {
    multiply(d);
  }
 
  @Override
  public void scaleAdd(double factor, double constant) {
    multiply(factor);
    add(constant);
  }

  public void set(double value) {
    set(new int[0],value);
  }
 
  public void set(int x, double value) {
    set(new int[] {x},value);
  }
 
  public void set(int x, int y, double value) {
    set(new int[] {x,y},value)
  }
 
  public void set (INDArray a) {
    int tdims=this.dimensionality();
    int adims=a.dimensionality();
    if (adims<tdims) {
      int sc=sliceCount();
      for (int i=0; i<sc; i++) {
        INDArray s=slice(i);
        s.set(a);
      }
    } else if (adims==tdims) {
      if (tdims==0) {
        set(a.get());
        return;
      }
      int sc=sliceCount();
      for (int i=0; i<sc; i++) {
        INDArray s=slice(i);
        s.set(a.slice(i));
      }
    } else {
      throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a));
    }
  }
 
  @Override
  public void clamp(double min, double max) {
    if (dimensionality()==0) {
      set(Maths.bound(get(), min, max));
      return;
    }
   
    int len=sliceCount();
    for (int i = 0; i < len; i++) {
      slice(i).clamp(min, max);
    }
  }
 
  public void set(Object o) {
    if (o instanceof INDArray) {set((INDArray)o); return;}
    if (o instanceof Number) {
      set(((Number)o).doubleValue()); return;
    }
    if (o instanceof Iterable<?>) {
      int i=0;
      for (Object ob: ((Iterable<?>)o)) {
        slice(i).set(ob);
      }
      return;
    }
    if (o instanceof double[]) {
      setElements((double[])o);
      return;
    }
    throw new UnsupportedOperationException("Can't set to value for "+o.getClass().toString());   
  }
 
  public void setElements(double... values) {
    int vl=values.length;
    if (vl!=elementCount()) throw new IllegalArgumentException("Wrong array length: "+vl);
    setElements(0,values,0,vl);
  }
 
  public void square() {
    applyOp(Ops.SQUARE);
  }
 
  @Override
  public INDArray squareCopy() {
    INDArray r=clone();
    r.square();
    return r;
  }
 
  @Override
  public INDArray absCopy() {
    INDArray r=clone();
    r.abs();
    return r;
  }
 
  @Override
  public INDArray reciprocalCopy() {
    INDArray r=clone();
    r.reciprocal();
    return r;
  }
 
  @Override
  public INDArray signumCopy() {
    INDArray r=clone();
    r.signum();
    return r;
  }
 
  @Override
  public Iterator<T> iterator() {
    return new SliceIterator<T>(this);
  }
 
  @Override
  public Iterator<Double> elementIterator() {
    if (dimensionality()==0) {
      return new SingleDoubleIterator(get());
    } else {
      return new SliceElementIterator(this);
    }
  }
 
  public boolean equals(Object o) {
    if (!(o instanceof INDArray)) return false;
    return equals((INDArray)o);
  }
 
  @Override
  public boolean equalsArray(double[] data) {
    if (data.length!=elementCount()) return false;
    return equalsArray(data,0);
  }

  @Override
  public int hashCode() {
    return asVector().hashCode();
  }
 
  @Override
  public String toString() {
    if (dimensionality()==0) {
      return Double.toString(get());
    }
   
    StringBuilder sb=new StringBuilder();
    int length=sliceCount();
    sb.append('[');
    if (length>0) {
      sb.append(slice(0).toString());
      for (int i = 1; i < length; i++) {
        sb.append(',');
        sb.append(slice(i).toString());
      }
    }
    sb.append(']');
    return sb.toString();
  }
 
  @Override
  public INDArray clone() {
    return Arrayz.create(this);
  }
 
  @Override
  public INDArray copy() {
    if (!isMutable()) return this;
    return clone();
  }
 
  @Override
  public INDArray scaleCopy(double d) {
    INDArray r=clone();
    r.scale(d);
    return r;
  }
 
  @Override
  public INDArray negateCopy() {
    INDArray r=clone();
    r.negate();
    return r;
  }
 
  @Override
  public boolean equals(INDArray a) {
    int dims=dimensionality();
    if (a.dimensionality()!=dims) return false;
    if (dims==0) {
      return (get()==a.get());
    } else if (dims==1) {
      return equals(a.asVector());
    } else {
      int sc=sliceCount();
      for (int i=0; i<sc; i++) {
        if (!slice(i).equals(a.slice(i))) return false;
      }
    }
    return true;
  }
 
  public boolean equals(AVector a) {
    if (dimensionality()!=1) return false;
    int sc=sliceCount();
    if (a.length()!=sc) return false;
    for (int i=0; i<sc; i++) {
      if (!(get(i)==a.unsafeGet(i))) return false;
    }
    return true;
  }
 
  @Override
  public boolean equalsArray(double[] data, int offset) {
    int dims=dimensionality();
    if (dims==0) {
      return (data[offset]==get());
    } else if (dims==1) {
      return asVector().equalsArray(data, offset);
    } else {
      int sc=sliceCount();
      int skip=(int) slice(0).elementCount();
      for (int i=0; i<sc; i++) {
        if (!slice(i).equalsArray(data,offset+i*skip)) return false;
      }
    }
    return true;
  }
 
  @Override
  public void add(INDArray a) {
    int dims=dimensionality();
    if (dims==0) {
      add(a.get());
      return;
    }
   
    int adims=a.dimensionality();
    int n=sliceCount();
    int na=a.sliceCount();
    if (dims==adims) {
      if (n!=na) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a));
      for (int i=0; i<n; i++) {
        slice(i).add(a.slice(i));
      }
    } else if (adims<dims) {
      for (int i=0; i<n; i++) {
        slice(i).add(a);
     
    } else {
      throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a));
    }
  }
 
  @Override
  public void add(double a) {
    int dims=dimensionality();
    if (dims ==0) {
      set(a+get());
    } else {
      int n=sliceCount();
      for (int i=0; i<n; i++) {
        slice(i).add(a);
     
    }
  }
 
  @Override
  public void addAt(int i, double v) {
    int ss=(int)(elementCount()/sliceCount());
    int s=i/ss;
    slice(s).addAt(i-s*ss, v);
  }
 
  @Override
  public INDArray addCopy(INDArray a) {
    INDArray r=this.broadcastCloneLike(a);
    r.add(a);
    return r;
  }
 
  @Override
  public INDArray subCopy(INDArray a) {
    INDArray r=this.broadcastCloneLike(a);
    r.sub(a);
    return r;
  }
 
  @Override
  public INDArray multiplyCopy(INDArray a) {
    INDArray r=this.broadcastCloneLike(a);
    r.multiply(a);
    return r;
  }
 
  @Override
  public INDArray divideCopy(INDArray a) {
    INDArray r=this.broadcastCloneLike(a);
    r.divide(a);
    return r;
  }
 
  @Override
  public void addToArray(double[] data, int offset) {
    int dims=dimensionality();
    if (dims ==0) {
      data[offset]+=get();
    } else {
      int n=sliceCount();
      INDArray s0=slice(0);
      int ec=(int) s0.elementCount();
      s0.addToArray(data, offset);
      for (int i=1; i<n; i++) {
        slice(i).addToArray(data, offset+i*ec);
     
    }
  }
 
  @Override
  public void pow(double exponent) {
    int dims=dimensionality();
    if (dims ==0) {
      set(Math.pow(get(), exponent));
    } else {
      int n=sliceCount();
      for (int i=0; i<n; i++) {
        slice(i).pow(exponent);
     
    }
  }
 
  @Override
  public void sub(double a) {
    add(-a);
  }
 
  @Override
  public void multiply(INDArray a) {
    int dims=dimensionality();
    if (dims==0) {set(get()*a.get()); return;}
    int adims=a.dimensionality();
    if (adims==0) {multiply(a.get()); return;}
   
    int n=sliceCount();
    int na=a.sliceCount();
    if (dims==adims) {
      if (n!=na) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a));
      for (int i=0; i<n; i++) {
        slice(i).multiply(a.slice(i));
      }
    } else if (adims<dims) {
      for (int i=0; i<n; i++) {
        slice(i).multiply(a);
     
    } else {
      throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a));
    }
  }
 
  @Override
  public void divide(INDArray a) {
    int dims=dimensionality();
    if (dims==0) {set(get()/a.get()); return;}
    int adims=a.dimensionality();
    if (adims==0) {scale(1.0/a.get()); return;}
   
    int n=sliceCount();
    int na=a.sliceCount();
    if (dims==adims) {
      if (n!=na) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a));
      for (int i=0; i<n; i++) {
        slice(i).divide(a.slice(i));
      }
    } else if (adims<dims) {
      for (int i=0; i<n; i++) {
        slice(i).divide(a);
     
    } else {
      throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a));
    }
  }
 
  @Override
  public void divide(double factor) {
    multiply(1.0/factor);
  }
 
  @Override
  public long nonZeroCount() {
    if (dimensionality()==0) {
      return (get()==0.0)?0:1;
    }
    long result=0;
    int n=sliceCount();
    for (int i=0; i<n; i++) {
      result+=slice(i).nonZeroCount();
    }
    return result;
  }
 
  public double density() {
    return ((double)nonZeroCount())/elementCount();
  }
 
  @Override
  public double elementSum() {
    if (dimensionality()==0) {
      return get();
    }
    double result=0;
    int n=sliceCount();
    for (int i=0; i<n; i++) {
      result+=slice(i).elementSum();
    }
    return result;
  }
 
  @Override
  public double elementProduct() {
    if (dimensionality()==0) {
      return get();
    }
    double result=1.0;
    int n=sliceCount();
    for (int i=0; i<n; i++) {
      result*=slice(i).elementProduct();
      if (result==0.0) return 0.0;
    }
    return result;
  }
 
  @Override
  public double elementMax(){
    if (dimensionality()==0) {
      return get();
    }
    double result=-Double.MAX_VALUE;
    int n=sliceCount();
    for (int i=0; i<n; i++) {
      double v=slice(i).elementMax();
      if (v>result) result=v;
    }
    return result; 
  }
 
  @Override
  public boolean elementsEqual(double value) {
    if (dimensionality()==0) {
      return get()==value;
    }
    int n=sliceCount();
    for (int i=0; i<n; i++) {
      if (!slice(i).elementsEqual(value)) return false;
    }
    return true;     
  }
 
  @Override
  public double elementMin(){
    if (dimensionality()==0) {
      return get();
    }
    double result=Double.MAX_VALUE;
    int n=sliceCount();
    for (int i=0; i<n; i++) {
      double v=slice(i).elementMin();
      if (v<result) result=v;
    }
    return result;
  }
 
  @Override
  public double elementSquaredSum() {
    if (dimensionality()==0) {
      double value=get();
      return value*value;
    }
    double result=0;
    int n=sliceCount();
    for (int i=0; i<n; i++) {
      result+=slice(i).elementSquaredSum();
    }
    return result;
  }


  @Override
  public void sub(INDArray a) {
    int dims=dimensionality();
    if (dims==0) {
      sub(a.get());
      return;
    }
   
    int n=sliceCount();
    int na=a.sliceCount();
    int adims=a.dimensionality();
    if (dims==adims) {
      if (n!=na) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a));
      for (int i=0; i<n; i++) {
        slice(i).sub(a.slice(i));
      }
    } else if (adims<dims) {
      for (int i=0; i<n; i++) {
        slice(i).sub(a);
     
    } else {
      throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a));
   
  }
 
  @Override
  public void negate() {
    multiply(-1.0);
  }
 
  @Override
  public void reciprocal() {
    if (dimensionality()==0) {
      set(1.0/get());
    } else {
      int sc=sliceCount();
      for (int i=0; i<sc; i++) {
        slice(i).reciprocal();
      }
    }
  }
 
  @Override
  public void abs() {
    if (dimensionality()==0) {
      set(Math.abs(get()));
    } else {
      int sc=sliceCount();
      for (int i=0; i<sc; i++) {
        slice(i).abs();
      }
    }
  }
 
  @Override
  public void sqrt() {
    if (dimensionality()==0) {
      set(Math.sqrt(get()));
    } else {
      int sc=sliceCount();
      for (int i=0; i<sc; i++) {
        slice(i).sqrt();
      }
    }
  }
 
  @Override
  public void log() {
    if (dimensionality()==0) {
      set(Math.log(get()));
    } else {
      int sc=sliceCount();
      for (int i=0; i<sc; i++) {
        slice(i).log();
      }
    }
  }
 
  @Override
  public void exp() {
    if (dimensionality()==0) {
      set(Math.exp(get()));
    } else {
      int sc=sliceCount();
      for (int i=0; i<sc; i++) {
        slice(i).exp();
      }
    }
  }
 
  @Override
  public void signum() {
    if (dimensionality()==0) {
      set(Math.signum(get()));
    } else {
      int sc=sliceCount();
      for (int i=0; i<sc; i++) {
        slice(i).signum();
      }
    }
  }
 
 
  @Override
  public INDArray reshape(int... targetShape) {
    return Arrayz.createFromVector(asVector(), targetShape);
  }
 
  @Override
  public INDArray reorder(int dim, int[] order) {
    if (order.length==0) {
      int[] shape=getShapeClone();
      shape[0]=0;
      return Arrayz.createZeroArray(shape);
    }
    int dims=dimensionality();
    if ((dim<0)||(dim>=dims)) throw new IndexOutOfBoundsException(ErrorMessages.invalidDimension(this, dim));
    ArrayList<INDArray> al=new ArrayList<INDArray>();
    for (int i : order) {
      al.add(slice(dim,i));
    }
    int n=al.size();
    int[] shp=this.getShapeClone();
    shp[dim]=n;

    if (dims==2) {
      if (dim==0) {
        return VectorMatrixMN.create(al,shp);
      }
    }
    if (dim==0) {
      return SliceArray.create(al,shp);
    } else {
      Array a=Array.newArray(shp);
      for (int i=0; i<n; i++) {
        a.slice(dim, i).set(al.get(i));
      }
      return a;
    }
  } 
 
  @Override
  public INDArray reorder(int[] order) {
    return reorder(0,order);
  }

 
  @Override
  public List<?> getSlices(int dimension) {
    int l=getShape(dimension);
    ArrayList<INDArray> al=new ArrayList<INDArray>(l);
    for (int i=0; i<l; i++) {
      al.add(slice(dimension,i));
    }
    return al; 
  }
 
  @Override
  public List<?> getSlices() {
    int n=sliceCount();
    ArrayList<Object> al=new ArrayList<Object>(n);
    for (int i=0; i<n; i++) {
      al.add(slice(i));
    }
    return al;
  }
 
  @Override
  public List<INDArray> getSliceViews() {
    int n=sliceCount();
    ArrayList<INDArray> al=new ArrayList<INDArray>(n);
    for (int i=0; i<n; i++) {
      al.add(slice(i));
    }
    return al;
  }
 
  @Override
  public INDArray subArray(int[] offsets, int[] shape) {
    int n=dimensionality();
    if (offsets.length!=n) throw new IllegalArgumentException(ErrorMessages.invalidIndex(this, offsets));
    if (shape.length!=n) throw new IllegalArgumentException(ErrorMessages.invalidIndex(this, offsets));
   
    int[] thisShape=this.getShape();
    if (IntArrays.equals(shape, thisShape)) {
      if (IntArrays.isZero(offsets)) {
        return this;
      } else {
        throw new IllegalArgumentException("Invalid subArray offsets");
      }
    }
   
    ArrayList<INDArray> al=new ArrayList<INDArray>();
    int endIndex=offsets[0]+shape[0];
    int[] zzoffsets=IntArrays.removeIndex(offsets, 0);
    int[] zzshape=IntArrays.removeIndex(shape, 0);
    for (int i=offsets[0]; i<endIndex; i++) {
      al.add(slice(i).subArray(zzoffsets, zzshape));
    }
    return SliceArray.create(al);
  }
 
  @Override
  public INDArray join(INDArray a, int dimension) {
    return JoinedArray.join(this,a,dimension);
  }
 
  @Override
  public INDArray join(INDArray a) {
    return JoinedArray.join(this,a,0);   
  }
 
  @Override
  public INDArray rotateView(int dimension, int shift) {
    int dlen=getShape(dimension);
    int n=dimensionality();
   
    shift = Maths.mod(shift,dlen);
    if (shift==0) return this;
   
    int[] off=new int[n];
    int[] shp=getShapeClone();
   
    shp[dimension]=shift;
    INDArray right=subArray(off,shp);
    shp[dimension]=dlen-shift;
    off[dimension]=shift;
    INDArray left=subArray(off,shp);
    return left.join(right,dimension);
  }
 
  @Override
  public Vector toVector() {
    int n=(int)elementCount();
    double[] data=new double[n];
    this.getElements(data, 0);
    return Vector.wrap(data);
  }
 
  @Override
  public Array toArray() {
    return Array.create(this);
  }
 
  @Override
  public List<Double> asElementList() {
    return asVector().asElementList();
  }

  @Override
  public final double[] getElements() {
    return toDoubleArray();
  }
 
  @Override
  public void getElements(double[] dest, int offset) {
    if (dimensionality()==0) {
      dest[offset]=get();
      return;
    }
    int sc=sliceCount();
    for (int i=0; i<sc; i++) {
      INDArray s=slice(i);
      s.getElements(dest,offset);
      offset+=s.elementCount();
    }
  }
 
  @Override
  public void copyTo(double[] arr) {
    getElements(arr,0);
  }
 
  @Override
  public void toDoubleBuffer(DoubleBuffer dest) {
    int sc=sliceCount();
    for (int i=0; i<sc; i++) {
      INDArray s=slice(i);
      s.toDoubleBuffer(dest);
    }
  }

  @Override
  public double[] toDoubleArray() {
    double[] result=Array.createStorage(this.getShape());
    if (this.isSparse()) {
      addToArray(result,0);
    } else {
      getElements(result,0);
    }
    return result;
  }
 
  @Override
  public double[] asDoubleArray() {
    return null;
  }
 
  @Override
  public INDArray[] toSliceArray() {
    int n=sliceCount();
    INDArray[] al=new INDArray[n];
    for (int i=0; i<n; i++) {
      al[i]=slice(i);
    }
    return al;
  }
 
  @Override
  public Object sliceValue(int i) {
    if (dimensionality()==1) {
      return get(i);
    }
    return slice(i);
  }
 
  @Override
  public INDArray broadcast(int... targetShape) {
    int dims=dimensionality();
    int tdims=targetShape.length;
    if (tdims<dims) {
      throw new IllegalArgumentException(ErrorMessages.incompatibleBroadcast(this, targetShape));
    } else if (dims==tdims) {
      if (IntArrays.equals(targetShape, this.getShape())) return this;
      throw new IllegalArgumentException(ErrorMessages.incompatibleBroadcast(this, targetShape));
    } else {
      int n=targetShape[0];
      INDArray s=broadcast(Arrays.copyOfRange(targetShape, 1, tdims));
      return SliceArray.repeat(s,n);
    }
  }
 
  @Override
  public INDArray immutable() {
    if (!isMutable()) return this;
    return ImmutableArray.create(this);
  }
 
  @Override
  public INDArray mutable() {
    if (isFullyMutable()&&(!isView())) return this;
    return clone();
  }
 
  @Override
  public final INDArray mutableClone() {
    return clone();
  }
 
  @Override
  public INDArray sparse() {
    int dims=dimensionality();
    if (dims==0) return this;
    if (dims==1) {
      if (this instanceof ISparse) return this;
      return Vectorz.createSparse(this.asVector());
    }
    if (dims==2) {
      if (this instanceof ISparse) return this;
      return Matrixx.createSparse(this.getSliceViews());
    }
    int n=this.sliceCount();
    List<INDArray> sls=this.getSliceViews();
    for (int i=0; i<n; i++) {
      sls.set(i,sls.get(i).sparse());
    }
    return SliceArray.create(sls);
  }
 
  @Override
  public INDArray dense() {
    if (this instanceof IDense) return this;
    int dims=dimensionality();
    if (dims==0) {
      if (this instanceof AScalar) return this;
      return Scalar.create(get());
    }
    if (dims==1) {
      return Vector.create(this);
    }
    if (dims==2) {
      return Matrix.create(this);
    }
    return Array.create(this);
  }
 
  @Override
  public INDArray denseClone() {
    int dims=dimensionality();
    if (dims==0) {
      return Scalar.create(get());
    }
    if (dims==1) {
      return Vector.create(this);
    }
    if (dims==2) {
      return Matrix.create(this);
    }
    return Array.create(this);
  }
 
  @Override
  public INDArray sparseClone() {
    int dims=dimensionality();
    if (dims==0) return this;
    if (dims==1) {
      return Vectorz.createSparseMutable(this.asVector());
    }
    if (dims==2) {
      if (this instanceof AMatrix) return Matrixx.createSparseRows((AMatrix)this);
      return Matrixx.createSparseRows(this);
    }
    int n=this.sliceCount();
    List<INDArray> sls=this.getSliceViews();
    for (int i=0; i<n; i++) {
      sls.set(i,sls.get(i).sparseClone());
    }
    return SliceArray.create(sls);
  }
 
  @Override
  public INDArray broadcastLike(INDArray target) {
    return broadcast(target.getShape());
  }
 
  @Override
  public AMatrix broadcastLike(AMatrix target) {
    return Matrixx.toMatrix(broadcast(target.getShape()));
  }
 
  @Override
  public AVector broadcastLike(AVector target) {
    return Vectorz.toVector(broadcast(target.getShape()));
  }
 
  @Override
  public INDArray broadcastCloneLike(INDArray target) {
    int dims=dimensionality();
    int targetDims=target.dimensionality();
    INDArray r=this;
    if (dims<targetDims) r=r.broadcastLike(target);
    return r.clone();
  }
 
  @Override
  public INDArray broadcastCopyLike(INDArray target) {
    if (isMutable()) {
      return broadcastCloneLike(target);
    } else {
      return broadcastLike(target);
    }
  }
 
  @Override
  public void validate() {
    // TODO: any generic validation?
  }

  /**
   * Returns true if any element is this array is NaN or infinite
   * @return
   */
  @Override
  public boolean hasUncountable() {
    if (dimensionality()==0) return Double.isNaN(get()) || Double.isInfinite(get());
    int sc=sliceCount();
    for (int i=0; i<sc; i++) {
      INDArray s=slice(i);
      if (s.hasUncountable()) return true;
    }
    return false;
  }

  /**
     * Returns the sum of all the elements raised to a specified power
     * @return
     */
    public double elementPowSum(double p) {
        if (dimensionality()==0) {
            double value=get();
            return Math.pow(value, p);
        }
        double result=0;
        int n=sliceCount();
        for (int i=0; i<n; i++) {
            result+=slice(i).elementPowSum(p);
        }
        return result;
    }
   
    /**
     * Returns the sum of the absolute values of all the elements raised to a specified power
     * @return
     */
    public double elementAbsPowSum(double p) {
        if (dimensionality()==0) {
            double value=Math.abs(get());
            return Math.pow(value, p);
        }
        double result=0;
        int n=sliceCount();
        for (int i=0; i<n; i++) {
            result+=slice(i).elementAbsPowSum(p);
        }
        return result;
    }

}
TOP

Related Classes of mikera.arrayz.impl.AbstractArray

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.