JavaSparkContext context = SparkContextBuild.getContext(args);
JavaRDD<SampleVector> rdds = context.parallelize(trainList);
rdds.count();
logger.info("RDD ok.");
LR lr = new LR(x_feature, y_feature);
SGDTrainConfig config = new SGDTrainConfig();
config.setUseCG(true);
config.setCgEpochStep(100);
config.setCgTolerance(0);
config.setCgMaxIterations(30);
config.setMaxEpochs(100);
config.setNbrModelReplica(4);
config.setMinLoss(0.01);
config.setUseRegularization(true);
config.setMrDataStorage(StorageLevel.MEMORY_ONLY());
config.setPrintLoss(true);
config.setLossCalStep(3);
config.setParamOutput(true);
config.setParamOutputStep(3);
config.setParamOutputPath("wb.bin");
logger.info("Start to train lr.");
DownpourSGDTrain.train(lr, rdds, config);
int trueCount = 0;
int falseCount = 0;
double[] predict_y = new double[y_feature];
for(SampleVector test : testList) {
lr.predict(test.getX(), predict_y);
if(ClassVerify.classTrue(test.getY(), predict_y)) {
trueCount++;
}
else {
falseCount++;