public final static String SQL_UID = "xoruser";
public final static String SQL_PWD = "xorpassword";
public static void main(final String args[]) {
BasicNetwork network = new BasicNetwork();
network.addLayer(new BasicLayer(2));
network.addLayer(new BasicLayer(2));
network.addLayer(new BasicLayer(1));
network.getStructure().finalizeStructure();
network.reset();
MLDataSet trainingSet = new SQLNeuralDataSet(
XORSQL.SQL,
XORSQL.INPUT_SIZE,
XORSQL.IDEAL_SIZE,
XORSQL.SQL_DRIVER,
XORSQL.SQL_URL,
XORSQL.SQL_UID,
XORSQL.SQL_PWD);
// train the neural network
final MLTrain train = new ResilientPropagation(network, trainingSet);
// reset if improve is less than 1% over 5 cycles
train.addStrategy(new RequiredImprovementStrategy(5));
int epoch = 1;
do {
train.iteration();
System.out
.println("Epoch #" + epoch + " Error:" + train.getError());
epoch++;
} while(train.getError() > 0.01);
// test the neural network
System.out.println("Neural Network Results:");
for(MLDataPair pair: trainingSet ) {
final MLData output = network.compute(pair.getInput());
System.out.println(pair.getInput().getData(0) + "," + pair.getInput().getData(1)
+ ", actual=" + output.getData(0) + ",ideal=" + pair.getIdeal().getData(0));
}
}