/*
* This file is part of ALOE.
*
* ALOE is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
* ALOE is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
* You should have received a copy of the GNU General Public License
* along with ALOE. If not, see <http://www.gnu.org/licenses/>.
*
* Copyright (c) 2012 SCCL, University of Washington (http://depts.washington.edu/sccl)
*/
package etc.aloe;
import etc.aloe.controllers.CrossValidationController;
import etc.aloe.controllers.TrainingController;
import etc.aloe.data.EvaluationReport;
import etc.aloe.data.FeatureSpecification;
import etc.aloe.data.MessageSet;
import etc.aloe.data.Model;
import etc.aloe.data.ROC;
import etc.aloe.data.SegmentSet;
import etc.aloe.options.ModeOptions;
import etc.aloe.options.TrainOptions;
import etc.aloe.processes.Segmentation;
import java.io.File;
import java.util.List;
import java.util.Map;
import weka.core.Instances;
/**
* Class that takes input training data, uses cross validation to evaluate the
* model, then trains a final model on the full training set.
*
* @author Michael Brooks <mjbrooks@uw.edu>
*/
public class AloeTrain extends Aloe {
@Override
public void run(ModeOptions modeOptions) {
System.out.println("== Preparation ==");
if (modeOptions instanceof TrainOptions) {
TrainOptions options = (TrainOptions) modeOptions;
saveCommand(options.outputCommandFile);
//Get and preprocess the data
MessageSet messages = this.loadMessages(options.inputCSVFile);
Segmentation segmentation = factory.constructSegmentation();
SegmentSet segments = segmentation.segment(messages);
//Set up a cross validation controller.
CrossValidationController crossValidationController = new CrossValidationController();
//Configure controller
factory.configureCrossValidation(crossValidationController);
//Run cross validation
crossValidationController.setSegmentSet(segments);
crossValidationController.run();
EvaluationReport evalReport = crossValidationController.getEvaluationReport();
if (evalReport != null) {
System.out.println("== Saving Results of Cross Validation ==");
saveEvaluationReport(evalReport, options.outputEvaluationReportFile);
System.out.println("Aggregated cross-validation report:");
System.out.println(evalReport);
System.out.println("---------");
if (options.makeROC) {
options.outputROCDir.mkdirs();
for (ROC roc : evalReport.getROCs()) {
String fileName = roc.getName() + FileNames.ROC_SUFFIX;
File outputFile = new File(options.outputROCDir, fileName);
saveROC(roc, outputFile);
}
}
if (options.outputTests) {
options.outputTestsDir.mkdirs();
List<SegmentSet> testSets = evalReport.getTestSets();
List<String> testSetNames = evalReport.getTestSetNames();
SegmentSet combined = new SegmentSet();
for (int i = 0; i < testSets.size(); i++) {
String fileName = testSetNames.get(i) + FileNames.TEST_DATA_SUFFIX;
SegmentSet testSet = testSets.get(i);
combined.addAll(testSet.getSegments());
File outputFile = new File(options.outputTestsDir, fileName);
saveMessages(testSet.getMessages(messages), outputFile);
}
String fileName = FileNames.OUTPUT_TEST_DATA_COMBINED_NAME;
File outputFile = new File(options.outputTestsDir, fileName);
saveMessages(combined.getMessages(messages), outputFile);
}
}
//Create a training controller for making the final model
TrainingController trainingController = new TrainingController();
//Configure the training controller
factory.configureTraining(trainingController);
//Run the full training
trainingController.setSegmentSet(segments);
trainingController.run();
//Get the fruits of our labors
System.out.println("== Saving Output ==");
FeatureSpecification spec = trainingController.getFeatureSpecification();
Model model = trainingController.getModel();
List<String> topFeatures = trainingController.getTopFeatures();
List<Map.Entry<String, Double>> featureWeights = trainingController.getFeatureWeights();
saveFeatureSpecification(spec, options.outputFeatureSpecFile);
saveModel(model, options.outputModelFile);
saveTopFeatures(topFeatures, options.outputTopFeaturesFile);
saveFeatureWeights(featureWeights, options.outputFeatureWeightsFile);
if (options.outputFeatureValues) {
Instances featureValues = trainingController.getFeatureValues();
saveInstances(featureValues, options.outputFeatureValuesFile);
}
} else {
throw new IllegalArgumentException("Options must be for Training");
}
}
}