this.updateWeightMatrices(updateMatrices);
}
@Override
public DoubleMatrix[] trainByInstance(DoubleVector trainingInstance) {
DoubleVector transformedVector = this.featureTransformer
.transform(trainingInstance.sliceUnsafe(this.layerSizeList.get(0) - 1));
int inputDimension = this.layerSizeList.get(0) - 1;
int outputDimension;
DoubleVector inputInstance = null;
DoubleVector labels = null;
if (this.learningStyle == LearningStyle.SUPERVISED) {
outputDimension = this.layerSizeList.get(this.layerSizeList.size() - 1);
// validate training instance
Preconditions.checkArgument(
inputDimension + outputDimension == trainingInstance.getDimension(),
String
.format(
"The dimension of training instance is %d, but requires %d.",
trainingInstance.getDimension(), inputDimension
+ outputDimension));
inputInstance = new DenseDoubleVector(this.layerSizeList.get(0));
inputInstance.set(0, 1); // add bias
// get the features from the transformed vector
for (int i = 0; i < inputDimension; ++i) {
inputInstance.set(i + 1, transformedVector.get(i));
}
// get the labels from the original training instance
labels = trainingInstance.sliceUnsafe(inputInstance.getDimension() - 1,
trainingInstance.getDimension() - 1);
} else if (this.learningStyle == LearningStyle.UNSUPERVISED) {
// labels are identical to input features
outputDimension = inputDimension;
// validate training instance
Preconditions.checkArgument(inputDimension == trainingInstance
.getDimension(), String.format(
"The dimension of training instance is %d, but requires %d.",
trainingInstance.getDimension(), inputDimension));
inputInstance = new DenseDoubleVector(this.layerSizeList.get(0));
inputInstance.set(0, 1); // add bias
// get the features from the transformed vector
for (int i = 0; i < inputDimension; ++i) {
inputInstance.set(i + 1, transformedVector.get(i));
}
// get the labels by copying the transformed vector
labels = transformedVector.deepCopy();
}
List<DoubleVector> internalResults = this.getOutputInternal(inputInstance);
DoubleVector output = internalResults.get(internalResults.size() - 1);
// get the training error
calculateTrainingError(labels,
output.deepCopy().sliceUnsafe(1, output.getDimension() - 1));
if (this.trainingMethod.equals(TrainingMethod.GRADIENT_DESCENT)) {
return this.trainByInstanceGradientDescent(labels, internalResults);
} else {
throw new IllegalArgumentException(