package mikera.vectorz;
import mikera.arrayz.INDArray;
import mikera.matrixx.AMatrix;
import mikera.transformz.ATransform;
import mikera.transformz.impl.AOpTransform;
import mikera.vectorz.impl.ADenseArrayVector;
import mikera.vectorz.ops.Composed;
import mikera.vectorz.ops.Derivative;
import mikera.vectorz.ops.Division;
import mikera.vectorz.ops.Inverse;
import mikera.vectorz.ops.Product;
import mikera.vectorz.ops.Sum;
/**
* Abstract class for representing a unary operation
*
* @author Mike
*/
public abstract class Op implements IOperator {
public abstract double apply(double x);
/**
* Applies the inverse of this Op. Throws an error if the inverse function does not exist.
* Returns Double.NaN if no inverse exists for the specific value of y.
*
* @param y
* @return
*/
public double applyInverse(double y) {
throw new UnsupportedOperationException("Inverse not defined for operator: "+this.toString());
}
@Override
public void applyTo(AVector v) {
if (v instanceof ADenseArrayVector) {
applyTo((ADenseArrayVector)v);
} else {
v.applyOp(this);
}
}
public void applyTo(AMatrix m) {
m.applyOp(this);
}
@Override
public void applyTo(AVector v, int start, int length) {
if (start<0) throw new IllegalArgumentException("Negative start position: "+start);
if ((start==0)&&(length==v.length())) {
v.applyOp(this);
} else {
v.subVector(start, length).applyOp(this);
}
}
public void applyTo(AScalar s) {
s.set(apply(s.get()));
}
public void applyTo(ADenseArrayVector v) {
applyTo(v.getArray(), v.getArrayOffset(),v.length());
}
public void applyTo(INDArray a) {
if (a instanceof AVector) {
applyTo((AVector)a);
} else if (a instanceof AMatrix) {
applyTo((AMatrix)a);
} else if (a instanceof AScalar) {
applyTo((AScalar)a);
} else {
a.applyOp(this);
}
}
@Override
public void applyTo(double[] data, int start, int length) {
for (int i=0; i<length; i++) {
double x=data[start+i];
data[start+i]=apply(x);
}
}
public void applyTo(double[] data) {
applyTo(data,0,data.length);
}
@Override
public ATransform getTransform(int dims) {
return new AOpTransform(this,dims);
}
@Override
public Op getInverse() {
if (hasInverse()) {
return new Inverse(this);
} else {
throw new UnsupportedOperationException("No inverse available: "+this.getClass());
}
}
public boolean hasDerivative() {
return false;
}
public boolean hasDerivativeForOutput() {
return hasDerivative();
}
public boolean hasInverse() {
return false;
}
/**
* Returns the derivative of this Op for a given output value y
*
* i.e. f'(g(y)) where f is the operator, g is the inverse of f
*
* @param y
* @return
*/
public double derivativeForOutput(double y) {
assert(!hasDerivative());
throw new UnsupportedOperationException("No derivative defined for "+this.toString());
}
/**
* Returns the derivative of this Op for a given input value x
*
* i.e. f'(x) where f is the operator
*
* @param y
* @return
*/
public double derivative(double x) {
assert(!hasDerivative());
return derivativeForOutput(apply(x));
}
/**
* Returns true if the operator is stochastic, i.e returns random values for at least some inputs
* @return
*/
public boolean isStochastic() {
return false;
}
public abstract double averageValue();
public double minValue() {
return Double.NEGATIVE_INFINITY;
}
public double maxValue() {
return Double.POSITIVE_INFINITY;
}
public double minDomain() {
return Double.NEGATIVE_INFINITY;
}
public double maxDomain() {
return Double.POSITIVE_INFINITY;
}
public boolean isDomainBounded() {
return (minDomain()>=-Double.MAX_VALUE)||(maxDomain()<=Double.MAX_VALUE);
}
/**
* Validates whether all values in a double[] are within the possible output range for this Op
* @param output
* @return
*/
public boolean validateOutput(double[] output) {
double min=minValue();
double max=maxValue();
for (double d: output) {
if ((d<min)||(d>max)) return false;
}
return true;
}
/**
* Creates a copy of the values of src in dest, constraining them to be within the valid
* range of output values from this Op
* @param src
* @param dest
* @param offset
* @param length
*/
public void constrainValues(double[] src, double[] dest, int offset, int length) {
if (!isBounded()) {
System.arraycopy(src, 0, dest, offset, length);
}
double min=minValue();
double max=maxValue();
for (int i=offset; i<(offset+length); i++) {
double v=src[i];
if (v>max) {
dest[i]=max;
} else if (v<min) {
dest[i]=min;
} else {
dest[i]=v;
}
}
}
public boolean isBounded() {
return (minValue()>=-Double.MAX_VALUE)||(maxValue()<=Double.MAX_VALUE);
}
public Op getDerivativeOp() {
return new Derivative(this);
}
public static Op compose(Op op1, Op op2) {
return Ops.compose(op1,op2);
}
public Op compose(Op op) {
return Composed.create(this, op);
}
public Op product(Op op) {
return Product.create(this, op);
}
public Op divide(Op op) {
return Division.create(this, op);
}
public Op sum(Op op) {
return Sum.create(this, op);
}
@Override public String toString() {
return getClass().toString();
}
}