Package mikera.arrayz

Source Code of mikera.arrayz.Arrayz

package mikera.arrayz;

import java.io.Reader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

import mikera.arrayz.impl.SliceArray;
import mikera.arrayz.impl.ZeroArray;
import mikera.matrixx.Matrix;
import mikera.matrixx.Matrixx;
import mikera.matrixx.impl.StridedMatrix;
import mikera.matrixx.impl.ZeroMatrix;
import mikera.vectorz.AScalar;
import mikera.vectorz.AVector;
import mikera.vectorz.Scalar;
import mikera.vectorz.Vector;
import mikera.vectorz.Vectorz;
import mikera.vectorz.impl.ArrayIndexScalar;
import mikera.vectorz.impl.ArraySubVector;
import mikera.vectorz.impl.ImmutableScalar;
import mikera.vectorz.impl.Vector0;
import mikera.vectorz.impl.ZeroVector;
import mikera.vectorz.util.ErrorMessages;
import mikera.vectorz.util.IntArrays;
import mikera.vectorz.util.VectorzException;
import us.bpsm.edn.parser.Parseable;
import us.bpsm.edn.parser.Parser;
import us.bpsm.edn.parser.Parsers;

/**
* Static function class for array operations
*
* @author Mike
*/
public class Arrayz {
  /**
   * Creates an array from the given data
   *
   * Handles double arrays, INDArray instances, and lists
   *
   * @param object
   * @return
   */
  @SuppressWarnings("unchecked")
  public static INDArray create(Object object) {
    if (object instanceof INDArray) return create((INDArray)object);
   
    if (object instanceof double[]) return Vector.of((double[])object);
    if (object instanceof List<?>) {
      List<?> list=(List<Object>) object;
      if (list.size()==0) return Vector0.INSTANCE;
      Object o1=list.get(0);
      if ((o1 instanceof AScalar)||(o1 instanceof Number)) {
        return Vectorz.create((List<Object>)object);
      } else if (o1 instanceof AVector) {
        return Matrixx.create((List<Object>)object);
      } else if (o1 instanceof INDArray) {
        return SliceArray.create((List<INDArray>)object);       
      } else {
        ArrayList<INDArray> al=new ArrayList<INDArray>();
        for (Object o: list) {
          al.add(create(o));
        }
        return Arrayz.create(al);
      }
    }
   
    if (object instanceof Number) return Scalar.create(((Number)object).doubleValue());
   
    if (object.getClass().isArray()) {
      return create(Arrays.asList((Object[])object));
    }
   
    throw new VectorzException("Don't know how to create array from: "+object.getClass());
  }
 
  /**
   * Create a new array instance with the given shape. New array will be filled with zeroes.
   *
   * @param shape
   * @return
   */
  public static INDArray newArray(int... shape) {
    int dims=shape.length;
   
    switch (dims) {
      case 0: return Scalar.create(0.0);
      case 1: return Vector.createLength(shape[0]);
      case 2: return Matrix.create(shape[0], shape[1]);
      default: return Array.newArray(shape);
    }
  }
 
  public static INDArray create(INDArray a) {
    int dims=a.dimensionality();
    switch (dims) {
    case 0:
      return Scalar.create(a.get());
    case 1:
      return Vector.wrap(a.toDoubleArray());
    case 2:
      return Matrix.wrap(a.getShape(0), a.getShape(1), a.toDoubleArray());
    default:
      return Array.wrap(a.toDoubleArray(),a.getShape());
    }
  }
 
  /**
   * Creates an array using the given data as slices.
   *
   * @param data
   * @return
   */
  public static INDArray create(Object... data) {
    return create((Object)data);
  }
 
  /**
   * Creates an INDArray instance wrapping the given double data, with the provided shape.
   *
   * @param data
   * @param shape
   * @return
   */
  public static INDArray wrap(double[] data, int[] shape) {
    int dlength=data.length;
    switch (shape.length) {
      case 0:
        return ArrayIndexScalar.wrap(data,0);
       
      case 1:
        int n=shape[0];
        if (dlength<n) throw new IllegalArgumentException(ErrorMessages.insufficientElements(dlength));
        if (n==dlength) {
          return Vector.wrap(data);
        } else {
          return ArraySubVector.wrap(data, 0, n);
        }
       
      case 2:
        int rc=shape[0], cc=shape[1];
        int ec=rc*cc;
        if (dlength<ec) throw new IllegalArgumentException(ErrorMessages.insufficientElements(dlength));
        if (ec==dlength) {
          return Matrix.wrap(rc,cc, data);
        } else {
          return StridedMatrix.wrap(data, shape[0], shape[1], 0, shape[1], 1);
        }
   
      default:
        long eec=IntArrays.arrayProduct(shape);
        if (dlength<eec) throw new IllegalArgumentException(ErrorMessages.insufficientElements(dlength));
        if (eec==dlength) {
          return Array.wrap(data, shape);
        } else {
          return NDArray.wrap(data, shape);
        }
    }
  }

  /**
   * Creates a new array using the elements in the specified vector.
   * Truncates or zero-pads the data as required to fill the new array
   * @param data
   * @param rows
   * @param columns
   * @return
   */
  public static INDArray createFromVector(AVector a, int... shape) {
    int dims=shape.length;
    if (dims==0) {
      return Scalar.createFromVector(a);
    } else if (dims==1) {
      return Vector.createFromVector(a,shape[0]);
    } else if (dims==2) {
      return Matrixx.createFromVector(a, shape[0], shape[1]);
    } else {
      return Array.createFromVector(a,shape);
    }
  }
 
  public static INDArray load(Reader reader) {
    Parseable pbr=Parsers.newParseable(reader);
    Parser p = Parsers.newParser(Parsers.defaultConfiguration());
    return Arrayz.create(p.nextValue(pbr));
  }
 
  /**
   * Parse an array from a String. String should be in edn format
   *
   * @param ednString
   * @return
   */
  public static INDArray parse(String ednString) {
    return load(new StringReader(ednString))
  }

  public static INDArray wrapStrided(double[] data, int offset, int[] shape, int[] strides) {
    int dims=shape.length;
    if (dims==0) {
      return ArrayIndexScalar.wrap(data, offset);
    } else if (dims==1) {
      return Vectorz.wrapStrided(data, offset, shape[0], strides[0]);
    } else if (dims==2) {
      return Matrixx.wrapStrided(data, shape[0],shape[1], offset, strides[0],strides[1]);
    } else {
      if (isPackedLayout(data,offset,shape,strides)) {
        return Array.wrap(data, shape);
      } else {
        return NDArray.wrapStrided(data,offset,shape,strides);
      }
    }
  }
 
  public static boolean isPackedLayout(double[] data, int offset, int[] shape, int[] strides) {
    if (offset!=0) return false;
    int dims=shape.length;
    int st=1;
    for (int i=dims-1; i>=0; i--) {
      if (strides[i]!=st) return false;
      st*=shape[i];
    }
    return (st==data.length);
  }

  /**
   * Checks if the given set of strides represents a fully packed, row major layout for the given shape
   * @param shape
   * @param strides
   * @return
   */
  public static boolean isPackedStrides(int[] shape, int[] strides) {
    int dims=shape.length;
    int st=1;
    for (int i=dims-1; i>=0; i--) {
      if (strides[i]!=st) return false;
      st*=shape[i];
    }
    return true;
  }

  public static INDArray createSparse(INDArray a) {
    int dims=a.dimensionality();
    if (dims==0) {
      return Scalar.create(a.get());
    } else if (dims==1) {
      return Vectorz.createSparse(a.asVector());
    } else if (dims==2) {
      return Matrixx.createSparse(Matrixx.toMatrix(a));
    } else {
      int n=a.sliceCount();
      List<INDArray> slices=a.getSliceViews();
      for (int i=0; i<n; i++) {
        slices.set(i, slices.get(i).sparseClone());
      }
      return SliceArray.create(slices)
    }
  }

  /**
   * Creates an immutable zero-filled array of the given shape
   *
   * @param shape
   * @return
   */
  public static INDArray createZeroArray(int... shape) {
    switch (shape.length) {
      case 0: return ImmutableScalar.ZERO;
      case 1: return ZeroVector.create(shape[0]);
      case 2: return ZeroMatrix.create(shape[0],shape[1]);
      default: return ZeroArray.create(shape);
    }
  }

  public static void fillRandom(INDArray a, long seed) {
    Vectorz.fillRandom(a.asVector(),seed);
  }
 
  public static void fillRandom(INDArray a, Random random) {
    Vectorz.fillRandom(a.asVector(),random);
  }

  public static void fillNormal(INDArray a, long seed) {
    Vectorz.fillNormal(a.asVector(),seed);
  }
 
  public static void fillNormal(INDArray a, Random random) {
    Vectorz.fillNormal(a.asVector(),random);
  }
}
TOP

Related Classes of mikera.arrayz.Arrayz

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.