* @param instances the data to process
* @return the modified data
* @throws Exception in case the processing goes wrong
*/
protected Instances processPLS1(Instances instances) throws Exception {
Matrix X, X_trans, x;
Matrix y;
Matrix W, w;
Matrix T, t, t_trans;
Matrix P, p, p_trans;
double b;
Matrix b_hat;
int i;
int j;
Matrix X_new;
Matrix tmp;
Instances result;
Instances tmpInst;
// initialization
if (!isFirstBatchDone()) {
// split up data
X = getX(instances);
y = getY(instances);
X_trans = X.transpose();
// init
W = new Matrix(instances.numAttributes() - 1, getNumComponents());
P = new Matrix(instances.numAttributes() - 1, getNumComponents());
T = new Matrix(instances.numInstances(), getNumComponents());
b_hat = new Matrix(getNumComponents(), 1);
for (j = 0; j < getNumComponents(); j++) {
// 1. step: wj
w = X_trans.times(y);
normalizeVector(w);
setVector(w, W, j);
// 2. step: tj
t = X.times(w);
t_trans = t.transpose();
setVector(t, T, j);
// 3. step: ^bj
b = t_trans.times(y).get(0, 0) / t_trans.times(t).get(0, 0);
b_hat.set(j, 0, b);
// 4. step: pj
p = X_trans.times(t).times((double) 1 / t_trans.times(t).get(0, 0));
p_trans = p.transpose();
setVector(p, P, j);
// 5. step: Xj+1
X = X.minus(t.times(p_trans));
y = y.minus(t.times(b));
}
// W*(P^T*W)^-1
tmp = W.times(((P.transpose()).times(W)).inverse());
// X_new = X*W*(P^T*W)^-1
X_new = getX(instances).times(tmp);
// factor = W*(P^T*W)^-1 * b_hat
m_PLS1_RegVector = tmp.times(b_hat);
// save matrices
m_PLS1_P = P;
m_PLS1_W = W;
m_PLS1_b_hat = b_hat;
if (getPerformPrediction())
result = toInstances(getOutputFormat(), X_new, y);
else
result = toInstances(getOutputFormat(), X_new, getY(instances));
}
// prediction
else {
result = new Instances(getOutputFormat());
for (i = 0; i < instances.numInstances(); i++) {
// work on each instance
tmpInst = new Instances(instances, 0);
tmpInst.add((Instance) instances.instance(i).copy());
x = getX(tmpInst);
X = new Matrix(1, getNumComponents());
T = new Matrix(1, getNumComponents());
for (j = 0; j < getNumComponents(); j++) {
setVector(x, X, j);
// 1. step: tj = xj * wj
t = x.times(getVector(m_PLS1_W, j));