// create new train/test sources
if (splitPercentage > 0) {
testSetPresent = true;
Instances tmpInst = trainSource.getDataSet(actualClassIndex);
if (!preserveOrder)
tmpInst.randomize(new Random(seed));
int trainSize =
(int) Math.round(tmpInst.numInstances() * splitPercentage / 100);
int testSize = tmpInst.numInstances() - trainSize;
Instances trainInst = new Instances(tmpInst, 0, trainSize);
Instances testInst = new Instances(tmpInst, trainSize, testSize);
trainSource = new DataSource(trainInst);
testSource = new DataSource(testInst);
template = test = testSource.getStructure();
if (classIndex != -1) {
test.setClassIndex(classIndex - 1);
} else {
if ( (test.classIndex() == -1) || (classIndexString.length() != 0) )
test.setClassIndex(test.numAttributes() - 1);
}
actualClassIndex = test.classIndex();
}
}
if (trainSetPresent) {
template = train = trainSource.getStructure();
if (classIndex != -1) {
train.setClassIndex(classIndex - 1);
} else {
if ( (train.classIndex() == -1) || (classIndexString.length() != 0) )
train.setClassIndex(train.numAttributes() - 1);
}
actualClassIndex = train.classIndex();
if (!(classifier instanceof weka.classifiers.misc.InputMappedClassifier)) {
if ((testSetPresent) && !test.equalHeaders(train)) {
throw new IllegalArgumentException("Train and test file not compatible!\n" + test.equalHeadersMsg(train));
}
}
}
if (template == null) {
throw new Exception("No actual dataset provided to use as template");
}
costMatrix = handleCostOption(
Utils.getOption('m', options), template.numClasses());
classStatistics = Utils.getFlag('i', options);
noOutput = Utils.getFlag('o', options);
trainStatistics = !Utils.getFlag('v', options);
printComplexityStatistics = Utils.getFlag('k', options);
printMargins = Utils.getFlag('r', options);
printGraph = Utils.getFlag('g', options);
sourceClass = Utils.getOption('z', options);
printSource = (sourceClass.length() != 0);
thresholdFile = Utils.getOption("threshold-file", options);
thresholdLabel = Utils.getOption("threshold-label", options);
String classifications = Utils.getOption("classifications", options);
String classificationsOld = Utils.getOption("p", options);
if (classifications.length() > 0) {
noOutput = true;
classificationOutput = AbstractOutput.fromCommandline(classifications);
classificationOutput.setHeader(template);
}
// backwards compatible with old "-p range" and "-distribution" options
else if (classificationsOld.length() > 0) {
noOutput = true;
classificationOutput = new PlainText();
classificationOutput.setHeader(template);
if (!classificationsOld.equals("0"))
classificationOutput.setAttributes(classificationsOld);
classificationOutput.setOutputDistribution(Utils.getFlag("distribution", options));
}
// -distribution flag needs -p option
else {
if (Utils.getFlag("distribution", options))
throw new Exception("Cannot print distribution without '-p' option!");
}
// if no training file given, we don't have any priors
if ( (!trainSetPresent) && (printComplexityStatistics) )
throw new Exception("Cannot print complexity statistics ('-k') without training file ('-t')!");
// If a model file is given, we can't process
// scheme-specific options
if (objectInputFileName.length() != 0) {
Utils.checkForRemainingOptions(options);
} else {
// Set options for classifier
if (classifier instanceof OptionHandler) {
for (int i = 0; i < options.length; i++) {
if (options[i].length() != 0) {
if (schemeOptionsText == null) {
schemeOptionsText = new StringBuffer();
}
if (options[i].indexOf(' ') != -1) {
schemeOptionsText.append('"' + options[i] + "\" ");
} else {
schemeOptionsText.append(options[i] + " ");
}
}
}
((OptionHandler)classifier).setOptions(options);
}
}
Utils.checkForRemainingOptions(options);
} catch (Exception e) {
throw new Exception("\nWeka exception: " + e.getMessage()
+ makeOptionString(classifier, false));
}
if (objectInputFileName.length() != 0) {
// Load classifier from file
if (objectInputStream != null) {
classifier = (Classifier) objectInputStream.readObject();
// try and read a header (if present)
Instances savedStructure = null;
try {
savedStructure = (Instances) objectInputStream.readObject();
} catch (Exception ex) {
// don't make a fuss
}
if (savedStructure != null) {
// test for compatibility with template
if (!template.equalHeaders(savedStructure)) {
throw new Exception("training and test set are not compatible\n" + template.equalHeadersMsg(savedStructure));
}
}
objectInputStream.close();
}
else if (xmlInputStream != null) {
// whether KOML is available has already been checked (objectInputStream would null otherwise)!
classifier = (Classifier) KOML.read(xmlInputStream);
xmlInputStream.close();
}
}
// Setup up evaluation objects
Evaluation trainingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
Evaluation testingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) {
Instances mappedClassifierHeader =
((weka.classifiers.misc.InputMappedClassifier)classifier).
getModelHeader(new Instances(template, 0));
trainingEvaluation = new Evaluation(new Instances(mappedClassifierHeader, 0), costMatrix);
testingEvaluation = new Evaluation(new Instances(mappedClassifierHeader, 0), costMatrix);
}
// disable use of priors if no training file given
if (!trainSetPresent)
testingEvaluation.useNoPriors();
// backup of fully setup classifier for cross-validation
classifierBackup = AbstractClassifier.makeCopy(classifier);
// Build the classifier if no object file provided
if ((classifier instanceof UpdateableClassifier) &&
(testSetPresent || noCrossValidation) &&
(costMatrix == null) &&
(trainSetPresent)) {
// Build classifier incrementally
trainingEvaluation.setPriors(train);
testingEvaluation.setPriors(train);
trainTimeStart = System.currentTimeMillis();
if (objectInputFileName.length() == 0) {
classifier.buildClassifier(train);
}
Instance trainInst;
while (trainSource.hasMoreElements(train)) {
trainInst = trainSource.nextElement(train);
trainingEvaluation.updatePriors(trainInst);
testingEvaluation.updatePriors(trainInst);
((UpdateableClassifier)classifier).updateClassifier(trainInst);
}
trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
} else if (objectInputFileName.length() == 0) {
// Build classifier in one go
tempTrain = trainSource.getDataSet(actualClassIndex);
if (classifier instanceof weka.classifiers.misc.InputMappedClassifier &&
!trainingEvaluation.getHeader().equalHeaders(tempTrain)) {
// we need to make a new dataset that maps the training instances to
// the structure expected by the mapped classifier - this is only
// to ensure that the structure and priors computed by the *testing*
// evaluation object is correct with respect to the mapped classifier
Instances mappedClassifierDataset =
((weka.classifiers.misc.InputMappedClassifier)classifier).
getModelHeader(new Instances(template, 0));
for (int zz = 0; zz < tempTrain.numInstances(); zz++) {
Instance mapped = ((weka.classifiers.misc.InputMappedClassifier)classifier).
constructMappedInstance(tempTrain.instance(zz));
mappedClassifierDataset.add(mapped);
}
tempTrain = mappedClassifierDataset;
}
trainingEvaluation.setPriors(tempTrain);
testingEvaluation.setPriors(tempTrain);
trainTimeStart = System.currentTimeMillis();
classifier.buildClassifier(tempTrain);
trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
}
// backup of fully trained classifier for printing the classifications
if (classificationOutput != null) {
classifierClassifications = AbstractClassifier.makeCopy(classifier);
if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) {
classificationOutput.setHeader(trainingEvaluation.getHeader());
}
}
// Save the classifier if an object output file is provided
if (objectOutputFileName.length() != 0) {
OutputStream os = new FileOutputStream(objectOutputFileName);
// binary
if (!(objectOutputFileName.endsWith(".xml") || (objectOutputFileName.endsWith(".koml") && KOML.isPresent()))) {
if (objectOutputFileName.endsWith(".gz")) {
os = new GZIPOutputStream(os);
}
ObjectOutputStream objectOutputStream = new ObjectOutputStream(os);
objectOutputStream.writeObject(classifier);
if (template != null) {
objectOutputStream.writeObject(template);
}
objectOutputStream.flush();
objectOutputStream.close();
}
// KOML/XML
else {
BufferedOutputStream xmlOutputStream = new BufferedOutputStream(os);
if (objectOutputFileName.endsWith(".xml")) {
XMLSerialization xmlSerial = new XMLClassifier();
xmlSerial.write(xmlOutputStream, classifier);
}
else
// whether KOML is present has already been checked
// if not present -> ".koml" is interpreted as binary - see above
if (objectOutputFileName.endsWith(".koml")) {
KOML.write(xmlOutputStream, classifier);
}
xmlOutputStream.close();
}
}
// If classifier is drawable output string describing graph
if ((classifier instanceof Drawable) && (printGraph)){
return ((Drawable)classifier).graph();
}
// Output the classifier as equivalent source
if ((classifier instanceof Sourcable) && (printSource)){
return wekaStaticWrapper((Sourcable) classifier, sourceClass);
}
// Output model
if (!(noOutput || printMargins)) {
if (classifier instanceof OptionHandler) {
if (schemeOptionsText != null) {
text.append("\nOptions: "+schemeOptionsText);
text.append("\n");
}
}
text.append("\n" + classifier.toString() + "\n");
}
if (!printMargins && (costMatrix != null)) {
text.append("\n=== Evaluation Cost Matrix ===\n\n");
text.append(costMatrix.toString());
}
// Output test instance predictions only
if (classificationOutput != null) {
DataSource source = testSource;
predsBuff = new StringBuffer();
classificationOutput.setBuffer(predsBuff);
// no test set -> use train set
if (source == null && noCrossValidation) {
source = trainSource;
predsBuff.append("\n=== Predictions on training data ===\n\n");
} else {
predsBuff.append("\n=== Predictions on test data ===\n\n");
}
if (source != null)
classificationOutput.print(classifierClassifications, source);
}
// Compute error estimate from training data
if ((trainStatistics) && (trainSetPresent)) {
if ((classifier instanceof UpdateableClassifier) &&
(testSetPresent) &&
(costMatrix == null)) {
// Classifier was trained incrementally, so we have to
// reset the source.
trainSource.reset();
// Incremental testing
train = trainSource.getStructure(actualClassIndex);
testTimeStart = System.currentTimeMillis();
Instance trainInst;
while (trainSource.hasMoreElements(train)) {
trainInst = trainSource.nextElement(train);
trainingEvaluation.evaluateModelOnce((Classifier)classifier, trainInst);
}
testTimeElapsed = System.currentTimeMillis() - testTimeStart;
} else {
testTimeStart = System.currentTimeMillis();
trainingEvaluation.evaluateModel(
classifier, trainSource.getDataSet(actualClassIndex));
testTimeElapsed = System.currentTimeMillis() - testTimeStart;
}
// Print the results of the training evaluation
if (printMargins) {
return trainingEvaluation.toCumulativeMarginDistributionString();
} else {
if (classificationOutput == null) {
text.append("\nTime taken to build model: "
+ Utils.doubleToString(trainTimeElapsed / 1000.0,2)
+ " seconds");
if (splitPercentage > 0)
text.append("\nTime taken to test model on training split: ");
else
text.append("\nTime taken to test model on training data: ");
text.append(Utils.doubleToString(testTimeElapsed / 1000.0,2) + " seconds");
if (splitPercentage > 0)
text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training"
+ " split ===\n", printComplexityStatistics));
else
text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training"
+ " data ===\n", printComplexityStatistics));
if (template.classAttribute().isNominal()) {
if (classStatistics) {
text.append("\n\n" + trainingEvaluation.toClassDetailsString());
}
if (!noCrossValidation)
text.append("\n\n" + trainingEvaluation.toMatrixString());
}
}
}
}
// Compute proper error estimates
if (testSource != null) {
// Testing is on the supplied test data
testSource.reset();
test = testSource.getStructure(test.classIndex());
Instance testInst;
while (testSource.hasMoreElements(test)) {
testInst = testSource.nextElement(test);
testingEvaluation.evaluateModelOnceAndRecordPrediction(
(Classifier)classifier, testInst);
}
if (splitPercentage > 0) {
if (classificationOutput == null) {
text.append("\n\n" + testingEvaluation.
toSummaryString("=== Error on test split ===\n",
printComplexityStatistics));
}
} else {
if (classificationOutput == null) {
text.append("\n\n" + testingEvaluation.
toSummaryString("=== Error on test data ===\n",
printComplexityStatistics));
}
}
} else if (trainSource != null) {
if (!noCrossValidation) {
// Testing is via cross-validation on training data
Random random = new Random(seed);
// use untrained (!) classifier for cross-validation
classifier = AbstractClassifier.makeCopy(classifierBackup);
if (classificationOutput == null) {
testingEvaluation.crossValidateModel(classifier,
trainSource.getDataSet(actualClassIndex),