JavaSparkContext context = SparkContextBuild.getContext(args);
JavaRDD<SampleVector> rdds = context.parallelize(trainList);
rdds.count();
logger.info("RDD ok.");
AutoEncoder da = new AutoEncoder(x_feature, n_hidden);
SGDTrainConfig config = new SGDTrainConfig();
config.setUseCG(true);
config.setDoCorruption(true);
config.setCorruption_level(0.25);
config.setCgEpochStep(50);
config.setCgTolerance(0);
config.setCgMaxIterations(10);
config.setMaxEpochs(50);
config.setNbrModelReplica(4);
config.setMinLoss(0.01);
config.setUseRegularization(true);
config.setMrDataStorage(StorageLevel.MEMORY_ONLY());
config.setPrintLoss(true);
config.setLossCalStep(3);
logger.info("Start to train dA.");
DownpourSGDTrain.train(da, rdds, config);
double[] reconstruct_x = new double[x_feature];
double totalError = 0;
for(SampleVector test : testList) {
da.reconstruct(test.getX(), reconstruct_x);
totalError += ClassVerify.squaredError(test.getX(), reconstruct_x);
}
logger.info("Mean square error is " + totalError / testList.size());
} catch(Throwable e) {
logger.error("", e);