String featureDataPath = args[1];
String resultDataPath = args[2];
String modelPath = args[3];
SmallLayeredNeuralNetwork ann = new SmallLayeredNeuralNetwork(modelPath);
// process data in streaming approach
FileSystem fs = FileSystem.get(new URI(featureDataPath), conf);
BufferedReader br = new BufferedReader(new InputStreamReader(
fs.open(new Path(featureDataPath))));
Path outputPath = new Path(resultDataPath);
if (fs.exists(outputPath)) {
fs.delete(outputPath, true);
}
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(
fs.create(outputPath)));
String line = null;
while ((line = br.readLine()) != null) {
if (line.trim().length() == 0) {
continue;
}
String[] tokens = line.trim().split(",");
double[] vals = new double[tokens.length];
for (int i = 0; i < tokens.length; ++i) {
vals[i] = Double.parseDouble(tokens[i]);
}
DoubleVector instance = new DenseDoubleVector(vals);
DoubleVector result = ann.getOutput(instance);
double[] arrResult = result.toArray();
StringBuilder sb = new StringBuilder();
for (int i = 0; i < arrResult.length; ++i) {
sb.append(arrResult[i]);
if (i != arrResult.length - 1) {
sb.append(",");
} else {
sb.append("\n");
}
}
bw.write(sb.toString());
}
br.close();
bw.close();
} else if (mode.equals("train")) {
if (args.length < 5) {
printUsage();
return;
}
String trainingDataPath = args[1];
String trainedModelPath = args[2];
int featureDimension = Integer.parseInt(args[3]);
int labelDimension = Integer.parseInt(args[4]);
int iteration = 1000;
double learningRate = 0.4;
double momemtumWeight = 0.2;
double regularizationWeight = 0.01;
// parse parameters
if (args.length >= 6) {
try {
iteration = Integer.parseInt(args[5]);
System.out.printf("Iteration: %d\n", iteration);
} catch (NumberFormatException e) {
System.err
.println("MAX_ITERATION format invalid. It should be a positive number.");
return;
}
}
if (args.length >= 7) {
try {
learningRate = Double.parseDouble(args[6]);
System.out.printf("Learning rate: %f\n", learningRate);
} catch (NumberFormatException e) {
System.err
.println("LEARNING_RATE format invalid. It should be a positive double in range (0, 1.0)");
return;
}
}
if (args.length >= 8) {
try {
momemtumWeight = Double.parseDouble(args[7]);
System.out.printf("Momemtum weight: %f\n", momemtumWeight);
} catch (NumberFormatException e) {
System.err
.println("MOMEMTUM_WEIGHT format invalid. It should be a positive double in range (0, 1.0)");
return;
}
}
if (args.length >= 9) {
try {
regularizationWeight = Double.parseDouble(args[8]);
System.out
.printf("Regularization weight: %f\n", regularizationWeight);
} catch (NumberFormatException e) {
System.err
.println("REGULARIZATION_WEIGHT format invalid. It should be a positive double in range (0, 1.0)");
return;
}
}
// train the model
SmallLayeredNeuralNetwork ann = new SmallLayeredNeuralNetwork();
ann.setLearningRate(learningRate);
ann.setMomemtumWeight(momemtumWeight);
ann.setRegularizationWeight(regularizationWeight);
ann.addLayer(featureDimension, false,
FunctionFactory.createDoubleFunction("Sigmoid"));
ann.addLayer(featureDimension, false,
FunctionFactory.createDoubleFunction("Sigmoid"));
ann.addLayer(labelDimension, true,
FunctionFactory.createDoubleFunction("Sigmoid"));
ann.setCostFunction(FunctionFactory
.createDoubleDoubleFunction("CrossEntropy"));
ann.setModelPath(trainedModelPath);
Map<String, String> trainingParameters = new HashMap<String, String>();
trainingParameters.put("tasks", "5");
trainingParameters.put("training.max.iterations", "" + iteration);
trainingParameters.put("training.batch.size", "300");
trainingParameters.put("convergence.check.interval", "1000");
ann.train(new Path(trainingDataPath), trainingParameters);
}
}