boolean noOutput = false,
printClassifications = false, trainStatistics = true,
printMargins = false, printComplexityStatistics = false,
printGraph = false, classStatistics = false, printSource = false;
StringBuffer text = new StringBuffer();
DataSource trainSource = null, testSource = null;
ObjectInputStream objectInputStream = null;
BufferedInputStream xmlInputStream = null;
CostMatrix costMatrix = null;
StringBuffer schemeOptionsText = null;
Range attributesToOutput = null;
long trainTimeStart = 0, trainTimeElapsed = 0,
testTimeStart = 0, testTimeElapsed = 0;
String xml = "";
String[] optionsTmp = null;
Classifier classifierBackup;
Classifier classifierClassifications = null;
boolean printDistribution = false;
int actualClassIndex = -1; // 0-based class index
String splitPercentageString = "";
int splitPercentage = -1;
boolean preserveOrder = false;
boolean trainSetPresent = false;
boolean testSetPresent = false;
String thresholdFile;
String thresholdLabel;
StringBuffer predsBuff = null; // predictions from cross-validation
// help requested?
if (Utils.getFlag("h", options) || Utils.getFlag("help", options)) {
throw new Exception("\nHelp requested." + makeOptionString(classifier));
}
try {
// do we get the input from XML instead of normal parameters?
xml = Utils.getOption("xml", options);
if (!xml.equals("")) {
options = new XMLOptions(xml).toArray();
}
// is the input model only the XML-Options, i.e. w/o built model?
optionsTmp = new String[options.length];
for (int i = 0; i < options.length; i++) {
optionsTmp[i] = options[i];
}
if (Utils.getOption('l', optionsTmp).toLowerCase().endsWith(".xml")) {
// load options from serialized data ('-l' is automatically erased!)
XMLClassifier xmlserial = new XMLClassifier();
Classifier cl = (Classifier) xmlserial.read(Utils.getOption('l', options));
// merge options
optionsTmp = new String[options.length + cl.getOptions().length];
System.arraycopy(cl.getOptions(), 0, optionsTmp, 0, cl.getOptions().length);
System.arraycopy(options, 0, optionsTmp, cl.getOptions().length, options.length);
options = optionsTmp;
}
noCrossValidation = Utils.getFlag("no-cv", options);
// Get basic options (options the same for all schemes)
classIndexString = Utils.getOption('c', options);
if (classIndexString.length() != 0) {
if (classIndexString.equals("first")) {
classIndex = 1;
} else if (classIndexString.equals("last")) {
classIndex = -1;
} else {
classIndex = Integer.parseInt(classIndexString);
}
}
trainFileName = Utils.getOption('t', options);
objectInputFileName = Utils.getOption('l', options);
objectOutputFileName = Utils.getOption('d', options);
testFileName = Utils.getOption('T', options);
foldsString = Utils.getOption('x', options);
if (foldsString.length() != 0) {
folds = Integer.parseInt(foldsString);
}
seedString = Utils.getOption('s', options);
if (seedString.length() != 0) {
seed = Integer.parseInt(seedString);
}
if (trainFileName.length() == 0) {
if (objectInputFileName.length() == 0) {
throw new Exception("No training file and no object " +
"input file given.");
}
if (testFileName.length() == 0) {
throw new Exception("No training file and no test " +
"file given.");
}
} else if ((objectInputFileName.length() != 0) &&
((!(classifier instanceof UpdateableClassifier)) ||
(testFileName.length() == 0))) {
throw new Exception("Classifier not incremental, or no " +
"test file provided: can't " +
"use both train and model file.");
}
try {
if (trainFileName.length() != 0) {
trainSetPresent = true;
trainSource = new DataSource(trainFileName);
}
if (testFileName.length() != 0) {
testSetPresent = true;
testSource = new DataSource(testFileName);
}
if (objectInputFileName.length() != 0) {
InputStream is = new FileInputStream(objectInputFileName);
if (objectInputFileName.endsWith(".gz")) {
is = new GZIPInputStream(is);
}
// load from KOML?
if (!(objectInputFileName.endsWith(".koml") && KOML.isPresent())) {
objectInputStream = new ObjectInputStream(is);
xmlInputStream = null;
} else {
objectInputStream = null;
xmlInputStream = new BufferedInputStream(is);
}
}
} catch (Exception e) {
throw new Exception("Can't open file " + e.getMessage() + '.');
}
if (testSetPresent) {
template = test = testSource.getStructure();
if (classIndex != -1) {
test.setClassIndex(classIndex - 1);
} else {
if ((test.classIndex() == -1) || (classIndexString.length() != 0)) {
test.setClassIndex(test.numAttributes() - 1);
}
}
actualClassIndex = test.classIndex();
} else {
// percentage split
splitPercentageString = Utils.getOption("split-percentage", options);
if (splitPercentageString.length() != 0) {
if (foldsString.length() != 0) {
throw new Exception(
"Percentage split cannot be used in conjunction with " + "cross-validation ('-x').");
}
splitPercentage = Integer.parseInt(splitPercentageString);
if ((splitPercentage <= 0) || (splitPercentage >= 100)) {
throw new Exception("Percentage split value needs be >0 and <100.");
}
} else {
splitPercentage = -1;
}
preserveOrder = Utils.getFlag("preserve-order", options);
if (preserveOrder) {
if (splitPercentage == -1) {
throw new Exception("Percentage split ('-percentage-split') is missing.");
}
}
// create new train/test sources
if (splitPercentage > 0) {
testSetPresent = true;
Instances tmpInst = trainSource.getDataSet(actualClassIndex);
if (!preserveOrder) {
tmpInst.randomize(new Random(seed));
}
int trainSize = tmpInst.numInstances() * splitPercentage / 100;
int testSize = tmpInst.numInstances() - trainSize;
Instances trainInst = new Instances(tmpInst, 0, trainSize);
Instances testInst = new Instances(tmpInst, trainSize, testSize);
trainSource = new DataSource(trainInst);
testSource = new DataSource(testInst);
template = test = testSource.getStructure();
if (classIndex != -1) {
test.setClassIndex(classIndex - 1);
} else {
if ((test.classIndex() == -1) || (classIndexString.length() != 0)) {
test.setClassIndex(test.numAttributes() - 1);
}
}
actualClassIndex = test.classIndex();
}
}
if (trainSetPresent) {
template = train = trainSource.getStructure();
if (classIndex != -1) {
train.setClassIndex(classIndex - 1);
} else {
if ((train.classIndex() == -1) || (classIndexString.length() != 0)) {
train.setClassIndex(train.numAttributes() - 1);
}
}
actualClassIndex = train.classIndex();
if ((testSetPresent) && !test.equalHeaders(train)) {
throw new IllegalArgumentException("Train and test file not compatible!");
}
}
if (template == null) {
throw new Exception("No actual dataset provided to use as template");
}
costMatrix = handleCostOption(
Utils.getOption('m', options), template.numClasses());
classStatistics = Utils.getFlag('i', options);
noOutput = Utils.getFlag('o', options);
trainStatistics = !Utils.getFlag('v', options);
printComplexityStatistics = Utils.getFlag('k', options);
printMargins = Utils.getFlag('r', options);
printGraph = Utils.getFlag('g', options);
sourceClass = Utils.getOption('z', options);
printSource = (sourceClass.length() != 0);
printDistribution = Utils.getFlag("distribution", options);
thresholdFile = Utils.getOption("threshold-file", options);
thresholdLabel = Utils.getOption("threshold-label", options);
// Check -p option
try {
attributeRangeString = Utils.getOption('p', options);
} catch (Exception e) {
throw new Exception(e.getMessage() + "\nNOTE: the -p option has changed. " +
"It now expects a parameter specifying a range of attributes " +
"to list with the predictions. Use '-p 0' for none.");
}
if (attributeRangeString.length() != 0) {
printClassifications = true;
if (!attributeRangeString.equals("0")) {
attributesToOutput = new Range(attributeRangeString);
}
}
if (!printClassifications && printDistribution) {
throw new Exception("Cannot print distribution without '-p' option!");
}
// if no training file given, we don't have any priors
if ((!trainSetPresent) && (printComplexityStatistics)) {
throw new Exception("Cannot print complexity statistics ('-k') without training file ('-t')!");
}
// If a model file is given, we can't process
// scheme-specific options
if (objectInputFileName.length() != 0) {
Utils.checkForRemainingOptions(options);
} else {
// Set options for classifier
if (classifier instanceof OptionHandler) {
for (int i = 0; i < options.length; i++) {
if (options[i].length() != 0) {
if (schemeOptionsText == null) {
schemeOptionsText = new StringBuffer();
}
if (options[i].indexOf(' ') != -1) {
schemeOptionsText.append('"' + options[i] + "\" ");
} else {
schemeOptionsText.append(options[i] + " ");
}
}
}
((OptionHandler) classifier).setOptions(options);
}
}
Utils.checkForRemainingOptions(options);
} catch (Exception e) {
throw new Exception("\nWeka exception: " + e.getMessage() + makeOptionString(classifier));
}
// Setup up evaluation objects
Evaluation trainingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
Evaluation testingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
// disable use of priors if no training file given
if (!trainSetPresent) {
testingEvaluation.useNoPriors();
}
if (objectInputFileName.length() != 0) {
// Load classifier from file
if (objectInputStream != null) {
classifier = (Classifier) objectInputStream.readObject();
// try and read a header (if present)
Instances savedStructure = null;
try {
savedStructure = (Instances) objectInputStream.readObject();
} catch (Exception ex) {
// don't make a fuss
}
if (savedStructure != null) {
// test for compatibility with template
if (!template.equalHeaders(savedStructure)) {
throw new Exception("training and test set are not compatible");
}
}
objectInputStream.close();
} else {
// whether KOML is available has already been checked (objectInputStream would null otherwise)!
classifier = (Classifier) KOML.read(xmlInputStream);
xmlInputStream.close();
}
}
// backup of fully setup classifier for cross-validation
classifierBackup = Classifier.makeCopy(classifier);
// Build the classifier if no object file provided
if ((classifier instanceof UpdateableClassifier) &&
(testSetPresent || noCrossValidation) &&
(costMatrix == null) &&
(trainSetPresent)) {
// Build classifier incrementally
trainingEvaluation.setPriors(train);
testingEvaluation.setPriors(train);
trainTimeStart = System.currentTimeMillis();
if (objectInputFileName.length() == 0) {
classifier.buildClassifier(train);
}
Instance trainInst;
while (trainSource.hasMoreElements(train)) {
trainInst = trainSource.nextElement(train);
trainingEvaluation.updatePriors(trainInst);
testingEvaluation.updatePriors(trainInst);
((UpdateableClassifier) classifier).updateClassifier(trainInst);
}
trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
} else if (objectInputFileName.length() == 0) {
// Build classifier in one go
tempTrain = trainSource.getDataSet(actualClassIndex);
trainingEvaluation.setPriors(tempTrain);
testingEvaluation.setPriors(tempTrain);
trainTimeStart = System.currentTimeMillis();
classifier.buildClassifier(tempTrain);
trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
}
/* FOR LARGE DATA SETS
// backup of fully trained classifier for printing the classifications
if (printClassifications) {
classifierClassifications = Classifier.makeCopy(classifier);
}
*/
// Save the classifier if an object output file is provided
if (objectOutputFileName.length() != 0) {
OutputStream os = new FileOutputStream(objectOutputFileName);
// binary
if (!(objectOutputFileName.endsWith(".xml") || (objectOutputFileName.endsWith(".koml") && KOML.isPresent()))) {
if (objectOutputFileName.endsWith(".gz")) {
os = new GZIPOutputStream(os);
}
ObjectOutputStream objectOutputStream = new ObjectOutputStream(os);
objectOutputStream.writeObject(classifier);
if (template != null) {
objectOutputStream.writeObject(template);
}
objectOutputStream.flush();
objectOutputStream.close();
} // KOML/XML
else {
BufferedOutputStream xmlOutputStream = new BufferedOutputStream(os);
if (objectOutputFileName.endsWith(".xml")) {
XMLSerialization xmlSerial = new XMLClassifier();
xmlSerial.write(xmlOutputStream, classifier);
} else // whether KOML is present has already been checked
// if not present -> ".koml" is interpreted as binary - see above
if (objectOutputFileName.endsWith(".koml")) {
KOML.write(xmlOutputStream, classifier);
}
xmlOutputStream.close();
}
}
/* FOR LARGE DATA SETS
// If classifier is drawable output string describing graph
if ((classifier instanceof Drawable) && (printGraph)) {
return ((Drawable) classifier).graph();
}
// Output the classifier as equivalent source
if ((classifier instanceof Sourcable) && (printSource)) {
return wekaStaticWrapper((Sourcable) classifier, sourceClass);
}
// Output model
if (!(noOutput || printMargins)) {
if (classifier instanceof OptionHandler) {
if (schemeOptionsText != null) {
text.append("\nOptions: " + schemeOptionsText);
text.append("\n");
}
}
text.append("\n" + classifier.toString() + "\n");
}
if (!printMargins && (costMatrix != null)) {
text.append("\n=== Evaluation Cost Matrix ===\n\n");
text.append(costMatrix.toString());
}
*/ // FOR LARGE DATA SETS
// Output test instance predictions only
if (printClassifications) {
DataSource source = testSource;
predsBuff = new StringBuffer();
// no test set -> use train set
if (source == null && noCrossValidation) {
source = trainSource;
predsBuff.append("\n=== Predictions on training data ===\n\n");