int testMode = 0;
int numFolds = 10;
double percent = 66;
int classIndex = m_ClassCombo.getSelectedIndex();
Classifier classifier = (Classifier) m_ClassifierEditor.getValue();
Classifier template = null;
try {
template = Classifier.makeCopy(classifier);
} catch (Exception ex) {
m_Log.logMessage(Messages.getInstance().getString("ClassifierPanel_StartClassifier_Log_LogMessage_Text_First") + ex.getMessage());
}
Classifier fullClassifier = null;
StringBuffer outBuff = new StringBuffer();
String name = (new SimpleDateFormat("HH:mm:ss - "))
.format(new Date());
String cname = classifier.getClass().getName();
if (cname.startsWith("weka.classifiers.")) {
name += cname.substring("weka.classifiers.".length());
} else {
name += cname;
}
String cmd = m_ClassifierEditor.getValue().getClass().getName();
if (m_ClassifierEditor.getValue() instanceof OptionHandler)
cmd += " " + Utils.joinOptions(((OptionHandler) m_ClassifierEditor.getValue()).getOptions());
Evaluation eval = null;
try {
if (m_CVBut.isSelected()) {
testMode = 1;
numFolds = Integer.parseInt(m_CVText.getText());
if (numFolds <= 1) {
throw new Exception(Messages.getInstance().getString("ClassifierPanel_StartClassifier_Exception_Text_First"));
}
} else if (m_PercentBut.isSelected()) {
testMode = 2;
percent = Double.parseDouble(m_PercentText.getText());
if ((percent <= 0) || (percent >= 100)) {
throw new Exception(Messages.getInstance().getString("ClassifierPanel_StartClassifier_Exception_Text_Second"));
}
} else if (m_TrainBut.isSelected()) {
testMode = 3;
} else if (m_TestSplitBut.isSelected()) {
testMode = 4;
// Check the test instance compatibility
if (source == null) {
throw new Exception(Messages.getInstance().getString("ClassifierPanel_StartClassifier_Exception_Text_Third"));
}
if (!inst.equalHeaders(userTestStructure)) {
throw new Exception(Messages.getInstance().getString("ClassifierPanel_StartClassifier_Exception_Text_Fourth"));
}
userTestStructure.setClassIndex(classIndex);
} else {
throw new Exception(Messages.getInstance().getString("ClassifierPanel_StartClassifier_Exception_Text_Fifth"));
}
inst.setClassIndex(classIndex);
// set up the structure of the plottable instances for
// visualization
if (saveVis) {
predInstances = setUpVisualizableInstances(inst);
predInstances.setClassIndex(inst.classIndex()+1);
}
// Output some header information
m_Log.logMessage(Messages.getInstance().getString("ClassifierPanel_StartClassifier_Log_LogMessage_Text_Second") + cname);
m_Log.logMessage(Messages.getInstance().getString("ClassifierPanel_StartClassifier_Log_LogMessage_Text_Third") + cmd);
if (m_Log instanceof TaskLogger) {
((TaskLogger)m_Log).taskStarted();
}
outBuff.append(Messages.getInstance().getString("ClassifierPanel_StartClassifier_OutBuffer_Text_First"));
outBuff.append(Messages.getInstance().getString("ClassifierPanel_StartClassifier_OutBuffer_Text_Second") + cname);
if (classifier instanceof OptionHandler) {
String [] o = ((OptionHandler) classifier).getOptions();
outBuff.append(" " + Utils.joinOptions(o));
}
outBuff.append("\n");
outBuff.append(Messages.getInstance().getString("ClassifierPanel_StartClassifier_OutBuffer_Text_Fourth") + inst.relationName() + '\n');
outBuff.append(Messages.getInstance().getString("ClassifierPanel_StartClassifier_OutBuffer_Text_Sixth") + inst.numInstances() + '\n');
outBuff.append(Messages.getInstance().getString("ClassifierPanel_StartClassifier_OutBuffer_Text_Eigth") + inst.numAttributes() + '\n');
if (inst.numAttributes() < 100) {
for (int i = 0; i < inst.numAttributes(); i++) {
outBuff.append(" " + inst.attribute(i).name()
+ '\n');
}
} else {
outBuff.append(Messages.getInstance().getString("ClassifierPanel_StartClassifier_OutBuffer_Text_Twelveth"));
}
outBuff.append(Messages.getInstance().getString("ClassifierPanel_StartClassifier_OutBuffer_Text_Thirteenth"));
switch (testMode) {
case 3: // Test on training
outBuff.append(Messages.getInstance().getString("ClassifierPanel_StartClassifier_OutBuffer_Text_Fourteenth"));
break;
case 1: // CV mode
outBuff.append("" + numFolds + Messages.getInstance().getString("ClassifierPanel_StartClassifier_OutBuffer_Text_Sixteenth"));
break;
case 2: // Percent split
outBuff.append(Messages.getInstance().getString("ClassifierPanel_StartClassifier_OutBuffer_Text_Seventeenth") + percent
+ Messages.getInstance().getString("ClassifierPanel_StartClassifier_OutBuffer_Text_Eighteenth"));
break;
case 4: // Test on user split
if (source.isIncremental())
outBuff.append(Messages.getInstance().getString("ClassifierPanel_StartClassifier_OutBuffer_Text_Nineteenth"));
else
outBuff.append(Messages.getInstance().getString("ClassifierPanel_StartClassifier_OutBuffer_Text_Twentyth")
+ source.getDataSet().numInstances() + Messages.getInstance().getString("ClassifierPanel_StartClassifier_OutBuffer_Text_TwentyFirst"));
break;
}
if (costMatrix != null) {
outBuff.append(Messages.getInstance().getString("ClassifierPanel_StartClassifier_OutBuffer_Text_TwentySecond"))
.append(costMatrix.toString()).append("\n");
}
outBuff.append("\n");
m_History.addResult(name, outBuff);
m_History.setSingle(name);
// Build the model and output it.
if (outputModel || (testMode == 3) || (testMode == 4)) {
m_Log.statusMessage(Messages.getInstance().getString("ClassifierPanel_StartClassifier_Log_StatusMessage_Text_Second"));
trainTimeStart = System.currentTimeMillis();
classifier.buildClassifier(inst);
trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
}
if (outputModel) {
outBuff.append(Messages.getInstance().getString("ClassifierPanel_StartClassifier_OutBuffer_Text_TwentySixth"));
outBuff.append(classifier.toString() + "\n");
outBuff.append(Messages.getInstance().getString("ClassifierPanel_StartClassifier_OutBuffer_Text_TwentyEighth") +
Utils.doubleToString(trainTimeElapsed / 1000.0,2)
+ " " + Messages.getInstance().getString("ClassifierPanel_StartClassifier_OutBuffer_Text_TwentyNineth"));
m_History.updateResult(name);
if (classifier instanceof Drawable) {
grph = null;
try {
grph = ((Drawable)classifier).graph();
} catch (Exception ex) {
}
}
// copy full model for output
SerializedObject so = new SerializedObject(classifier);
fullClassifier = (Classifier) so.getObject();
}
switch (testMode) {
case 3: // Test on training
m_Log.statusMessage("Evaluating on training data...");
eval = new Evaluation(inst, costMatrix);
if (outputPredictionsText) {
printPredictionsHeader(outBuff, inst, "training set");
}
for (int jj=0;jj<inst.numInstances();jj++) {
processClassifierPrediction(inst.instance(jj), classifier,
eval, predInstances, plotShape,
plotSize);
if (outputPredictionsText) {
outBuff.append(predictionText(classifier, inst.instance(jj), jj+1));
}
if ((jj % 100) == 0) {
m_Log.statusMessage("Evaluating on training data. Processed "
+jj+" instances...");
}
}
if (outputPredictionsText) {
outBuff.append("\n");
}
outBuff.append("=== Evaluation on training set ===\n");
break;
case 1: // CV mode
m_Log.statusMessage("Randomizing instances...");
int rnd = 1;
try {
rnd = Integer.parseInt(m_RandomSeedText.getText().trim());
// System.err.println("Using random seed "+rnd);
} catch (Exception ex) {
m_Log.logMessage("Trouble parsing random seed value");
rnd = 1;
}
Random random = new Random(rnd);
inst.randomize(random);
if (inst.attribute(classIndex).isNominal()) {
m_Log.statusMessage("Stratifying instances...");
inst.stratify(numFolds);
}
eval = new Evaluation(inst, costMatrix);
if (outputPredictionsText) {
printPredictionsHeader(outBuff, inst, "test data");
}
// Make some splits and do a CV
for (int fold = 0; fold < numFolds; fold++) {
m_Log.statusMessage(Messages.getInstance().getString("ClassifierPanel_StartClassifier_Log_StatusMessage_Text_Eighth")
+ (fold + 1) + Messages.getInstance().getString("ClassifierPanel_StartClassifier_Log_StatusMessage_Text_Nineth"));
Instances train = inst.trainCV(numFolds, fold, random);
eval.setPriors(train);
m_Log.statusMessage(Messages.getInstance().getString("ClassifierPanel_StartClassifier_Log_StatusMessage_Text_Tenth")
+ (fold + 1) + Messages.getInstance().getString("ClassifierPanel_StartClassifier_Log_StatusMessage_Text_Twelveth"));
Classifier current = null;
try {
current = Classifier.makeCopy(template);
} catch (Exception ex) {
m_Log.logMessage(Messages.getInstance().getString("ClassifierPanel_StartClassifier_Log_LogMessage_Text_Fifth") + ex.getMessage());
}
current.buildClassifier(train);
Instances test = inst.testCV(numFolds, fold);
m_Log.statusMessage(Messages.getInstance().getString("ClassifierPanel_StartClassifier_Log_StatusMessage_Text_Eleventh")
+ (fold + 1) + Messages.getInstance().getString("ClassifierPanel_StartClassifier_Log_StatusMessage_Text_Twelveth"));
for (int jj=0;jj<test.numInstances();jj++) {
processClassifierPrediction(test.instance(jj), current,
eval, predInstances, plotShape,
plotSize);
if (outputPredictionsText) {
outBuff.append(predictionText(current, test.instance(jj), jj+1));
}
}
}
if (outputPredictionsText) {
outBuff.append("\n");
}
if (inst.attribute(classIndex).isNominal()) {
outBuff.append(Messages.getInstance().getString("ClassifierPanel_StartClassifier_OutBuffer_Text_ThirtyThird"));
} else {
outBuff.append(Messages.getInstance().getString("ClassifierPanel_StartClassifier_OutBuffer_Text_ThirtyFourth"));
}
break;
case 2: // Percent split
if (!m_PreserveOrderBut.isSelected()) {
m_Log.statusMessage(Messages.getInstance().getString("ClassifierPanel_StartClassifier_Log_StatusMessage_Text_Thirteenth"));
try {
rnd = Integer.parseInt(m_RandomSeedText.getText().trim());
} catch (Exception ex) {
m_Log.logMessage(Messages.getInstance().getString("ClassifierPanel_StartClassifier_Log_StatusMessage_Text_Fourteenth"));
rnd = 1;
}
inst.randomize(new Random(rnd));
}
int trainSize = (int) Math.round(inst.numInstances() * percent / 100);
int testSize = inst.numInstances() - trainSize;
Instances train = new Instances(inst, 0, trainSize);
Instances test = new Instances(inst, trainSize, testSize);
m_Log.statusMessage(Messages.getInstance().getString("ClassifierPanel_StartClassifier_Log_StatusMessage_Text_Fifteenth") + trainSize+ Messages.getInstance().getString("ClassifierPanel_StartClassifier_Log_StatusMessage_Text_Sixteenth"));
Classifier current = null;
try {
current = Classifier.makeCopy(template);
} catch (Exception ex) {
m_Log.logMessage(Messages.getInstance().getString("ClassifierPanel_StartClassifier_Log_LogMessage_Text_Sixth") + ex.getMessage());
}
current.buildClassifier(train);
eval = new Evaluation(train, costMatrix);
m_Log.statusMessage(Messages.getInstance().getString("ClassifierPanel_StartClassifier_Log_StatusMessage_Text_Seventeenth"));
if (outputPredictionsText) {
printPredictionsHeader(outBuff, inst, Messages.getInstance().getString("ClassifierPanel_StartClassifier_PrintPredictionsHeader_Text_First"));