String strDataPath = "/tmp/xor-training-by-xor";
Path dataPath = new Path(strDataPath);
// generate training data
DoubleVector[] trainingData = new DenseDoubleVector[] {
new DenseDoubleVector(new double[] { 0, 0, 0 }),
new DenseDoubleVector(new double[] { 0, 1, 1 }),
new DenseDoubleVector(new double[] { 1, 0, 1 }),
new DenseDoubleVector(new double[] { 1, 1, 0 }) };
try {
URI uri = new URI(strDataPath);
FileSystem fs = FileSystem.get(uri, conf);
fs.delete(dataPath, true);
if (!fs.exists(dataPath)) {
fs.createNewFile(dataPath);
SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf,
dataPath, LongWritable.class, VectorWritable.class);
for (int i = 0; i < 1000; ++i) {
VectorWritable vecWritable = new VectorWritable(trainingData[i % 4]);
writer.append(new LongWritable(i), vecWritable);
}
writer.close();
}
} catch (Exception e) {
e.printStackTrace();
}
// begin training
String modelPath = "/tmp/xorModel-training-by-xor.data";
double learningRate = 0.6;
double regularization = 0.02; // no regularization
double momentum = 0.3; // no momentum
String squashingFunctionName = "Tanh";
String costFunctionName = "SquaredError";
int[] layerSizeArray = new int[] { 2, 5, 1 };
SmallMultiLayerPerceptron mlp = new SmallMultiLayerPerceptron(learningRate,
regularization, momentum, squashingFunctionName, costFunctionName,
layerSizeArray);
Map<String, String> trainingParams = new HashMap<String, String>();
trainingParams.put("training.iteration", "1000");
trainingParams.put("training.mode", "minibatch.gradient.descent");
trainingParams.put("training.batch.size", "100");
trainingParams.put("tasks", "3");
trainingParams.put("modelPath", modelPath);
try {
mlp.train(dataPath, trainingParams);
} catch (Exception e) {
e.printStackTrace();
}
// test the model
for (int i = 0; i < trainingData.length; ++i) {
DenseDoubleVector testVec = (DenseDoubleVector) trainingData[i].slice(2);
try {
DenseDoubleVector actual = (DenseDoubleVector) mlp.output(testVec);
assertEquals(trainingData[i].toArray()[2], actual.get(0), 0.2);
} catch (Exception e) {
e.printStackTrace();
}
}
}