Package mikera.arrayz

Source Code of mikera.arrayz.TestArrays

package mikera.arrayz;

import java.io.StringReader;
import java.nio.DoubleBuffer;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;

import mikera.arrayz.impl.IDense;
import mikera.arrayz.impl.IStridedArray;
import mikera.arrayz.impl.ImmutableArray;
import mikera.arrayz.impl.JoinedArray;
import mikera.arrayz.impl.SliceArray;
import mikera.indexz.Indexz;
import mikera.matrixx.Matrixx;
import mikera.matrixx.impl.VectorMatrixM3;
import mikera.vectorz.AVector;
import mikera.vectorz.Op;
import mikera.vectorz.Scalar;
import mikera.vectorz.TestOps;
import mikera.vectorz.TestingUtils;
import mikera.vectorz.Tools;
import mikera.vectorz.Vector;
import mikera.vectorz.Vectorz;
import mikera.vectorz.ops.Constant;
import mikera.vectorz.util.DoubleArrays;
import mikera.vectorz.util.IntArrays;
import mikera.vectorz.util.VectorzException;
import static org.junit.Assert.*;

import org.junit.Test;

/**
* Tests for general purpose INDArray implementations
*
* @author Mike
*
*/
public class TestArrays {

  private void testShape(INDArray a) {
    AVector v = a.asVector();
    int[] shape = a.getShape();

    for (int i = 0; i < a.dimensionality(); i++) {
      assertEquals(shape[i], a.getShape(i));
    }

    long[] longShape = a.getLongShape();
    for (int i = 0; i < shape.length; i++) {
      assertEquals(longShape[i], shape[i]);
    }

    assertEquals(a.dimensionality(), shape.length);
    long r = 1;
    for (int i = 0; i < shape.length; i++) {
      r *= shape[i];
    }
    assertEquals(v.length(), r);

    assertEquals(a, a.reshape(shape));
  }
 
  private void testReorder(INDArray a) {
    if (a.dimensionality()==0) return;
    INDArray c=a.copy();
   
    assertEquals(c,a.reorder(Indexz.createSequence(a.sliceCount()).toArray()));
   
    assertEquals(c,a);
  }
 
  private void testSubArray(INDArray a) {
    int n=a.dimensionality();
    INDArray a2=a.subArray(new int[n], a.getShape());
   
    assertEquals(a,a2);
  }
 
  private void testRotateView(INDArray a) {
    int n=a.dimensionality();
    if (n==0) return;
   
    for (int i=0; i<n; i++) {
      int size=a.getShape(i);
      assertEquals(a,a.rotateView(i, 0));
      assertEquals(a,a.rotateView(i, 1).rotateView(i, size-1));
    }
  }

  private void testSlices(INDArray a) {
    int dims=a.dimensionality();
    if ((a.elementCount() == 0) || (dims == 0)) return;

    INDArray sl = a.slice(0);
    assertEquals(sl,a.slice(0,0));
   
    // assertTrue(sl.isView()); not always... damn contained vectors
    assertTrue(sl.dimensionality() == (dims - 1));

    if (a.isFullyMutable()) {
      assert (sl.isFullyMutable());
      assert (sl.isMutable());
    }

    testArray(sl);

    List<?> slices = a.getSlices();
    assertEquals(a, Arrayz.create(slices));
   
    assertEquals(Arrayz.create(slices),Arrayz.create(a.getSlices(0)));
   
    List<?> vslices = a.getSliceViews();
    assertEquals(sl, vslices.get(0));
  }
 
  public void testAdd(INDArray a) {
    INDArray b=TestingUtils.createRandomLike(a, 16786);
    INDArray r=a.addCopy(b);
    assertTrue(r.isSameShape(a));
   
    double[] adata=a.getElements();
    double[] bdata=b.getElements();
    double[] rdata=r.getElements();
    for (int i=0; i<adata.length; i++) {
      assertTrue(rdata[i]==adata[i]+bdata[i]);
    }
   
    b.add(a);
    assertEquals(r,b)
  }
 
  private void testSub(INDArray a) {
    INDArray b=TestingUtils.createRandomLike(a, 16786);
    INDArray r=a.subCopy(b);
    assertTrue(r.isSameShape(a));
   
    double[] adata=a.getElements();
    double[] bdata=b.getElements();
    double[] rdata=r.getElements();
    for (int i=0; i<adata.length; i++) {
      assertTrue(rdata[i]==adata[i]-bdata[i]);
    }
   
    b.sub(a);
    b.negate();
    assertEquals(r,b);   
  }


  private void testAsVector(INDArray a) {
    AVector v = a.asVector();
    assertEquals(a.elementCount(), v.length());
    assertEquals(a.elementSum(), v.elementSum(),0.00000001);

    if (a.isMutable() && (v.length() > 0)) {
      assertTrue(v.isMutable());
      // assertTrue((a==v)||(v.isView())); not always...
    } else {
      if (v.length() > 0) {
        try {
          v.set(0, 10.0);
          fail("Shouldn't be able to set an immutable view vector");
        } catch (Throwable t) {
          // OK
        }
      }
    }
   
    assertEquals(Vector.create(a.getElements()),v);
  }

  private void testToArray(INDArray a) {
    double[] arr = new double[(int) a.elementCount()];
    a.copyTo(arr);
    assertEquals(a.nonZeroCount(),
        DoubleArrays.nonZeroCount(arr, 0, arr.length));

    assertEquals(a, Vector.wrap(arr).reshape(a.getShape()));
    assertEquals(a, Arrayz.wrap(arr, a.getShape()));

  }

  private void testClone(INDArray a) {
    // regular clone
    INDArray c = a.clone();
    assertTrue(c.isFullyMutable());
    assertTrue(c.isSameShape(a));
   
    if (!(c instanceof VectorMatrixM3)) { // allowed to still be a view
      assertFalse(c.isView());
    }
   
    assertTrue(c.equals(a));
    assertTrue(a.equals(c));
    assertEquals(a.hashCode(), c.hashCode());

    if (c == a) {
      // can only return same object if immutable
      assert (!a.isMutable());
    }

    INDArray ec = a.exactClone();
    if (a.isMutable()) assertTrue(ec != a);
    assertEquals(a, ec);
    assertEquals(c, ec);
    assertEquals(a.getClass(), ec.getClass());
   
    // sparse clone
    INDArray sc=a.sparseClone();
    assertTrue("Should have fully mutable sparseClone: "+a.getClass(), (sc.elementCount()==0)||sc.isFullyMutable());
    assertEquals(a,sc);
   
    // sparse coercion
    INDArray sp=a.sparse();
    assertEquals(a,sp);
   
    // immutable coercion
    INDArray imma=a.immutable();
    assertFalse(imma.isMutable());
    assertEquals(a,imma);
    try {
      imma.fill(2.0);
      fail();
    } catch (Throwable t) {
      // OK
    }
   
    // mutable coercion
    INDArray muta=a.exactClone().mutable();
    assertTrue(muta.isFullyMutable());
    muta.fill(2.0);
   
    // dense coercion
    INDArray densa=a.dense();
    assertTrue(densa instanceof IDense);
    assertEquals(a,densa);
   
    INDArray densc=a.denseClone();
    assertTrue(densc.isDense());
    assertEquals(a,densc);
  }

  private void testSetElements(INDArray a) {
    if (!a.isFullyMutable()) return;

    a = a.exactClone();
    INDArray c = a.clone();
    a.fill(Double.NaN);

    double[] arr = c.asVector().toDoubleArray();
    assertEquals(a.elementCount(), arr.length);
    a.setElements(arr);
    assertEquals(c.asVector(), a.asVector());
    assertEquals(c, a);
  }

  private void testGetElements(INDArray a) {
    int ecount = (int) a.elementCount();
    double[] data = new double[ecount + 1];
    Arrays.fill(data, Double.NaN);

    a.getElements(data, 1);
    assertTrue(Double.isNaN(data[0]));
    for (int i = 1; i < data.length; i++) {
      assertFalse(Double.isNaN(data[i]));
    }
    assertTrue(a.equalsArray(data,1));
   
    double[] data2=new double[ecount+1];
    data[0]=13;
    data2[0]=13;
    a.asVector().getElements(data2, 1);
    assertTrue(DoubleArrays.equals(data, data2));
   
    if (!a.isFullyMutable()) return;

    INDArray b = a.exactClone();
    b.fill(Double.NaN);
    b.setElements(data, 1);
    assertEquals(a, b);
    b.setElements(0,data, 1, ecount);
    assertEquals(a, b);
  }

  private void testApplyOp(INDArray a) {
    if (!a.isFullyMutable()) return;
    INDArray c = a.exactClone();
    INDArray d = a.exactClone();

    c.asVector().fill(5.0);
    d.applyOp(Constant.create(5.0));
    assertTrue(c.equals(d));
  }

  private void testApplyAllOps(INDArray a) {
    if (!a.isFullyMutable()) return;
    for (Op op : TestOps.ALL_OPS) {
      if (op.isStochastic()) continue;
      int n = (int) a.elementCount();

      INDArray b = a.exactClone();
      INDArray c = a.exactClone();
      INDArray d = a.exactClone();
      AVector v = Vector.create(b);
      assertEquals(n,v.length());
      double[] ds = new double[n];
      double[] tmp = new double[n];
      a.getElements(ds, 0);

      boolean nan=false;
      for (int i = 0; i < n; i++) {
        double x= op.apply(v.unsafeGet(i));
        if (Double.isNaN(x)) nan=true;
        tmp[i] = x;
      }
      if (nan) continue; // TODO: compare NaNs properly
      b.validate();
      op.applyTo(b);
      c.applyOp(op);
      op.applyTo(v);
      op.applyTo(ds);
      op.applyTo(d.asVector());

      if (b.dimensionality()>0) for (int i=0; i<b.sliceCount(); i++) {
        assertEquals(b.slice(i),c.slice(i));
      }
      if (!b.equals(c)) {
        b=a.exactClone();
        b.validate();
        op.applyTo(b);
        b.validate();
        c.validate();
        assertEquals(b, c);
      }
      assertEquals(b, d);
      assertEquals(v, b.toVector());
      assertEquals(v, Vector.wrap(ds));
      assertEquals(v, Vector.wrap(tmp));
    }
  }

  private void testMutability(INDArray a) {
    if (a.isFullyMutable() && (a.elementCount() > 0)) {
      assertTrue(a.isMutable());
    }

    if (a.isElementConstrained()) {
      assertFalse(a.isFullyMutable());
    }
   
    INDArray b=a.mutable();
    assertTrue(mikera.vectorz.util.Testing.validateFullyMutable(b));
   
    if ((!a.isMutable())&&(a.elementCount()>0)) {
      try {
        a.asVector().set(0,Math.PI);
        // System.out.println(a.getClass());
        fail("Set on immutable array succeeded!");
      } catch (UnsupportedOperationException t) {
        // OK
      } catch (VectorzException t) {
        // Also OK
      }
    }
  }

  private void testHash(INDArray a) {
    assertEquals(a.asVector().hashCode(), a.hashCode());
  }

  private void testEquals(INDArray a) {
    assertEquals(a, a.exactClone());
    assertEquals(a, a.clone());
    assertEquals(a, a.sparse());

    assertTrue(a.epsilonEquals(a.exactClone()));

    if ((!a.isFullyMutable()) || (a.elementCount() == 0)) return;

    INDArray b = a.exactClone();
    b.add(Vectorz.TEST_EPSILON * 0.5);
    assertFalse(a.equals(b));
    assertFalse(b.equals(a));
    assertTrue(a.epsilonEquals(b));
    assertTrue(b.epsilonEquals(a));

    b.add(Vectorz.TEST_EPSILON * 1.5);
    assertFalse(a.equals(b));
    assertFalse(b.equals(a));
    assertFalse(a.epsilonEquals(b));
    assertFalse(b.epsilonEquals(a));

  }

  private void testParserRoundTrip(INDArray a) {
    String s = a.toString();
    if (a.elementCount()==0) return;
    assertEquals(a, Arrayz.load(new StringReader(s)));
    assertEquals(a, Arrayz.parse(s));
  }

  private void testBufferRoundTrip(INDArray a) {
    int len = (int) a.elementCount();
    DoubleBuffer buf = DoubleBuffer.allocate(len);
    assertEquals(len, buf.remaining());
    a.toDoubleBuffer(buf);
    assertEquals(0, buf.remaining());
    buf.flip();
    AVector vv = Vectorz.create(buf);
    assertEquals(a.asVector(), vv);
    assertEquals(a, vv.reshape(a.getShape()));
  }

  private void testMultiply(INDArray a) {
    assertTrue(a.scaleCopy(0.0).isZero());
    assertEquals(a, a.scaleCopy(1.0));

    if (!a.isFullyMutable()) return;
    INDArray m = a.exactClone();

    m.multiply(2.0);
    m.multiply(Scalar.create(0.5));
    assertEquals(a, m);
   
    m.multiply(0.0);
    assertTrue(m.isZero());
   
    INDArray m1 = a.exactClone();
    INDArray m2 = a.exactClone();
    Vectorz.fillRandom(m1.asVector());
    m1.add(1.0);
   
    m2.multiply(m1);
    m2.divide(m1);
    assertTrue(m2.epsilonEquals(a))
  }
 
  private void testReciprocal(INDArray a) {
    a=a.exactClone();
   
    INDArray ra=a.reciprocalCopy();
   
    if (a.isFullyMutable()&&(!ra.hasUncountable())) {
      a.reciprocal();
      assertTrue(a.epsilonEquals(ra));
    }
  }
 
  private void testBoolean(INDArray a) {
    assertEquals(a.isBoolean(),DoubleArrays.isBoolean(Tools.getElements(a)));
  }

  private void testBroadcast(INDArray a) {
    int dims = a.dimensionality();
    int[] ts = new int[dims + 2];
    ts[0] = 1;
    ts[1] = 2;
    System.arraycopy(a.getShape(), 0, ts, 2, dims);

    INDArray b = a.broadcast(ts);
    int[] bs = b.getShape();
    for (int i = 0; i < ts.length; i++) {
      assertEquals(ts[i], bs[i]);
    }

    assertEquals(a, b.slice(0).slice(1));
  }
 
  private void testIllegalBroadcast(INDArray a) {
    if (a.dimensionality()==0) return; // scalars can broadcast to anything :-)
    int[] ts=IntArrays.consArray(1,IntArrays.incrementAll(a.getShape()));
    try {
      a.broadcast(ts);
      fail("Broadcast should not be possible!!");
    } catch (IllegalArgumentException t) {
      // OK
    }
   
    try {
      a.broadcastLike(Array.newArray(ts));
      fail("Broadcast should not be possible!!");
    } catch (IllegalArgumentException t) {
      // OK
    }
  }
 
  private void testBroadcastLike(INDArray a) {
    INDArray up=SliceArray.of(a,a);   
    INDArray b=a.broadcastLike(up);   
    assertEquals(up,b);
    assertEquals(a,up.slice(0));
    assertEquals(a,up.slice(1));
   
    INDArray up2=Arrayz.create(SliceArray.of(a));   
    INDArray b2=a.broadcastLike(up2);   
    assertEquals(up2,b2);
   
    INDArray bcl=a.broadcastCloneLike(a);
    assertEquals(a,bcl);
    assertTrue((a.elementCount()==0)||(a!=bcl));
  }

  private void testSums(INDArray a) {
    INDArray b = a.clone();
    assertEquals(a.elementSum(), b.elementSum(), 0.000001);
    assertEquals(a.elementSquaredSum(), b.elementSquaredSum(), 0.000001);
    assertEquals(a.elementSum(), a.toVector().elementSum(), 0.000001);
    assertEquals(a.elementSquaredSum(), a.toVector().elementSquaredSum(),
        0.000001);

    b.multiply(a);
    assertEquals(a.elementSquaredSum(), b.elementSum(), 0.000001);
  }

  private void testTranspose(INDArray a) {
    assertEquals(a, a.getTranspose().getTranspose());
    try {
      assertEquals(a.getTransposeCopy(), a.getTransposeView());
    } catch (UnsupportedOperationException x) {
      // OK, not all matrices have transpose views
    }
  }

  private void testClamp(INDArray a) {
    if ((!a.isFullyMutable()) || (a.elementCount() == 0)) return;

    INDArray b = a.clone();
    a = a.exactClone();

    b.clamp(-Double.MAX_VALUE, Double.MAX_VALUE);
    assertEquals(a, b);

    a.fill(13.0);
    a.clamp(14, 15);
    assertEquals(14.0, Vectorz.minValue(a.toVector()), 0.0001);
    a.clamp(12, 17);
    assertEquals(14.0, Vectorz.maxValue(a.toVector()), 0.0001);
  }
 
  private void testElementIterator(INDArray m) {
    Iterator<Double> it=m.elementIterator();
   
    int i=0;
    AVector av=m.asVector();
    while (it.hasNext()) {
      double v=it.next();
      assertEquals(av.get(i++),v,0.0);
    }
    assertEquals(m.elementCount(),i);
  }
 
  private void testElementSums(INDArray m) {
    double es=m.elementSum();
    double ess=m.elementSquaredSum();
   
    assertEquals(es,m.asVector().elementSum(),0.0001);
    assertEquals(m.getClass().toString(),es,m.elementPowSum(1.0),0.0001);
    assertEquals(ess,m.elementAbsPowSum(2.0),0.0001);

  }
 
  private void testStridedArray(INDArray mm) {
    if (!(mm instanceof IStridedArray)) {
      assertNull(mm.asDoubleArray());
      return;
    }
    IStridedArray m=(IStridedArray)mm;
   
    int dims=m.dimensionality();
    int[] shape=m.getShape();
    int[] strides=m.getStrides();
   
    if (m.elementCount()==0) return;
   
    if (mm.isMutable()) {
      double[] data=m.getArray();
      for (int i=0; i<dims; i++) {
        assertEquals(m.getStride(i),strides[i]);
      }
     
      if (m.isPackedArray()) {
        assertNotNull(m.asDoubleArray());
        assertTrue(m.asDoubleArray()==m.getArray());
      } else {
        assertNull(m.asDoubleArray());
      }
     
      int[] ix = IntArrays.rand(shape);
      int off=m.getArrayOffset()+IntArrays.dotProduct(strides,ix);
      assertEquals(data[off],m.get(ix),0.0);
   
  }

  private void testMathsFunctions(INDArray a) {
    if (!a.isFullyMutable()) return;
    a = a.exactClone();

    if (a.nonZeroCount() == a.elementCount()) {
      INDArray a1 = a.exactClone();
      a1.reciprocal();
      a1.reciprocal();
      assertTrue(a.toVector().epsilonEquals(a1.toVector()));
    }

    INDArray b = a.exactClone();
    AVector v = b.toVector();
    b.pow(2.5);
    v.pow(2.5);
    assertTrue(v.epsilonEquals(b.toVector()));

    b = a.exactClone();
    v = b.toVector();
    b.square();
    v.square();
    assertEquals(v, b.toVector());
  }
 
  private void testIndexedAccess(INDArray a) {
    if ((a.elementCount() <= 1)) return;
    int dims=a.dimensionality();

    // 0d indexed access
    try {
      a.get();
      fail("0d get should fail for array with shape: "+Arrays.toString(a.getShape()));
    } catch (Throwable t) { /* OK */ }
   
    try {
      a.set(1.0);
      fail("0d set should fail for array with shape: "+Arrays.toString(a.getShape()));
    } catch (Throwable t) { /* OK */ }

    try {
      a.get(IntArrays.EMPTY_INT_ARRAY);
      fail("0d get should fail for array with shape: "+Arrays.toString(a.getShape()));
    } catch (Throwable t) { /* OK */ }
   
    try {
      a.set(IntArrays.EMPTY_INT_ARRAY,1.0);
      fail("0d set should fail for array with shape: "+Arrays.toString(a.getShape()));
    } catch (Throwable t) { /* OK */ }

    // 1D indexed access
    if (a.dimensionality()>1) {
      try {
        a.get(0);
        fail("1d get should fail for array with shape: "+Arrays.toString(a.getShape()));
      } catch (Throwable t) { /* OK */ }
     
      try {
        a.set(0,1.0);
        fail("1d set should fail for array with shape: "+Arrays.toString(a.getShape()));
      } catch (Throwable t) { /* OK */ }
    }
   
    assertEquals(a.getElements()[0],a.get(new int[dims]),0.0);
   
    try {
      int[] ix=new int[dims];
      ix[0]=-1;
      fail("Should get exception on invalid index");
    } catch (Throwable t) { /* OK */ }
  }

  public void testArray(INDArray a) {
    a.validate();
    testTranspose(a);
    testAsVector(a);
    testAdd(a);
    testSub(a);
    testToArray(a);
    testMultiply(a);
    testReciprocal(a);
    testApplyOp(a);
    testApplyAllOps(a);
    testElementIterator(a);
    testElementSums(a);
    testStridedArray(a);
    testBoolean(a);
    testSums(a);
    testEquals(a);
    testReorder(a);
    testIndexedAccess(a);
    testMathsFunctions(a);
    testSetElements(a);
    testGetElements(a);
    testBroadcast(a);
    testBroadcastLike(a);
    testIllegalBroadcast(a);
    testShape(a);
    testClamp(a);
    testHash(a);
    testClone(a);
    testMutability(a);
    testSlices(a);
    testSubArray(a);
    testRotateView(a);
    testParserRoundTrip(a);
    testBufferRoundTrip(a);
  }

  @Test
  public void g_SliceArray() {
    SliceArray<AVector> sa = SliceArray.of(
        Vectorz.createUniformRandomVector(10),
        Vectorz.createUniformRandomVector(10));
    testArray(sa);
    testArray(Array.create(sa));
  }

  @Test
  public void g_NDArray() {
    NDArray nd1 = NDArray.newArray(3);
    Vectorz.fillIndexes(nd1.asVector());
    testArray(nd1);
    testArray(Array.create(nd1));
   
    NDArray nd2 = NDArray.newArray(3, 3);
    assertEquals(9,nd2.elementCount());
    Vectorz.fillIndexes(nd2.asVector());
    testArray(nd2);
    testArray(Array.create(nd2));
   
    NDArray nd3 = NDArray.newArray(3, 3, 3);
    Vectorz.fillIndexes(nd3.asVector());
    testArray(nd3);
    testArray(Array.create(nd3));
  }

  @Test
  public void g_NDScalar() {
    NDArray ndscalar = NDArray.newArray();
    ndscalar.set(1.0);
    testArray(ndscalar);
    testArray(Array.create(ndscalar));
  }
   
  @Test
  public void g_JoinedArray() {
    testArray(JoinedArray.join(Vector.of(1,2),Vector.of(1,2,3,4,5),0));
    testArray(JoinedArray.join(NDArray.newArray(3, 3),NDArray.newArray(3, 3),1));
  }
 
  @Test
  public void g_ImmutableArray() {
    // immutable array tests
    testArray(ImmutableArray.create(Matrixx.createRandomMatrix(4, 5)));
    testArray(ImmutableArray.create(Vectorz.createUniformRandomVector(4)));
    testArray(ImmutableArray.create(Scalar.create(4)));
  }
   
  @Test
  public void g_ZeroArray() {
    // zero array tests
    testArray(Arrayz.createZeroArray());
    testArray(Arrayz.createZeroArray(2));
    testArray(Arrayz.createZeroArray(2,3));
    testArray(Arrayz.createZeroArray(1,2,3));
    testArray(Arrayz.createZeroArray(1,2,4,1));

  }
}
TOP

Related Classes of mikera.arrayz.TestArrays

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.