name += cname;
}
String cmd = m_ClassifierEditor.getValue().getClass().getName();
if (m_ClassifierEditor.getValue() instanceof OptionHandler)
cmd += " " + Utils.joinOptions(((OptionHandler) m_ClassifierEditor.getValue()).getOptions());
Evaluation eval = null;
try {
if (m_CVBut.isSelected()) {
testMode = 1;
numFolds = Integer.parseInt(m_CVText.getText());
if (numFolds <= 1) {
throw new Exception("Number of folds must be greater than 1");
}
} else if (m_PercentBut.isSelected()) {
testMode = 2;
percent = Double.parseDouble(m_PercentText.getText());
if ((percent <= 0) || (percent >= 100)) {
throw new Exception("Percentage must be between 0 and 100");
}
} else if (m_TrainBut.isSelected()) {
testMode = 3;
} else if (m_TestSplitBut.isSelected()) {
testMode = 4;
// Check the test instance compatibility
if (source == null) {
throw new Exception("No user test set has been specified");
}
if (!inst.equalHeaders(userTestStructure)) {
throw new Exception("Train and test set are not compatible");
}
userTestStructure.setClassIndex(classIndex);
} else {
throw new Exception("Unknown test mode");
}
inst.setClassIndex(classIndex);
// set up the structure of the plottable instances for
// visualization
if (saveVis) {
predInstances = setUpVisualizableInstances(inst);
predInstances.setClassIndex(inst.classIndex()+1);
}
// Output some header information
m_Log.logMessage("Started " + cname);
m_Log.logMessage("Command: " + cmd);
if (m_Log instanceof TaskLogger) {
((TaskLogger)m_Log).taskStarted();
}
outBuff.append("=== Run information ===\n\n");
outBuff.append("Scheme: " + cname);
if (classifier instanceof OptionHandler) {
String [] o = ((OptionHandler) classifier).getOptions();
outBuff.append(" " + Utils.joinOptions(o));
}
outBuff.append("\n");
outBuff.append("Relation: " + inst.relationName() + '\n');
outBuff.append("Instances: " + inst.numInstances() + '\n');
outBuff.append("Attributes: " + inst.numAttributes() + '\n');
if (inst.numAttributes() < 100) {
for (int i = 0; i < inst.numAttributes(); i++) {
outBuff.append(" " + inst.attribute(i).name()
+ '\n');
}
} else {
outBuff.append(" [list of attributes omitted]\n");
}
outBuff.append("Test mode: ");
switch (testMode) {
case 3: // Test on training
outBuff.append("evaluate on training data\n");
break;
case 1: // CV mode
outBuff.append("" + numFolds + "-fold cross-validation\n");
break;
case 2: // Percent split
outBuff.append("split " + percent
+ "% train, remainder test\n");
break;
case 4: // Test on user split
if (source.isIncremental())
outBuff.append("user supplied test set: "
+ " size unknown (reading incrementally)\n");
else
outBuff.append("user supplied test set: "
+ source.getDataSet().numInstances() + " instances\n");
break;
}
if (costMatrix != null) {
outBuff.append("Evaluation cost matrix:\n")
.append(costMatrix.toString()).append("\n");
}
outBuff.append("\n");
m_History.addResult(name, outBuff);
m_History.setSingle(name);
// Build the model and output it.
if (outputModel || (testMode == 3) || (testMode == 4)) {
m_Log.statusMessage("Building model on training data...");
trainTimeStart = System.currentTimeMillis();
classifier.buildClassifier(inst);
trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
}
if (outputModel) {
outBuff.append("=== Classifier model (full training set) ===\n\n");
outBuff.append(classifier.toString() + "\n");
outBuff.append("\nTime taken to build model: " +
Utils.doubleToString(trainTimeElapsed / 1000.0,2)
+ " seconds\n\n");
m_History.updateResult(name);
if (classifier instanceof Drawable) {
grph = null;
try {
grph = ((Drawable)classifier).graph();
} catch (Exception ex) {
}
}
// copy full model for output
SerializedObject so = new SerializedObject(classifier);
fullClassifier = (Classifier) so.getObject();
}
switch (testMode) {
case 3: // Test on training
m_Log.statusMessage("Evaluating on training data...");
eval = new Evaluation(inst, costMatrix);
if (outputPredictionsText) {
printPredictionsHeader(outBuff, inst, "training set");
}
for (int jj=0;jj<inst.numInstances();jj++) {
processClassifierPrediction(inst.instance(jj), classifier,
eval, predInstances, plotShape,
plotSize);
if (outputPredictionsText) {
outBuff.append(predictionText(classifier, inst.instance(jj), jj+1));
}
if ((jj % 100) == 0) {
m_Log.statusMessage("Evaluating on training data. Processed "
+jj+" instances...");
}
}
if (outputPredictionsText) {
outBuff.append("\n");
}
outBuff.append("=== Evaluation on training set ===\n");
break;
case 1: // CV mode
m_Log.statusMessage("Randomizing instances...");
int rnd = 1;
try {
rnd = Integer.parseInt(m_RandomSeedText.getText().trim());
// System.err.println("Using random seed "+rnd);
} catch (Exception ex) {
m_Log.logMessage("Trouble parsing random seed value");
rnd = 1;
}
Random random = new Random(rnd);
inst.randomize(random);
if (inst.attribute(classIndex).isNominal()) {
m_Log.statusMessage("Stratifying instances...");
inst.stratify(numFolds);
}
eval = new Evaluation(inst, costMatrix);
if (outputPredictionsText) {
printPredictionsHeader(outBuff, inst, "test data");
}
// Make some splits and do a CV
for (int fold = 0; fold < numFolds; fold++) {
m_Log.statusMessage("Creating splits for fold "
+ (fold + 1) + "...");
Instances train = inst.trainCV(numFolds, fold, random);
eval.setPriors(train);
m_Log.statusMessage("Building model for fold "
+ (fold + 1) + "...");
Classifier current = null;
try {
current = Classifier.makeCopy(template);
} catch (Exception ex) {
m_Log.logMessage("Problem copying classifier: " + ex.getMessage());
}
current.buildClassifier(train);
Instances test = inst.testCV(numFolds, fold);
m_Log.statusMessage("Evaluating model for fold "
+ (fold + 1) + "...");
for (int jj=0;jj<test.numInstances();jj++) {
processClassifierPrediction(test.instance(jj), current,
eval, predInstances, plotShape,
plotSize);
if (outputPredictionsText) {
outBuff.append(predictionText(current, test.instance(jj), jj+1));
}
}
}
if (outputPredictionsText) {
outBuff.append("\n");
}
if (inst.attribute(classIndex).isNominal()) {
outBuff.append("=== Stratified cross-validation ===\n");
} else {
outBuff.append("=== Cross-validation ===\n");
}
break;
case 2: // Percent split
if (!m_PreserveOrderBut.isSelected()) {
m_Log.statusMessage("Randomizing instances...");
try {
rnd = Integer.parseInt(m_RandomSeedText.getText().trim());
} catch (Exception ex) {
m_Log.logMessage("Trouble parsing random seed value");
rnd = 1;
}
inst.randomize(new Random(rnd));
}
int trainSize = (int) Math.round(inst.numInstances() * percent / 100);
int testSize = inst.numInstances() - trainSize;
Instances train = new Instances(inst, 0, trainSize);
Instances test = new Instances(inst, trainSize, testSize);
m_Log.statusMessage("Building model on training split ("+trainSize+" instances)...");
Classifier current = null;
try {
current = Classifier.makeCopy(template);
} catch (Exception ex) {
m_Log.logMessage("Problem copying classifier: " + ex.getMessage());
}
current.buildClassifier(train);
eval = new Evaluation(train, costMatrix);
m_Log.statusMessage("Evaluating on test split...");
if (outputPredictionsText) {
printPredictionsHeader(outBuff, inst, "test split");
}
for (int jj=0;jj<test.numInstances();jj++) {
processClassifierPrediction(test.instance(jj), current,
eval, predInstances, plotShape,
plotSize);
if (outputPredictionsText) {
outBuff.append(predictionText(current, test.instance(jj), jj+1));
}
if ((jj % 100) == 0) {
m_Log.statusMessage("Evaluating on test split. Processed "
+jj+" instances...");
}
}
if (outputPredictionsText) {
outBuff.append("\n");
}
outBuff.append("=== Evaluation on test split ===\n");
break;
case 4: // Test on user split
m_Log.statusMessage("Evaluating on test data...");
eval = new Evaluation(inst, costMatrix);
if (outputPredictionsText) {
printPredictionsHeader(outBuff, inst, "test set");
}
Instance instance;
int jj = 0;
while (source.hasMoreElements(userTestStructure)) {
instance = source.nextElement(userTestStructure);
processClassifierPrediction(instance, classifier,
eval, predInstances, plotShape,
plotSize);
if (outputPredictionsText) {
outBuff.append(predictionText(classifier, instance, jj+1));
}
if ((++jj % 100) == 0) {
m_Log.statusMessage("Evaluating on test data. Processed "
+jj+" instances...");
}
}
if (outputPredictionsText) {
outBuff.append("\n");
}
outBuff.append("=== Evaluation on test set ===\n");
break;
default:
throw new Exception("Test mode not implemented");
}
if (outputSummary) {
outBuff.append(eval.toSummaryString(outputEntropy) + "\n");
}
if (inst.attribute(classIndex).isNominal()) {
if (outputPerClass) {
outBuff.append(eval.toClassDetailsString() + "\n");
}
if (outputConfusion) {
outBuff.append(eval.toMatrixString() + "\n");
}
}
if ( (fullClassifier instanceof Sourcable)
&& m_OutputSourceCode.isSelected()) {
outBuff.append("=== Source code ===\n\n");
outBuff.append(
Evaluation.wekaStaticWrapper(
((Sourcable) fullClassifier),
m_SourceCodeClass.getText()));
}
m_History.updateResult(name);
m_Log.logMessage("Finished " + cname);
m_Log.statusMessage("OK");
} catch (Exception ex) {
ex.printStackTrace();
m_Log.logMessage(ex.getMessage());
JOptionPane.showMessageDialog(ClassifierPanel.this,
"Problem evaluating classifier:\n"
+ ex.getMessage(),
"Evaluate classifier",
JOptionPane.ERROR_MESSAGE);
m_Log.statusMessage("Problem evaluating classifier");
} finally {
try {
if (!saveVis && outputModel) {
FastVector vv = new FastVector();
vv.addElement(fullClassifier);
Instances trainHeader = new Instances(m_Instances, 0);
trainHeader.setClassIndex(classIndex);
vv.addElement(trainHeader);
if (grph != null) {
vv.addElement(grph);
}
m_History.addObject(name, vv);
} else if (saveVis && predInstances != null &&
predInstances.numInstances() > 0) {
if (predInstances.attribute(predInstances.classIndex())
.isNumeric()) {
postProcessPlotInfo(plotSize);
}
m_CurrentVis = new VisualizePanel();
m_CurrentVis.setName(name+" ("+inst.relationName()+")");
m_CurrentVis.setLog(m_Log);
PlotData2D tempd = new PlotData2D(predInstances);
tempd.setShapeSize(plotSize);
tempd.setShapeType(plotShape);
tempd.setPlotName(name+" ("+inst.relationName()+")");
tempd.addInstanceNumberAttribute();
m_CurrentVis.addPlot(tempd);
m_CurrentVis.setColourIndex(predInstances.classIndex()+1);
FastVector vv = new FastVector();
if (outputModel) {
vv.addElement(fullClassifier);
Instances trainHeader = new Instances(m_Instances, 0);
trainHeader.setClassIndex(classIndex);
vv.addElement(trainHeader);
if (grph != null) {
vv.addElement(grph);
}
}
vv.addElement(m_CurrentVis);
if ((eval != null) && (eval.predictions() != null)) {
vv.addElement(eval.predictions());
vv.addElement(inst.classAttribute());
}
m_History.addObject(name, vv);
}
} catch (Exception ex) {