package mikera.vectorz.impl;
import java.nio.DoubleBuffer;
import mikera.vectorz.AVector;
import mikera.vectorz.Op;
import mikera.vectorz.util.ErrorMessages;
import mikera.vectorz.util.IntArrays;
import mikera.vectorz.util.VectorzException;
/**
* A vector that represents the concatenation of many vectors.
*
* @author Mike
*
*/
public final class JoinedMultiVector extends AJoinedVector {
private static final long serialVersionUID = 6226205676178066609L;
/**
* The number of segments
*/
private final int n;
/**
* The set of vectors that define all the segments of this JoinedMultiVector
*/
private final AVector[] vecs;
// array of split positions [0, ...... , length] with length n+1
private final int[] splits;
private JoinedMultiVector(int length, AVector[] vecs, int[] splits) {
super(length);
n=vecs.length;
this.vecs=vecs;
this.splits=splits;
}
private JoinedMultiVector(AVector[] vs) {
this(sumOfLengths(vs),vs,new int[vs.length+1]);
int j=0;
for (int i=0; i<n ; i++) {
j+=vs[i].length();
splits[i+1]=j;
}
}
private static final int sumOfLengths(AVector[] vs) {
int result=0;
for (AVector v:vs) {
result+=v.length();
}
return result;
}
@Override
public boolean isFullyMutable() {
for (AVector v:vecs) {
if (!(v.isFullyMutable())) return false;
}
return true;
}
@Override
public void copyTo(AVector dest, int offset) {
for (AVector v:vecs) {
v.copyTo(dest, offset);
offset+=v.length();
}
}
@Override
public void toDoubleBuffer(DoubleBuffer dest) {
for (AVector v:vecs) {
v.toDoubleBuffer(dest);
}
}
@Override
public void addToArray(int offset, double[] array, int arrayOffset, int length) {
int start=offset;
int end=start+length;
if ((start<0)||(end>this.length)||(length<0)) throw new IndexOutOfBoundsException(ErrorMessages.invalidRange(this,offset, length));
int i1=IntArrays.indexLookup(splits, start);
int i2=IntArrays.indexLookup(splits, end-1);
if (i1==i2) {
offset-=splits[i1];
vecs[i1].addToArray(offset, array,arrayOffset,length);
return;
}
vecs[i1].addToArray(offset-splits[i1], array, arrayOffset,splits[i1+1]-offset);
vecs[i2].addToArray(0, array, arrayOffset+splits[i2]-offset,end-splits[i2]);
for (int i=i1+1;i<i2; i++) {
int io=splits[i]-offset;
vecs[i].addToArray(array, arrayOffset+io);
}
}
@Override
public void addToArray(double[] dest, int offset, int stride) {
for (int i=0; i<n; i++) {
vecs[i].addToArray(dest, offset+stride*splits[i],stride);
}
}
@Override
public void addMultipleToArray(double factor,int offset, double[] array, int arrayOffset, int length) {
int start=offset;
int end=start+length;
if ((start<0)||(end>this.length)||(length<0)) throw new IndexOutOfBoundsException(ErrorMessages.invalidRange(this,offset, length));
int i1=IntArrays.indexLookup(splits, start);
int i2=IntArrays.indexLookup(splits, end-1);
if (i1==i2) {
offset-=splits[i1];
vecs[i1].addMultipleToArray(factor,offset, array,arrayOffset,length);
return;
}
vecs[i1].addMultipleToArray(factor,offset-splits[i1], array, arrayOffset,splits[i1+1]-offset);
vecs[i2].addMultipleToArray(factor,0, array, arrayOffset+splits[i2]-offset,end-splits[i2]);
for (int i=i1+1;i<i2; i++) {
int io=splits[i]-offset;
vecs[i].addMultipleToArray(factor,0,array, arrayOffset+io,vecs[i].length());
}
}
@Override
public void addAt(int i, double v) {
int j=IntArrays.indexLookup(splits, i);
int joff=i-splits[j];
vecs[j].addAt(joff,v);
}
@Override
public void getElements(double[] data, int offset) {
for (int i=0; i<n; i++) {
vecs[i].getElements(data, offset+splits[i]);
}
}
@Override
public void multiplyTo(double[] data, int offset) {
for (int i=0; i<n; i++) {
vecs[i].multiplyTo(data, offset+splits[i]);
}
}
@Override
public void divideTo(double[] data, int offset) {
for (int i=0; i<n; i++) {
vecs[i].divideTo(data, offset+splits[i]);
}
}
@Override
public void copyTo(int start, AVector dest, int destOffset, int length) {
subVector(start,length).copyTo(dest, destOffset);
}
@Override
public AVector subVector(int start, int length) {
int end=start+length;
if (length==0) return Vector0.INSTANCE;
if ((start<0)||(end>this.length)||(length<0)) throw new IndexOutOfBoundsException(ErrorMessages.invalidRange(this,start, length));
if (length==this.length) return this;
int i1=IntArrays.indexLookup(splits, start);
int i2=IntArrays.indexLookup(splits, end-1);
if (i1==i2) {
return vecs[i1].subVector(start-splits[i1], length);
}
int nn =i2-i1+1;
AVector[] nvecs=new AVector[nn];
nvecs[0]=vecs[i1].subVector(start-splits[i1], splits[i1+1]-start);
nvecs[nn-1]=vecs[i2].subVector(0, end-splits[i2]);
for (int i=1; i<(i2-i1); i++) {
nvecs[i]=vecs[i1+i];
}
return new JoinedMultiVector(nvecs);
}
@Override
public AVector tryEfficientJoin(AVector v) {
if (v instanceof JoinedMultiVector) return join((JoinedMultiVector)v);
if (v instanceof JoinedVector) return join((JoinedVector)v);
AVector ej=vecs[n-1].tryEfficientJoin(v);
if (ej!=null) {
AVector[] nvecs=vecs.clone();
nvecs[n-1]=ej;
return new JoinedMultiVector(nvecs);
} else {
AVector[] nvecs=new AVector[n+1];
System.arraycopy(vecs, 0, nvecs, 0, n);
nvecs[n]=v;
return new JoinedMultiVector(nvecs);
}
}
public AVector join(JoinedMultiVector v) {
AVector[] nvecs=new AVector[n+v.n];
System.arraycopy(vecs, 0, nvecs, 0, n);
System.arraycopy(v.vecs, 0, nvecs, n, v.n);
return new JoinedMultiVector(nvecs);
}
public AVector join(JoinedVector v) {
AVector ej=vecs[n-1].tryEfficientJoin(v.left);
if (ej!=null) {
AVector[] nvecs=new AVector[n+1];
System.arraycopy(vecs, 0, nvecs, 0, n);
nvecs[n-1]=ej;
nvecs[n]=v.right;
return new JoinedMultiVector(nvecs);
} else {
AVector[] nvecs=new AVector[n+2];
System.arraycopy(vecs, 0, nvecs, 0, n);
nvecs[n]=v.left;
nvecs[n+1]=v.right;
return new JoinedMultiVector(nvecs);
}
}
@Override
public void add(AVector a) {
assert(length()==a.length());
if (a instanceof JoinedMultiVector) {
add((JoinedMultiVector)a);
} else {
add(a,0);
}
}
public void add(JoinedMultiVector a) {
if (IntArrays.equals(splits, a.splits)) {
for (int i=0; i<n; i++) {
vecs[i].add(a.vecs[i]);
}
} else {
add(a,0);
}
}
@Override
public void scaleAdd(double factor, double constant) {
for (int i=0; i<n; i++) {
vecs[i].scaleAdd(factor,constant);
}
}
@Override
public void add(double constant) {
for (int i=0; i<n; i++) {
vecs[i].add(constant);
}
}
@Override
public void reciprocal() {
for (int i=0; i<n; i++) {
vecs[i].reciprocal();
}
}
@Override
public void clamp(double min, double max) {
for (int i=0; i<n; i++) {
vecs[i].clamp(min, max);
}
}
@Override
public double dotProduct (AVector v) {
if (v instanceof ADenseArrayVector) {
ADenseArrayVector av=(ADenseArrayVector)v;
return dotProduct(av.getArray(),av.getArrayOffset());
}
return super.dotProduct(v);
}
@Override
public double dotProduct(double[] data, int offset) {
double result=0.0;
for (int i=0; i<n; i++) {
result+=vecs[i].dotProduct(data, offset+splits[i]);
}
return result;
}
@Override
public void add(AVector a,int aOffset) {
for (int i=0; i<n; i++) {
vecs[i].add(a, aOffset+splits[i]);
}
}
@Override
public void add(double[] data,int offset) {
for (int i=0; i<n; i++) {
vecs[i].add(data, offset+splits[i]);
}
}
@Override
public void add(int offset, AVector a) {
add(offset,a,0,a.length());
}
@Override
public void addMultiple(AVector a, double factor) {
for (int i=0; i<n; i++) {
vecs[i].addMultiple(a, splits[i],factor);
}
}
@Override
public void addMultiple(AVector a, int aOffset, double factor) {
for (int i=0; i<n; i++) {
vecs[i].addMultiple(a, aOffset+splits[i],factor);
}
}
@Override
public void addProduct(AVector a, AVector b, double factor) {
for (int i=0; i<n; i++) {
int off=splits[i];
vecs[i].addProduct(a, off,b,off,factor);
}
}
@Override
public void addProduct(AVector a, int aOffset, AVector b, int bOffset, double factor) {
for (int i=0; i<n; i++) {
int off=splits[i];
vecs[i].addProduct(a, aOffset+off,b,bOffset+off,factor);
}
}
@Override
public void signum() {
for (int i=0; i<n; i++) {
vecs[i].signum();
}
}
@Override
public void abs() {
for (int i=0; i<n; i++) {
vecs[i].abs();
}
}
@Override
public void exp() {
for (int i=0; i<n; i++) {
vecs[i].exp();
}
}
@Override
public void log() {
for (int i=0; i<n; i++) {
vecs[i].log();
}
}
@Override
public void negate() {
for (int i=0; i<n; i++) {
vecs[i].negate();
}
}
@Override
public void applyOp(Op op) {
for (int i=0; i<n; i++) {
vecs[i].applyOp(op);
}
}
@Override
public double elementSum() {
double result=0.0;
for (int i=0; i<n; i++) {
result+=vecs[i].elementSum();
}
return result;
}
@Override
public double elementProduct() {
double result=1.0;
for (int i=0; i<n; i++) {
result*=vecs[i].elementProduct();
if (result==0.0) return 0.0;
}
return result;
}
@Override
public double elementMax() {
double result=vecs[0].elementMax();
for (int i=0; i<n; i++) {
double m=vecs[i].elementMax();
if (m>result) result=m;
}
return result;
}
@Override
public double elementMin() {
double result=vecs[0].elementMin();
for (int i=0; i<n; i++) {
double m=vecs[i].elementMin();
if (m<result) result=m;
}
return result;
}
@Override
public double magnitudeSquared() {
double result=0.0;
for (int i=0; i<n; i++) {
result+=vecs[i].magnitudeSquared();
}
return result;
}
@Override
public long nonZeroCount() {
long result=0;
for (int i=0; i<n; i++) {
result+=vecs[i].nonZeroCount();
}
return result;
}
@Override
public double get(int i) {
checkIndex(i);
int j=IntArrays.indexLookup(splits, i);
return vecs[j].unsafeGet(i-splits[j]);
}
@Override
public void set(AVector src) {
checkSameLength(src);
set(src,0);
}
@Override
public double unsafeGet(int i) {
int j=IntArrays.indexLookup(splits, i);
return vecs[j].unsafeGet(i-splits[j]);
}
@Override
public void set(AVector src, int srcOffset) {
for (int i=0; i<n; i++) {
vecs[i].set(src,srcOffset+splits[i]);
}
}
@Override
public void set(int i, double value) {
checkIndex(i);
unsafeSet(i,value);
}
@Override
public void unsafeSet(int i, double value) {
int j=IntArrays.indexLookup(splits, i);
vecs[j].unsafeSet(i-splits[j], value);
}
@Override
public void fill(double value) {
for (int i=0; i<n; i++) {
vecs[i].fill(value);
}
}
@Override
public void square() {
for (int i=0; i<n; i++) {
vecs[i].square();
}
}
@Override
public void sqrt() {
for (int i=0; i<n; i++) {
vecs[i].sqrt();
}
}
@Override
public void tanh() {
for (int i=0; i<n; i++) {
vecs[i].tanh();
}
}
@Override
public void logistic() {
for (int i=0; i<n; i++) {
vecs[i].logistic();
}
}
@Override
public void multiply(double value) {
for (int i=0; i<n; i++) {
vecs[i].multiply(value);
}
}
@Override
public double[] toDoubleArray() {
double[] data=new double[length];
for (int i=0; i<n; i++) {
vecs[i].copyTo(data,splits[i]);
}
return data;
}
@Override
public boolean equals(AVector v) {
if (v instanceof JoinedMultiVector) return equals((JoinedMultiVector)v);
if (v instanceof ADenseArrayVector) {
ADenseArrayVector av=(ADenseArrayVector) v;
return equalsArray(av.getArray(),av.getArrayOffset());
}
return super.equals(v);
}
public boolean equals(JoinedMultiVector v) {
if (IntArrays.equals(splits, v.splits)) {
for (int i=0; i<n; i++) {
if (!vecs[i].equals(v.vecs[i])) return false;
}
}
return super.equals(v);
}
@Override
public boolean equalsArray(double[] data, int offset) {
for (int i=0; i<n; i++) {
if (!vecs[i].equalsArray(data,offset+splits[i])) return false;
}
return true;
}
@Override
public JoinedMultiVector exactClone() {
AVector[] nvecs=new AVector[n];
for (int i=0; i<n; i++) {
nvecs[i]=vecs[i].exactClone();
}
return new JoinedMultiVector(nvecs);
}
public static AVector wrap(AVector... vecs) {
return new JoinedMultiVector(vecs);
}
public static AVector create(AVector... vecs) {
return new JoinedMultiVector(vecs.clone());
}
@Override
public void validate() {
super.validate();
if (splits[n]!=length) throw new VectorzException("Unexpected final slit position - not equal to JoinedMultVector length");
}
@Override
public int segmentCount() {
return n;
}
@Override
public AVector getSegment(int k) {
return vecs[k];
}
@Override
protected AJoinedVector reconstruct(AVector... segments) {
return new JoinedMultiVector(length,vecs,splits);
}
}