// Copy the current state of things
m_Log.statusMessage("Setting up...");
Classifier classifierToUse = classifier;
StringBuffer outBuff = m_History.getNamedBuffer(name);
DataSource source = null;
Instances userTestStructure = null;
ClassifierErrorsPlotInstances plotInstances = null;
CostMatrix costMatrix = null;
if (m_EvalWRTCostsBut.isSelected()) {
costMatrix = new CostMatrix((CostMatrix) m_CostMatrixEditor
.getValue());
}
boolean outputConfusion = m_OutputConfusionBut.isSelected();
boolean outputPerClass = m_OutputPerClassBut.isSelected();
boolean outputSummary = true;
boolean outputEntropy = m_OutputEntropyBut.isSelected();
boolean saveVis = m_StorePredictionsBut.isSelected();
boolean outputPredictionsText = (m_ClassificationOutputEditor.getValue().getClass() != Null.class);
String grph = null;
Evaluation eval = null;
try {
boolean incrementalLoader = (m_TestLoader instanceof IncrementalConverter);
if (m_TestLoader != null && m_TestLoader.getStructure() != null) {
m_TestLoader.reset();
source = new DataSource(m_TestLoader);
userTestStructure = source.getStructure();
userTestStructure.setClassIndex(m_TestClassIndex);
}
// Check the test instance compatibility
if (source == null) {
throw new Exception("No user test set has been specified");
}
if (trainHeader != null) {
boolean compatibilityProblem = false;
if (trainHeader.classIndex() >
userTestStructure.numAttributes()-1) {
compatibilityProblem = true;
//throw new Exception("Train and test set are not compatible");
}
userTestStructure.setClassIndex(trainHeader.classIndex());
if (!trainHeader.equalHeaders(userTestStructure)) {
compatibilityProblem = true;
// throw new Exception("Train and test set are not compatible:\n" + trainHeader.equalHeadersMsg(userTestStructure));
if (compatibilityProblem &&
!(classifierToUse instanceof weka.classifiers.misc.InputMappedClassifier)) {
boolean wrapClassifier = false;
if (!Utils.
getDontShowDialog("weka.gui.explorer.ClassifierPanel.AutoWrapInInputMappedClassifier")) {
JCheckBox dontShow = new JCheckBox("Do not show this message again");
Object[] stuff = new Object[2];
stuff[0] = "Data used to train model and test set are not compatible.\n" +
"Would you like to automatically wrap the classifier in\n" +
"an \"InputMappedClassifier\" before proceeding?.\n";
stuff[1] = dontShow;
int result = JOptionPane.showConfirmDialog(ClassifierPanel.this, stuff,
"ClassifierPanel", JOptionPane.YES_OPTION);
if (result == JOptionPane.YES_OPTION) {
wrapClassifier = true;
}
if (dontShow.isSelected()) {
String response = (wrapClassifier) ? "yes" : "no";
Utils.
setDontShowDialogResponse("weka.gui.explorer.ClassifierPanel.AutoWrapInInputMappedClassifier",
response);
}
} else {
// What did the user say - do they want to autowrap or not?
String response =
Utils.getDontShowDialogResponse("weka.gui.explorer.ClassifierPanel.AutoWrapInInputMappedClassifier");
if (response != null && response.equalsIgnoreCase("yes")) {
wrapClassifier = true;
}
}
if (wrapClassifier) {
weka.classifiers.misc.InputMappedClassifier temp =
new weka.classifiers.misc.InputMappedClassifier();
temp.setClassifier(classifierToUse);
temp.setModelHeader(trainHeader);
classifierToUse = temp;
} else {
throw new Exception("Train and test set are not compatible\n" +
trainHeader.equalHeadersMsg(userTestStructure));
}
}
}
} else {
if (classifierToUse instanceof PMMLClassifier) {
// set the class based on information in the mining schema
Instances miningSchemaStructure =
((PMMLClassifier)classifierToUse).getMiningSchema().getMiningSchemaAsInstances();
String className = miningSchemaStructure.classAttribute().name();
Attribute classMatch = userTestStructure.attribute(className);
if (classMatch == null) {
throw new Exception("Can't find a match for the PMML target field "
+ className + " in the "
+ "test instances!");
}
userTestStructure.setClass(classMatch);
} else {
userTestStructure.
setClassIndex(userTestStructure.numAttributes()-1);
}
}
if (m_Log instanceof TaskLogger) {
((TaskLogger)m_Log).taskStarted();
}
m_Log.statusMessage("Evaluating on test data...");
m_Log.logMessage("Re-evaluating classifier (" + name
+ ") on test set");
eval = new Evaluation(userTestStructure, costMatrix);
// set up the structure of the plottable instances for
// visualization if selected
if (saveVis) {
plotInstances = new ClassifierErrorsPlotInstances();
plotInstances.setInstances(userTestStructure);
plotInstances.setClassifier(classifierToUse);
plotInstances.setClassIndex(userTestStructure.classIndex());
plotInstances.setEvaluation(eval);
plotInstances.setUp();
}
outBuff.append("\n=== Re-evaluation on test set ===\n\n");
outBuff.append("User supplied test set\n");
outBuff.append("Relation: "
+ userTestStructure.relationName() + '\n');
if (incrementalLoader)
outBuff.append("Instances: unknown (yet). Reading incrementally\n");
else
outBuff.append("Instances: " + source.getDataSet().numInstances() + "\n");
outBuff.append("Attributes: "
+ userTestStructure.numAttributes()
+ "\n\n");
if (trainHeader == null &&
!(classifierToUse instanceof
weka.classifiers.pmml.consumer.PMMLClassifier)) {
outBuff.append("NOTE - if test set is not compatible then results are "
+ "unpredictable\n\n");
}
AbstractOutput classificationOutput = null;
if (outputPredictionsText) {
classificationOutput = (AbstractOutput) m_ClassificationOutputEditor.getValue();
classificationOutput.setHeader(userTestStructure);
classificationOutput.setBuffer(outBuff);
/* classificationOutput.setAttributes("");
classificationOutput.setOutputDistribution(false);*/
// classificationOutput.printHeader();
}
// make adjustments if the classifier is an InputMappedClassifier
eval = setupEval(eval, classifierToUse, userTestStructure, costMatrix,
plotInstances, classificationOutput, false);
eval.useNoPriors();
if (outputPredictionsText) {
printPredictionsHeader(outBuff, classificationOutput, "user test set");
}
Instance instance;
int jj = 0;
while (source.hasMoreElements(userTestStructure)) {
instance = source.nextElement(userTestStructure);
plotInstances.process(instance, classifierToUse, eval);
if (outputPredictionsText) {
classificationOutput.printClassification(classifierToUse, instance, jj);
}
if ((++jj % 100) == 0) {