int testMode = 0;
int numFolds = 10;
double percent = 66;
int classIndex = m_ClassCombo.getSelectedIndex();
inst.setClassIndex(classIndex);
Classifier classifier = (Classifier) m_ClassifierEditor.getValue();
Classifier template = null;
try {
template = AbstractClassifier.makeCopy(classifier);
} catch (Exception ex) {
m_Log.logMessage("Problem copying classifier: " + ex.getMessage());
}
Classifier fullClassifier = null;
StringBuffer outBuff = new StringBuffer();
AbstractOutput classificationOutput = null;
if (outputPredictionsText) {
classificationOutput = (AbstractOutput) m_ClassificationOutputEditor.getValue();
Instances header = new Instances(inst, 0);
header.setClassIndex(classIndex);
classificationOutput.setHeader(header);
classificationOutput.setBuffer(outBuff);
}
String name = (new SimpleDateFormat("HH:mm:ss - ")).format(new Date());
String cname = "";
String cmd = "";
Evaluation eval = null;
try {
if (m_CVBut.isSelected()) {
testMode = 1;
numFolds = Integer.parseInt(m_CVText.getText());
if (numFolds <= 1) {
throw new Exception("Number of folds must be greater than 1");
}
} else if (m_PercentBut.isSelected()) {
testMode = 2;
percent = Double.parseDouble(m_PercentText.getText());
if ((percent <= 0) || (percent >= 100)) {
throw new Exception("Percentage must be between 0 and 100");
}
} else if (m_TrainBut.isSelected()) {
testMode = 3;
} else if (m_TestSplitBut.isSelected()) {
testMode = 4;
// Check the test instance compatibility
if (source == null) {
throw new Exception("No user test set has been specified");
}
if (!(classifier instanceof weka.classifiers.misc.InputMappedClassifier)) {
if (!inst.equalHeaders(userTestStructure)) {
boolean wrapClassifier = false;
if (!Utils.
getDontShowDialog("weka.gui.explorer.ClassifierPanel.AutoWrapInInputMappedClassifier")) {
JCheckBox dontShow = new JCheckBox("Do not show this message again");
Object[] stuff = new Object[2];
stuff[0] = "Train and test set are not compatible.\n" +
"Would you like to automatically wrap the classifier in\n" +
"an \"InputMappedClassifier\" before proceeding?.\n";
stuff[1] = dontShow;
int result = JOptionPane.showConfirmDialog(ClassifierPanel.this, stuff,
"ClassifierPanel", JOptionPane.YES_OPTION);
if (result == JOptionPane.YES_OPTION) {
wrapClassifier = true;
}
if (dontShow.isSelected()) {
String response = (wrapClassifier) ? "yes" : "no";
Utils.
setDontShowDialogResponse("weka.gui.explorer.ClassifierPanel.AutoWrapInInputMappedClassifier",
response);
}
} else {
// What did the user say - do they want to autowrap or not?
String response =
Utils.getDontShowDialogResponse("weka.gui.explorer.ClassifierPanel.AutoWrapInInputMappedClassifier");
if (response != null && response.equalsIgnoreCase("yes")) {
wrapClassifier = true;
}
}
if (wrapClassifier) {
weka.classifiers.misc.InputMappedClassifier temp =
new weka.classifiers.misc.InputMappedClassifier();
// pass on the known test structure so that we get the
// correct mapping report from the toString() method
// of InputMappedClassifier
temp.setClassifier(classifier);
temp.setTestStructure(userTestStructure);
classifier = temp;
} else {
throw new Exception("Train and test set are not compatible\n" + inst.equalHeadersMsg(userTestStructure));
}
}
}
} else {
throw new Exception("Unknown test mode");
}
cname = classifier.getClass().getName();
if (cname.startsWith("weka.classifiers.")) {
name += cname.substring("weka.classifiers.".length());
} else {
name += cname;
}
cmd = classifier.getClass().getName();
if (classifier instanceof OptionHandler)
cmd += " " + Utils.joinOptions(((OptionHandler) classifier).getOptions());
// set up the structure of the plottable instances for
// visualization
plotInstances = ExplorerDefaults.getClassifierErrorsPlotInstances();
plotInstances.setInstances(inst);
plotInstances.setClassifier(classifier);
plotInstances.setClassIndex(inst.classIndex());
plotInstances.setSaveForVisualization(saveVis);
// Output some header information
m_Log.logMessage("Started " + cname);
m_Log.logMessage("Command: " + cmd);
if (m_Log instanceof TaskLogger) {
((TaskLogger)m_Log).taskStarted();
}
outBuff.append("=== Run information ===\n\n");
outBuff.append("Scheme: " + cname);
if (classifier instanceof OptionHandler) {
String [] o = ((OptionHandler) classifier).getOptions();
outBuff.append(" " + Utils.joinOptions(o));
}
outBuff.append("\n");
outBuff.append("Relation: " + inst.relationName() + '\n');
outBuff.append("Instances: " + inst.numInstances() + '\n');
outBuff.append("Attributes: " + inst.numAttributes() + '\n');
if (inst.numAttributes() < 100) {
for (int i = 0; i < inst.numAttributes(); i++) {
outBuff.append(" " + inst.attribute(i).name()
+ '\n');
}
} else {
outBuff.append(" [list of attributes omitted]\n");
}
outBuff.append("Test mode: ");
switch (testMode) {
case 3: // Test on training
outBuff.append("evaluate on training data\n");
break;
case 1: // CV mode
outBuff.append("" + numFolds + "-fold cross-validation\n");
break;
case 2: // Percent split
outBuff.append("split " + percent
+ "% train, remainder test\n");
break;
case 4: // Test on user split
if (source.isIncremental())
outBuff.append("user supplied test set: "
+ " size unknown (reading incrementally)\n");
else
outBuff.append("user supplied test set: "
+ source.getDataSet().numInstances() + " instances\n");
break;
}
if (costMatrix != null) {
outBuff.append("Evaluation cost matrix:\n")
.append(costMatrix.toString()).append("\n");
}
outBuff.append("\n");
m_History.addResult(name, outBuff);
m_History.setSingle(name);
// Build the model and output it.
if (outputModel || (testMode == 3) || (testMode == 4)) {
m_Log.statusMessage("Building model on training data...");
trainTimeStart = System.currentTimeMillis();
classifier.buildClassifier(inst);
trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
}
if (outputModel) {
outBuff.append("=== Classifier model (full training set) ===\n\n");
outBuff.append(classifier.toString() + "\n");
outBuff.append("\nTime taken to build model: " +
Utils.doubleToString(trainTimeElapsed / 1000.0,2)
+ " seconds\n\n");
m_History.updateResult(name);
if (classifier instanceof Drawable) {
grph = null;
try {
grph = ((Drawable)classifier).graph();
} catch (Exception ex) {
}
}
// copy full model for output
SerializedObject so = new SerializedObject(classifier);
fullClassifier = (Classifier) so.getObject();
}
switch (testMode) {
case 3: // Test on training
m_Log.statusMessage("Evaluating on training data...");
eval = new Evaluation(inst, costMatrix);
// make adjustments if the classifier is an InputMappedClassifier
eval = setupEval(eval, classifier, inst, costMatrix,
plotInstances, classificationOutput, false);
//plotInstances.setEvaluation(eval);
plotInstances.setUp();
if (outputPredictionsText) {
printPredictionsHeader(outBuff, classificationOutput, "training set");
}
for (int jj=0;jj<inst.numInstances();jj++) {
plotInstances.process(inst.instance(jj), classifier, eval);
if (outputPredictionsText) {
classificationOutput.printClassification(classifier, inst.instance(jj), jj);
}
if ((jj % 100) == 0) {
m_Log.statusMessage("Evaluating on training data. Processed "
+jj+" instances...");
}
}
if (outputPredictionsText)
classificationOutput.printFooter();
if (outputPredictionsText && classificationOutput.generatesOutput()) {
outBuff.append("\n");
}
outBuff.append("=== Evaluation on training set ===\n");
break;
case 1: // CV mode
m_Log.statusMessage("Randomizing instances...");
int rnd = 1;
try {
rnd = Integer.parseInt(m_RandomSeedText.getText().trim());
// System.err.println("Using random seed "+rnd);
} catch (Exception ex) {
m_Log.logMessage("Trouble parsing random seed value");
rnd = 1;
}
Random random = new Random(rnd);
inst.randomize(random);
if (inst.attribute(classIndex).isNominal()) {
m_Log.statusMessage("Stratifying instances...");
inst.stratify(numFolds);
}
eval = new Evaluation(inst, costMatrix);
// make adjustments if the classifier is an InputMappedClassifier
eval = setupEval(eval, classifier, inst, costMatrix,
plotInstances, classificationOutput, false);
// plotInstances.setEvaluation(eval);
plotInstances.setUp();
if (outputPredictionsText) {
printPredictionsHeader(outBuff, classificationOutput, "test data");
}
// Make some splits and do a CV
for (int fold = 0; fold < numFolds; fold++) {
m_Log.statusMessage("Creating splits for fold "
+ (fold + 1) + "...");
Instances train = inst.trainCV(numFolds, fold, random);
// make adjustments if the classifier is an InputMappedClassifier
eval = setupEval(eval, classifier, train, costMatrix,
plotInstances, classificationOutput, true);
// eval.setPriors(train);
m_Log.statusMessage("Building model for fold "
+ (fold + 1) + "...");
Classifier current = null;
try {
current = AbstractClassifier.makeCopy(template);
} catch (Exception ex) {
m_Log.logMessage("Problem copying classifier: " + ex.getMessage());
}
current.buildClassifier(train);
Instances test = inst.testCV(numFolds, fold);
m_Log.statusMessage("Evaluating model for fold "
+ (fold + 1) + "...");
for (int jj=0;jj<test.numInstances();jj++) {
plotInstances.process(test.instance(jj), current, eval);
if (outputPredictionsText) {
classificationOutput.printClassification(current, test.instance(jj), jj);
}
}
}
if (outputPredictionsText)
classificationOutput.printFooter();
if (outputPredictionsText) {
outBuff.append("\n");
}
if (inst.attribute(classIndex).isNominal()) {
outBuff.append("=== Stratified cross-validation ===\n");
} else {
outBuff.append("=== Cross-validation ===\n");
}
break;
case 2: // Percent split
if (!m_PreserveOrderBut.isSelected()) {
m_Log.statusMessage("Randomizing instances...");
try {
rnd = Integer.parseInt(m_RandomSeedText.getText().trim());
} catch (Exception ex) {
m_Log.logMessage("Trouble parsing random seed value");
rnd = 1;
}
inst.randomize(new Random(rnd));
}
int trainSize = (int) Math.round(inst.numInstances() * percent / 100);
int testSize = inst.numInstances() - trainSize;
Instances train = new Instances(inst, 0, trainSize);
Instances test = new Instances(inst, trainSize, testSize);
m_Log.statusMessage("Building model on training split ("+trainSize+" instances)...");
Classifier current = null;
try {
current = AbstractClassifier.makeCopy(template);
} catch (Exception ex) {
m_Log.logMessage("Problem copying classifier: " + ex.getMessage());
}
current.buildClassifier(train);
eval = new Evaluation(train, costMatrix);
// make adjustments if the classifier is an InputMappedClassifier
eval = setupEval(eval, classifier, train, costMatrix,
plotInstances, classificationOutput, false);