Package weka.core.converters.ConverterUtils

Examples of weka.core.converters.ConverterUtils.DataSource


    boolean noOutput = false,
    trainStatistics = true,
    printMargins = false, printComplexityStatistics = false,
    printGraph = false, classStatistics = false, printSource = false;
    StringBuffer text = new StringBuffer();
    DataSource trainSource = null, testSource = null;
    ObjectInputStream objectInputStream = null;
    BufferedInputStream xmlInputStream = null;
    CostMatrix costMatrix = null;
    StringBuffer schemeOptionsText = null;
    long trainTimeStart = 0, trainTimeElapsed = 0,
    testTimeStart = 0, testTimeElapsed = 0;
    String xml = "";
    String[] optionsTmp = null;
    Classifier classifierBackup;
    Classifier classifierClassifications = null;
    int actualClassIndex = -1// 0-based class index
    String splitPercentageString = "";
    double splitPercentage = -1;
    boolean preserveOrder = false;
    boolean trainSetPresent = false;
    boolean testSetPresent = false;
    String thresholdFile;
    String thresholdLabel;
    StringBuffer predsBuff = null; // predictions from cross-validation
    AbstractOutput classificationOutput = null;

    // help requested?
    if (Utils.getFlag("h", options) || Utils.getFlag("help", options)) {

      // global info requested as well?
      boolean globalInfo = Utils.getFlag("synopsis", options) ||
        Utils.getFlag("info", options);

      throw new Exception("\nHelp requested."
          + makeOptionString(classifier, globalInfo));
    }

    try {
      // do we get the input from XML instead of normal parameters?
      xml = Utils.getOption("xml", options);
      if (!xml.equals(""))
        options = new XMLOptions(xml).toArray();

      // is the input model only the XML-Options, i.e. w/o built model?
      optionsTmp = new String[options.length];
      for (int i = 0; i < options.length; i++)
        optionsTmp[i] = options[i];

      String tmpO = Utils.getOption('l', optionsTmp);
      //if (Utils.getOption('l', optionsTmp).toLowerCase().endsWith(".xml")) {
      if (tmpO.endsWith(".xml")) {
        // try to load file as PMML first
        boolean success = false;
        try {
          PMMLModel pmmlModel = PMMLFactory.getPMMLModel(tmpO);
          if (pmmlModel instanceof PMMLClassifier) {
            classifier = ((PMMLClassifier)pmmlModel);
            success = true;
          }
        } catch (IllegalArgumentException ex) {
          success = false;
        }
        if (!success) {
          // load options from serialized data  ('-l' is automatically erased!)
          XMLClassifier xmlserial = new XMLClassifier();
          OptionHandler cl = (OptionHandler) xmlserial.read(Utils.getOption('l', options));

          // merge options
          optionsTmp = new String[options.length + cl.getOptions().length];
          System.arraycopy(cl.getOptions(), 0, optionsTmp, 0, cl.getOptions().length);
          System.arraycopy(options, 0, optionsTmp, cl.getOptions().length, options.length);
          options = optionsTmp;
        }
      }

      noCrossValidation = Utils.getFlag("no-cv", options);
      // Get basic options (options the same for all schemes)
      classIndexString = Utils.getOption('c', options);
      if (classIndexString.length() != 0) {
        if (classIndexString.equals("first"))
          classIndex = 1;
        else if (classIndexString.equals("last"))
          classIndex = -1;
        else
          classIndex = Integer.parseInt(classIndexString);
      }
      trainFileName = Utils.getOption('t', options);
      objectInputFileName = Utils.getOption('l', options);
      objectOutputFileName = Utils.getOption('d', options);
      testFileName = Utils.getOption('T', options);
      foldsString = Utils.getOption('x', options);
      if (foldsString.length() != 0) {
        folds = Integer.parseInt(foldsString);
      }
      seedString = Utils.getOption('s', options);
      if (seedString.length() != 0) {
        seed = Integer.parseInt(seedString);
      }
      if (trainFileName.length() == 0) {
        if (objectInputFileName.length() == 0) {
          throw new Exception("No training file and no object input file given.");
        }
        if (testFileName.length() == 0) {
          throw new Exception("No training file and no test file given.");
        }
      } else if ((objectInputFileName.length() != 0) &&
          ((!(classifier instanceof UpdateableClassifier)) ||
           (testFileName.length() == 0))) {
        throw new Exception("Classifier not incremental, or no " +
            "test file provided: can't "+
            "use both train and model file.");
      }
      try {
        if (trainFileName.length() != 0) {
          trainSetPresent = true;
          trainSource = new DataSource(trainFileName);
        }
        if (testFileName.length() != 0) {
          testSetPresent = true;
          testSource = new DataSource(testFileName);
        }
        if (objectInputFileName.length() != 0) {
          if (objectInputFileName.endsWith(".xml")) {
            // if this is the case then it means that a PMML classifier was
            // successfully loaded earlier in the code
            objectInputStream = null;
            xmlInputStream = null;
          } else {
            InputStream is = new FileInputStream(objectInputFileName);
            if (objectInputFileName.endsWith(".gz")) {
              is = new GZIPInputStream(is);
            }
            // load from KOML?
            if (!(objectInputFileName.endsWith(".koml") && KOML.isPresent()) ) {
              objectInputStream = new ObjectInputStream(is);
              xmlInputStream    = null;
            }
            else {
              objectInputStream = null;
              xmlInputStream    = new BufferedInputStream(is);
            }
          }
        }
      } catch (Exception e) {
        throw new Exception("Can't open file " + e.getMessage() + '.');
      }
      if (testSetPresent) {
        template = test = testSource.getStructure();
        if (classIndex != -1) {
          test.setClassIndex(classIndex - 1);
        } else {
          if ( (test.classIndex() == -1) || (classIndexString.length() != 0) )
            test.setClassIndex(test.numAttributes() - 1);
        }
        actualClassIndex = test.classIndex();
      }
      else {
        // percentage split
        splitPercentageString = Utils.getOption("split-percentage", options);
        if (splitPercentageString.length() != 0) {
          if (foldsString.length() != 0)
            throw new Exception(
                "Percentage split cannot be used in conjunction with "
                + "cross-validation ('-x').");
          splitPercentage = Double.parseDouble(splitPercentageString);
          if ((splitPercentage <= 0) || (splitPercentage >= 100))
            throw new Exception("Percentage split value needs be >0 and <100.");
        }
        else {
          splitPercentage = -1;
        }
        preserveOrder = Utils.getFlag("preserve-order", options);
        if (preserveOrder) {
          if (splitPercentage == -1)
            throw new Exception("Percentage split ('-percentage-split') is missing.");
        }
        // create new train/test sources
        if (splitPercentage > 0) {
          testSetPresent = true;
          Instances tmpInst = trainSource.getDataSet(actualClassIndex);
          if (!preserveOrder)
            tmpInst.randomize(new Random(seed));
          int trainSize =
            (int) Math.round(tmpInst.numInstances() * splitPercentage / 100);
          int testSize  = tmpInst.numInstances() - trainSize;
          Instances trainInst = new Instances(tmpInst, 0, trainSize);
          Instances testInst  = new Instances(tmpInst, trainSize, testSize);
          trainSource = new DataSource(trainInst);
          testSource  = new DataSource(testInst);
          template = test = testSource.getStructure();
          if (classIndex != -1) {
            test.setClassIndex(classIndex - 1);
          } else {
            if ( (test.classIndex() == -1) || (classIndexString.length() != 0) )
              test.setClassIndex(test.numAttributes() - 1);
          }
          actualClassIndex = test.classIndex();
        }
      }
      if (trainSetPresent) {
        template = train = trainSource.getStructure();
        if (classIndex != -1) {
          train.setClassIndex(classIndex - 1);
        } else {
          if ( (train.classIndex() == -1) || (classIndexString.length() != 0) )
            train.setClassIndex(train.numAttributes() - 1);
        }
        actualClassIndex = train.classIndex();
        if (!(classifier instanceof weka.classifiers.misc.InputMappedClassifier)) {
          if ((testSetPresent) && !test.equalHeaders(train)) {
            throw new IllegalArgumentException("Train and test file not compatible!\n" + test.equalHeadersMsg(train));
          }
        }
      }
      if (template == null) {
        throw new Exception("No actual dataset provided to use as template");
      }
      costMatrix = handleCostOption(
          Utils.getOption('m', options), template.numClasses());

      classStatistics = Utils.getFlag('i', options);
      noOutput = Utils.getFlag('o', options);
      trainStatistics = !Utils.getFlag('v', options);
      printComplexityStatistics = Utils.getFlag('k', options);
      printMargins = Utils.getFlag('r', options);
      printGraph = Utils.getFlag('g', options);
      sourceClass = Utils.getOption('z', options);
      printSource = (sourceClass.length() != 0);
      thresholdFile = Utils.getOption("threshold-file", options);
      thresholdLabel = Utils.getOption("threshold-label", options);

      String classifications = Utils.getOption("classifications", options);
      String classificationsOld = Utils.getOption("p", options);
      if (classifications.length() > 0) {
        noOutput = true;
        classificationOutput = AbstractOutput.fromCommandline(classifications);
        classificationOutput.setHeader(template);
      }
      // backwards compatible with old "-p range" and "-distribution" options
      else if (classificationsOld.length() > 0) {
        noOutput = true;
        classificationOutput = new PlainText();
        classificationOutput.setHeader(template);
        if (!classificationsOld.equals("0"))
          classificationOutput.setAttributes(classificationsOld);
        classificationOutput.setOutputDistribution(Utils.getFlag("distribution", options));
      }
      // -distribution flag needs -p option
      else {
        if (Utils.getFlag("distribution", options))
          throw new Exception("Cannot print distribution without '-p' option!");
      }

      // if no training file given, we don't have any priors
      if ( (!trainSetPresent) && (printComplexityStatistics) )
        throw new Exception("Cannot print complexity statistics ('-k') without training file ('-t')!");

      // If a model file is given, we can't process
      // scheme-specific options
      if (objectInputFileName.length() != 0) {
        Utils.checkForRemainingOptions(options);
      } else {

        // Set options for classifier
        if (classifier instanceof OptionHandler) {
          for (int i = 0; i < options.length; i++) {
            if (options[i].length() != 0) {
              if (schemeOptionsText == null) {
                schemeOptionsText = new StringBuffer();
              }
              if (options[i].indexOf(' ') != -1) {
                schemeOptionsText.append('"' + options[i] + "\" ");
              } else {
                schemeOptionsText.append(options[i] + " ");
              }
            }
          }
          ((OptionHandler)classifier).setOptions(options);
        }
      }

      Utils.checkForRemainingOptions(options);
    } catch (Exception e) {
      throw new Exception("\nWeka exception: " + e.getMessage()
          + makeOptionString(classifier, false));
    }

    if (objectInputFileName.length() != 0) {
      // Load classifier from file
      if (objectInputStream != null) {
        classifier = (Classifier) objectInputStream.readObject();
        // try and read a header (if present)
        Instances savedStructure = null;
        try {
          savedStructure = (Instances) objectInputStream.readObject();
        } catch (Exception ex) {
          // don't make a fuss
        }
        if (savedStructure != null) {
          // test for compatibility with template
          if (!template.equalHeaders(savedStructure)) {
            throw new Exception("training and test set are not compatible\n" + template.equalHeadersMsg(savedStructure));
          }
        }
        objectInputStream.close();
      }
      else if (xmlInputStream != null) {
        // whether KOML is available has already been checked (objectInputStream would null otherwise)!
        classifier = (Classifier) KOML.read(xmlInputStream);
        xmlInputStream.close();
      }
    }
   
    // Setup up evaluation objects
    Evaluation trainingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
    Evaluation testingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
    if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) {
      Instances mappedClassifierHeader =
        ((weka.classifiers.misc.InputMappedClassifier)classifier).
          getModelHeader(new Instances(template, 0));
           
      trainingEvaluation = new Evaluation(new Instances(mappedClassifierHeader, 0), costMatrix);
      testingEvaluation = new Evaluation(new Instances(mappedClassifierHeader, 0), costMatrix);
    }

    // disable use of priors if no training file given
    if (!trainSetPresent)
      testingEvaluation.useNoPriors();

    // backup of fully setup classifier for cross-validation
    classifierBackup = AbstractClassifier.makeCopy(classifier);

    // Build the classifier if no object file provided
    if ((classifier instanceof UpdateableClassifier) &&
        (testSetPresent || noCrossValidation) &&
        (costMatrix == null) &&
        (trainSetPresent)) {
      // Build classifier incrementally
      trainingEvaluation.setPriors(train);
      testingEvaluation.setPriors(train);
      trainTimeStart = System.currentTimeMillis();
      if (objectInputFileName.length() == 0) {
        classifier.buildClassifier(train);
      }
      Instance trainInst;
      while (trainSource.hasMoreElements(train)) {
        trainInst = trainSource.nextElement(train);
        trainingEvaluation.updatePriors(trainInst);
        testingEvaluation.updatePriors(trainInst);
        ((UpdateableClassifier)classifier).updateClassifier(trainInst);
      }
      trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
    } else if (objectInputFileName.length() == 0) {
      // Build classifier in one go
      tempTrain = trainSource.getDataSet(actualClassIndex);
     
      if (classifier instanceof weka.classifiers.misc.InputMappedClassifier &&
          !trainingEvaluation.getHeader().equalHeaders(tempTrain)) {
        // we need to make a new dataset that maps the training instances to
        // the structure expected by the mapped classifier - this is only
        // to ensure that the structure and priors computed by the *testing*
        // evaluation object is correct with respect to the mapped classifier
        Instances mappedClassifierDataset =
          ((weka.classifiers.misc.InputMappedClassifier)classifier).
            getModelHeader(new Instances(template, 0));
        for (int zz = 0; zz < tempTrain.numInstances(); zz++) {
          Instance mapped = ((weka.classifiers.misc.InputMappedClassifier)classifier).
            constructMappedInstance(tempTrain.instance(zz));
          mappedClassifierDataset.add(mapped);
        }
        tempTrain = mappedClassifierDataset;
      }
     
      trainingEvaluation.setPriors(tempTrain);
      testingEvaluation.setPriors(tempTrain);
      trainTimeStart = System.currentTimeMillis();
      classifier.buildClassifier(tempTrain);
      trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
    }

    // backup of fully trained classifier for printing the classifications
    if (classificationOutput != null) {
      classifierClassifications = AbstractClassifier.makeCopy(classifier);
      if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) {
        classificationOutput.setHeader(trainingEvaluation.getHeader());
      }
    }

    // Save the classifier if an object output file is provided
    if (objectOutputFileName.length() != 0) {
      OutputStream os = new FileOutputStream(objectOutputFileName);
      // binary
      if (!(objectOutputFileName.endsWith(".xml") || (objectOutputFileName.endsWith(".koml") && KOML.isPresent()))) {
        if (objectOutputFileName.endsWith(".gz")) {
          os = new GZIPOutputStream(os);
        }
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(os);
        objectOutputStream.writeObject(classifier);
        if (template != null) {
          objectOutputStream.writeObject(template);
        }
        objectOutputStream.flush();
        objectOutputStream.close();
      }
      // KOML/XML
      else {
        BufferedOutputStream xmlOutputStream = new BufferedOutputStream(os);
        if (objectOutputFileName.endsWith(".xml")) {
          XMLSerialization xmlSerial = new XMLClassifier();
          xmlSerial.write(xmlOutputStream, classifier);
        }
        else
          // whether KOML is present has already been checked
          // if not present -> ".koml" is interpreted as binary - see above
          if (objectOutputFileName.endsWith(".koml")) {
            KOML.write(xmlOutputStream, classifier);
          }
        xmlOutputStream.close();
      }
    }

    // If classifier is drawable output string describing graph
    if ((classifier instanceof Drawable) && (printGraph)){
      return ((Drawable)classifier).graph();
    }

    // Output the classifier as equivalent source
    if ((classifier instanceof Sourcable) && (printSource)){
      return wekaStaticWrapper((Sourcable) classifier, sourceClass);
    }

    // Output model
    if (!(noOutput || printMargins)) {
      if (classifier instanceof OptionHandler) {
        if (schemeOptionsText != null) {
          text.append("\nOptions: "+schemeOptionsText);
          text.append("\n");
        }
      }
      text.append("\n" + classifier.toString() + "\n");
    }

    if (!printMargins && (costMatrix != null)) {
      text.append("\n=== Evaluation Cost Matrix ===\n\n");
      text.append(costMatrix.toString());
    }

    // Output test instance predictions only
    if (classificationOutput != null) {
      DataSource source = testSource;
      predsBuff = new StringBuffer();
      classificationOutput.setBuffer(predsBuff);
      // no test set -> use train set
      if (source == null && noCrossValidation) {
        source = trainSource;
View Full Code Here


  public void run() {
    // Copy the current state of things
    m_Log.statusMessage("Setting up...");
    CostMatrix costMatrix = null;
    Instances inst = new Instances(m_Instances);
    DataSource source = null;
          Instances userTestStructure = null;
    ClassifierErrorsPlotInstances plotInstances = null;
   
    // for timing
    long trainTimeStart = 0, trainTimeElapsed = 0;

          try {
            if (m_TestLoader != null && m_TestLoader.getStructure() != null) {
              m_TestLoader.reset();
              source = new DataSource(m_TestLoader);
              userTestStructure = source.getStructure();
              userTestStructure.setClassIndex(m_TestClassIndex);
            }
          } catch (Exception ex) {
            ex.printStackTrace();
          }
    if (m_EvalWRTCostsBut.isSelected()) {
      costMatrix = new CostMatrix((CostMatrix) m_CostMatrixEditor
          .getValue());
    }
    boolean outputModel = m_OutputModelBut.isSelected();
    boolean outputConfusion = m_OutputConfusionBut.isSelected();
    boolean outputPerClass = m_OutputPerClassBut.isSelected();
    boolean outputSummary = true;
          boolean outputEntropy = m_OutputEntropyBut.isSelected();
    boolean saveVis = m_StorePredictionsBut.isSelected();
    boolean outputPredictionsText = (m_ClassificationOutputEditor.getValue().getClass() != Null.class);

    String grph = null;

    int testMode = 0;
    int numFolds = 10;
          double percent = 66;
    int classIndex = m_ClassCombo.getSelectedIndex();
    inst.setClassIndex(classIndex);
    Classifier classifier = (Classifier) m_ClassifierEditor.getValue();
    Classifier template = null;
    try {
      template = AbstractClassifier.makeCopy(classifier);
    } catch (Exception ex) {
      m_Log.logMessage("Problem copying classifier: " + ex.getMessage());
    }
    Classifier fullClassifier = null;
    StringBuffer outBuff = new StringBuffer();
    AbstractOutput classificationOutput = null;
    if (outputPredictionsText) {
      classificationOutput = (AbstractOutput) m_ClassificationOutputEditor.getValue();
      Instances header = new Instances(inst, 0);
      header.setClassIndex(classIndex);
      classificationOutput.setHeader(header);
      classificationOutput.setBuffer(outBuff);
    }
    String name = (new SimpleDateFormat("HH:mm:ss - ")).format(new Date());
    String cname = "";
          String cmd = "";
    Evaluation eval = null;
    try {
      if (m_CVBut.isSelected()) {
        testMode = 1;
        numFolds = Integer.parseInt(m_CVText.getText());
        if (numFolds <= 1) {
    throw new Exception("Number of folds must be greater than 1");
        }
      } else if (m_PercentBut.isSelected()) {
        testMode = 2;
        percent = Double.parseDouble(m_PercentText.getText());
        if ((percent <= 0) || (percent >= 100)) {
    throw new Exception("Percentage must be between 0 and 100");
        }
      } else if (m_TrainBut.isSelected()) {
        testMode = 3;
      } else if (m_TestSplitBut.isSelected()) {
        testMode = 4;
        // Check the test instance compatibility
        if (source == null) {
          throw new Exception("No user test set has been specified");
        }
       
        if (!(classifier instanceof weka.classifiers.misc.InputMappedClassifier)) {
          if (!inst.equalHeaders(userTestStructure)) {
            boolean wrapClassifier = false;
            if (!Utils.
                getDontShowDialog("weka.gui.explorer.ClassifierPanel.AutoWrapInInputMappedClassifier")) {
              JCheckBox dontShow = new JCheckBox("Do not show this message again");
              Object[] stuff = new Object[2];
              stuff[0] = "Train and test set are not compatible.\n" +
              "Would you like to automatically wrap the classifier in\n" +
              "an \"InputMappedClassifier\" before proceeding?.\n";
              stuff[1] = dontShow;

              int result = JOptionPane.showConfirmDialog(ClassifierPanel.this, stuff,
                  "ClassifierPanel", JOptionPane.YES_OPTION);
             
              if (result == JOptionPane.YES_OPTION) {
                wrapClassifier = true;
              }
             
              if (dontShow.isSelected()) {
                String response = (wrapClassifier) ? "yes" : "no";
                Utils.
                  setDontShowDialogResponse("weka.gui.explorer.ClassifierPanel.AutoWrapInInputMappedClassifier",
                      response);
              }

            } else {
              // What did the user say - do they want to autowrap or not?
              String response =
                Utils.getDontShowDialogResponse("weka.gui.explorer.ClassifierPanel.AutoWrapInInputMappedClassifier");
              if (response != null && response.equalsIgnoreCase("yes")) {
                wrapClassifier = true;
              }
            }

            if (wrapClassifier) {
              weka.classifiers.misc.InputMappedClassifier temp =
                new weka.classifiers.misc.InputMappedClassifier();

              // pass on the known test structure so that we get the
              // correct mapping report from the toString() method
              // of InputMappedClassifier
              temp.setClassifier(classifier);
              temp.setTestStructure(userTestStructure);
              classifier = temp;
            } else {
              throw new Exception("Train and test set are not compatible\n" + inst.equalHeadersMsg(userTestStructure));
            }
          }
        }
             
      } else {
        throw new Exception("Unknown test mode");
      }

      cname = classifier.getClass().getName();
      if (cname.startsWith("weka.classifiers.")) {
        name += cname.substring("weka.classifiers.".length());
      } else {
        name += cname;
      }
      cmd = classifier.getClass().getName();
      if (classifier instanceof OptionHandler)
        cmd += " " + Utils.joinOptions(((OptionHandler) classifier).getOptions());
     
      // set up the structure of the plottable instances for
      // visualization
      plotInstances = ExplorerDefaults.getClassifierErrorsPlotInstances();
      plotInstances.setInstances(inst);
      plotInstances.setClassifier(classifier);
      plotInstances.setClassIndex(inst.classIndex());
      plotInstances.setSaveForVisualization(saveVis);

      // Output some header information
      m_Log.logMessage("Started " + cname);
      m_Log.logMessage("Command: " + cmd);
      if (m_Log instanceof TaskLogger) {
        ((TaskLogger)m_Log).taskStarted();
      }
      outBuff.append("=== Run information ===\n\n");
      outBuff.append("Scheme:       " + cname);
      if (classifier instanceof OptionHandler) {
        String [] o = ((OptionHandler) classifier).getOptions();
        outBuff.append(" " + Utils.joinOptions(o));
      }
      outBuff.append("\n");
      outBuff.append("Relation:     " + inst.relationName() + '\n');
      outBuff.append("Instances:    " + inst.numInstances() + '\n');
      outBuff.append("Attributes:   " + inst.numAttributes() + '\n');
      if (inst.numAttributes() < 100) {
        for (int i = 0; i < inst.numAttributes(); i++) {
    outBuff.append("              " + inst.attribute(i).name()
             + '\n');
        }
      } else {
        outBuff.append("              [list of attributes omitted]\n");
      }

      outBuff.append("Test mode:    ");
      switch (testMode) {
        case 3: // Test on training
    outBuff.append("evaluate on training data\n");
    break;
        case 1: // CV mode
    outBuff.append("" + numFolds + "-fold cross-validation\n");
    break;
        case 2: // Percent split
    outBuff.append("split " + percent
        + "% train, remainder test\n");
    break;
        case 4: // Test on user split
    if (source.isIncremental())
      outBuff.append("user supplied test set: "
          + " size unknown (reading incrementally)\n");
    else
      outBuff.append("user supplied test set: "
          + source.getDataSet().numInstances() + " instances\n");
    break;
      }
            if (costMatrix != null) {
               outBuff.append("Evaluation cost matrix:\n")
               .append(costMatrix.toString()).append("\n");
            }
      outBuff.append("\n");
      m_History.addResult(name, outBuff);
      m_History.setSingle(name);
     
      // Build the model and output it.
      if (outputModel || (testMode == 3) || (testMode == 4)) {
        m_Log.statusMessage("Building model on training data...");

        trainTimeStart = System.currentTimeMillis();
        classifier.buildClassifier(inst);
        trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
      }

      if (outputModel) {
        outBuff.append("=== Classifier model (full training set) ===\n\n");
        outBuff.append(classifier.toString() + "\n");
        outBuff.append("\nTime taken to build model: " +
           Utils.doubleToString(trainTimeElapsed / 1000.0,2)
           + " seconds\n\n");
        m_History.updateResult(name);
        if (classifier instanceof Drawable) {
    grph = null;
    try {
      grph = ((Drawable)classifier).graph();
    } catch (Exception ex) {
    }
        }
        // copy full model for output
        SerializedObject so = new SerializedObject(classifier);
        fullClassifier = (Classifier) so.getObject();
      }
     
      switch (testMode) {
        case 3: // Test on training
        m_Log.statusMessage("Evaluating on training data...");
        eval = new Evaluation(inst, costMatrix);
       
        // make adjustments if the classifier is an InputMappedClassifier
        eval = setupEval(eval, classifier, inst, costMatrix,
            plotInstances, classificationOutput, false);
       
        //plotInstances.setEvaluation(eval);
              plotInstances.setUp();
       
        if (outputPredictionsText) {
    printPredictionsHeader(outBuff, classificationOutput, "training set");
        }

        for (int jj=0;jj<inst.numInstances();jj++) {
    plotInstances.process(inst.instance(jj), classifier, eval);
   
    if (outputPredictionsText) {
      classificationOutput.printClassification(classifier, inst.instance(jj), jj);
    }
    if ((jj % 100) == 0) {
      m_Log.statusMessage("Evaluating on training data. Processed "
              +jj+" instances...");
    }
        }
        if (outputPredictionsText)
    classificationOutput.printFooter();
        if (outputPredictionsText && classificationOutput.generatesOutput()) {
    outBuff.append("\n");
        }
        outBuff.append("=== Evaluation on training set ===\n");
        break;

        case 1: // CV mode
        m_Log.statusMessage("Randomizing instances...");
        int rnd = 1;
        try {
    rnd = Integer.parseInt(m_RandomSeedText.getText().trim());
    // System.err.println("Using random seed "+rnd);
        } catch (Exception ex) {
    m_Log.logMessage("Trouble parsing random seed value");
    rnd = 1;
        }
        Random random = new Random(rnd);
        inst.randomize(random);
        if (inst.attribute(classIndex).isNominal()) {
    m_Log.statusMessage("Stratifying instances...");
    inst.stratify(numFolds);
        }
        eval = new Evaluation(inst, costMatrix);
       
         // make adjustments if the classifier is an InputMappedClassifier
              eval = setupEval(eval, classifier, inst, costMatrix,
                  plotInstances, classificationOutput, false);
       
//        plotInstances.setEvaluation(eval);
              plotInstances.setUp();
     
        if (outputPredictionsText) {
    printPredictionsHeader(outBuff, classificationOutput, "test data");
        }

        // Make some splits and do a CV
        for (int fold = 0; fold < numFolds; fold++) {
    m_Log.statusMessage("Creating splits for fold "
            + (fold + 1) + "...");
    Instances train = inst.trainCV(numFolds, fold, random);
   
    // make adjustments if the classifier is an InputMappedClassifier
          eval = setupEval(eval, classifier, train, costMatrix,
              plotInstances, classificationOutput, true);
         
//    eval.setPriors(train);
    m_Log.statusMessage("Building model for fold "
            + (fold + 1) + "...");
    Classifier current = null;
    try {
      current = AbstractClassifier.makeCopy(template);
    } catch (Exception ex) {
      m_Log.logMessage("Problem copying classifier: " + ex.getMessage());
    }
    current.buildClassifier(train);
    Instances test = inst.testCV(numFolds, fold);
    m_Log.statusMessage("Evaluating model for fold "
            + (fold + 1) + "...");
    for (int jj=0;jj<test.numInstances();jj++) {
      plotInstances.process(test.instance(jj), current, eval);
      if (outputPredictionsText) {
        classificationOutput.printClassification(current, test.instance(jj), jj);
      }
    }
        }
        if (outputPredictionsText)
    classificationOutput.printFooter();
        if (outputPredictionsText) {
    outBuff.append("\n");
        }
        if (inst.attribute(classIndex).isNominal()) {
    outBuff.append("=== Stratified cross-validation ===\n");
        } else {
    outBuff.append("=== Cross-validation ===\n");
        }
        break;
   
        case 2: // Percent split
        if (!m_PreserveOrderBut.isSelected()) {
    m_Log.statusMessage("Randomizing instances...");
    try {
      rnd = Integer.parseInt(m_RandomSeedText.getText().trim());
    } catch (Exception ex) {
      m_Log.logMessage("Trouble parsing random seed value");
      rnd = 1;
    }
    inst.randomize(new Random(rnd));
        }
        int trainSize = (int) Math.round(inst.numInstances() * percent / 100);
        int testSize = inst.numInstances() - trainSize;
        Instances train = new Instances(inst, 0, trainSize);
        Instances test = new Instances(inst, trainSize, testSize);
        m_Log.statusMessage("Building model on training split ("+trainSize+" instances)...");
        Classifier current = null;
        try {
    current = AbstractClassifier.makeCopy(template);
        } catch (Exception ex) {
    m_Log.logMessage("Problem copying classifier: " + ex.getMessage());
        }
        current.buildClassifier(train);
        eval = new Evaluation(train, costMatrix);
       
        // make adjustments if the classifier is an InputMappedClassifier
              eval = setupEval(eval, classifier, train, costMatrix,
                  plotInstances, classificationOutput, false);
                     
//        plotInstances.setEvaluation(eval);
              plotInstances.setUp();
        m_Log.statusMessage("Evaluating on test split...");
      
        if (outputPredictionsText) {
    printPredictionsHeader(outBuff, classificationOutput, "test split");
        }
    
        for (int jj=0;jj<test.numInstances();jj++) {
    plotInstances.process(test.instance(jj), current, eval);
    if (outputPredictionsText) {
      classificationOutput.printClassification(current, test.instance(jj), jj);
    }
    if ((jj % 100) == 0) {
      m_Log.statusMessage("Evaluating on test split. Processed "
              +jj+" instances...");
    }
        }
        if (outputPredictionsText)
    classificationOutput.printFooter();
        if (outputPredictionsText) {
    outBuff.append("\n");
        }
        outBuff.append("=== Evaluation on test split ===\n");
        break;
   
        case 4: // Test on user split
        m_Log.statusMessage("Evaluating on test data...");
        eval = new Evaluation(inst, costMatrix);
        // make adjustments if the classifier is an InputMappedClassifier
              eval = setupEval(eval, classifier, inst, costMatrix,
                  plotInstances, classificationOutput, false);
             
//        plotInstances.setEvaluation(eval);
              plotInstances.setUp();
       
        if (outputPredictionsText) {
    printPredictionsHeader(outBuff, classificationOutput, "test set");
        }

        Instance instance;
        int jj = 0;
        while (source.hasMoreElements(userTestStructure)) {
    instance = source.nextElement(userTestStructure);
    plotInstances.process(instance, classifier, eval);
    if (outputPredictionsText) {
      classificationOutput.printClassification(classifier, instance, jj);
    }
    if ((++jj % 100) == 0) {
View Full Code Here

            // Copy the current state of things
            m_Log.statusMessage("Setting up...");
            Classifier classifierToUse = classifier;

            StringBuffer outBuff = m_History.getNamedBuffer(name);
            DataSource source = null;
            Instances userTestStructure = null;
            ClassifierErrorsPlotInstances plotInstances = null;

            CostMatrix costMatrix = null;
            if (m_EvalWRTCostsBut.isSelected()) {
              costMatrix = new CostMatrix((CostMatrix) m_CostMatrixEditor
                                          .getValue());
            }   
            boolean outputConfusion = m_OutputConfusionBut.isSelected();
            boolean outputPerClass = m_OutputPerClassBut.isSelected();
            boolean outputSummary = true;
            boolean outputEntropy = m_OutputEntropyBut.isSelected();
            boolean saveVis = m_StorePredictionsBut.isSelected();
            boolean outputPredictionsText = (m_ClassificationOutputEditor.getValue().getClass() != Null.class);
            String grph = null;   
            Evaluation eval = null;

            try {

              boolean incrementalLoader = (m_TestLoader instanceof IncrementalConverter);
              if (m_TestLoader != null && m_TestLoader.getStructure() != null) {
                m_TestLoader.reset();
                source = new DataSource(m_TestLoader);
                userTestStructure = source.getStructure();
                userTestStructure.setClassIndex(m_TestClassIndex);
              }
              // Check the test instance compatibility
              if (source == null) {
                throw new Exception("No user test set has been specified");
              }
              if (trainHeader != null) {
                boolean compatibilityProblem = false;
                if (trainHeader.classIndex() >
                    userTestStructure.numAttributes()-1) {
                  compatibilityProblem = true;
                  //throw new Exception("Train and test set are not compatible");
                }
                userTestStructure.setClassIndex(trainHeader.classIndex());
                if (!trainHeader.equalHeaders(userTestStructure)) {
                  compatibilityProblem = true;
                  // throw new Exception("Train and test set are not compatible:\n" + trainHeader.equalHeadersMsg(userTestStructure));
                 
                  if (compatibilityProblem &&
                      !(classifierToUse instanceof weka.classifiers.misc.InputMappedClassifier)) {

                    boolean wrapClassifier = false;
                    if (!Utils.
                        getDontShowDialog("weka.gui.explorer.ClassifierPanel.AutoWrapInInputMappedClassifier")) {
                      JCheckBox dontShow = new JCheckBox("Do not show this message again");
                      Object[] stuff = new Object[2];
                      stuff[0] = "Data used to train model and test set are not compatible.\n" +
                      "Would you like to automatically wrap the classifier in\n" +
                      "an \"InputMappedClassifier\" before proceeding?.\n";
                      stuff[1] = dontShow;

                      int result = JOptionPane.showConfirmDialog(ClassifierPanel.this, stuff,
                          "ClassifierPanel", JOptionPane.YES_OPTION);
                     
                      if (result == JOptionPane.YES_OPTION) {
                        wrapClassifier = true;
                      }
                     
                      if (dontShow.isSelected()) {
                        String response = (wrapClassifier) ? "yes" : "no";
                        Utils.
                          setDontShowDialogResponse("weka.gui.explorer.ClassifierPanel.AutoWrapInInputMappedClassifier",
                              response);
                      }

                    } else {
                      // What did the user say - do they want to autowrap or not?
                      String response =
                        Utils.getDontShowDialogResponse("weka.gui.explorer.ClassifierPanel.AutoWrapInInputMappedClassifier");
                      if (response != null && response.equalsIgnoreCase("yes")) {
                        wrapClassifier = true;
                      }
                    }

                    if (wrapClassifier) {
                      weka.classifiers.misc.InputMappedClassifier temp =
                        new weka.classifiers.misc.InputMappedClassifier();

                      temp.setClassifier(classifierToUse);
                      temp.setModelHeader(trainHeader);
                      classifierToUse = temp;
                    } else {
                      throw new Exception("Train and test set are not compatible\n" +
                          trainHeader.equalHeadersMsg(userTestStructure));
                    }
                  }
                }
              } else {
          if (classifierToUse instanceof PMMLClassifier) {
            // set the class based on information in the mining schema
            Instances miningSchemaStructure =
              ((PMMLClassifier)classifierToUse).getMiningSchema().getMiningSchemaAsInstances();
            String className = miningSchemaStructure.classAttribute().name();
            Attribute classMatch = userTestStructure.attribute(className);
            if (classMatch == null) {
              throw new Exception("Can't find a match for the PMML target field "
            + className + " in the "
            + "test instances!");
            }
            userTestStructure.setClass(classMatch);
          } else {
            userTestStructure.
              setClassIndex(userTestStructure.numAttributes()-1);
          }
              }
              if (m_Log instanceof TaskLogger) {
                ((TaskLogger)m_Log).taskStarted();
              }
              m_Log.statusMessage("Evaluating on test data...");
              m_Log.logMessage("Re-evaluating classifier (" + name
                               + ") on test set");
              eval = new Evaluation(userTestStructure, costMatrix);
     
              // set up the structure of the plottable instances for
              // visualization if selected
              if (saveVis) {
          plotInstances = new ClassifierErrorsPlotInstances();
          plotInstances.setInstances(userTestStructure);
          plotInstances.setClassifier(classifierToUse);
          plotInstances.setClassIndex(userTestStructure.classIndex());
          plotInstances.setEvaluation(eval);
          plotInstances.setUp();
              }
             
     
              outBuff.append("\n=== Re-evaluation on test set ===\n\n");
              outBuff.append("User supplied test set\n")
              outBuff.append("Relation:     "
                             + userTestStructure.relationName() + '\n');
              if (incrementalLoader)
          outBuff.append("Instances:     unknown (yet). Reading incrementally\n");
              else
          outBuff.append("Instances:    " + source.getDataSet().numInstances() + "\n");
              outBuff.append("Attributes:   "
            + userTestStructure.numAttributes()
            + "\n\n");
              if (trainHeader == null &&
                  !(classifierToUse instanceof
                      weka.classifiers.pmml.consumer.PMMLClassifier)) {
                outBuff.append("NOTE - if test set is not compatible then results are "
                               + "unpredictable\n\n");
              }

              AbstractOutput classificationOutput = null;
              if (outputPredictionsText) {
          classificationOutput = (AbstractOutput) m_ClassificationOutputEditor.getValue();
          classificationOutput.setHeader(userTestStructure);
          classificationOutput.setBuffer(outBuff);
/*          classificationOutput.setAttributes("");
          classificationOutput.setOutputDistribution(false);*/
//          classificationOutput.printHeader();         
              }
             
              // make adjustments if the classifier is an InputMappedClassifier
              eval = setupEval(eval, classifierToUse, userTestStructure, costMatrix,
                  plotInstances, classificationOutput, false);
              eval.useNoPriors();
             
              if (outputPredictionsText) {
                printPredictionsHeader(outBuff, classificationOutput, "user test set");
              }

        Instance instance;
        int jj = 0;
        while (source.hasMoreElements(userTestStructure)) {
    instance = source.nextElement(userTestStructure);
    plotInstances.process(instance, classifierToUse, eval);
    if (outputPredictionsText) {
      classificationOutput.printClassification(classifierToUse, instance, jj);
    }
    if ((++jj % 100) == 0) {
View Full Code Here

    // TODO: false => !config.arffOverwrite()
    file = new BufferedWriter(new FileWriter(path + player + name, false));
    file.write(attributes);
    file.flush();
   
    DataSource source = new DataSource(path + player + name);
      instances = source.getDataSet();
      // make it clean
      instances.delete();
     
      predictions = new ArrayList<Prediction>();
     
View Full Code Here

//    System.out.println("Creating model for " + player + name);
    Instances data;
    if (config.solveConceptDrift())
      data = instances;
    else {
      DataSource source = new DataSource(path + player + name);
        data = source.getDataSet();
    }
      if (rmAttributes.length > 0) {
        String[] optionsDel = new String[2];
      optionsDel[0] = "-R";                                  
      optionsDel[1] = "";
View Full Code Here

  private Map<ClusterClass, List<StoredDomainCluster>> executeClassifier(String prepfeatures, String modelfile,
      List<StoredDomainCluster> clusters){
    Map<ClusterClass, List<StoredDomainCluster>> retval =
        new HashMap<ClusterClass, List<StoredDomainCluster>>();
    try{
      DataSource source = new DataSource(new ByteArrayInputStream(prepfeatures.getBytes()));
      Instances data = source.getDataSet();
      if (data.classIndex() == -1){
        data.setClassIndex(data.numAttributes() - 1);
      }
      String[] options = weka.core.Utils.splitOptions("-p 0");
      J48 cls = (J48)weka.core.SerializationHelper.read(modelfile);
View Full Code Here

    ReputationGraph repGraph = new ReputationGraph(new ReputationEdgeFactory());
   
    Util.assertNotNull(arffFileName);
    Util.assertFileExists(arffFileName);
   
    DataSource source;
    try
    {
      source = new DataSource(arffFileName);
      Instances instances = source.getDataSet();
      logger.debug("Number of instances in arff file is " + instances.numInstances());
     
      Enumeration enu = instances.enumerateInstances();
      //get all the feedback lines
     
View Full Code Here

 
  public static void testbedARFFfromFile() throws Exception
  {
    Attribute groupID = new Attribute("groupID", (FastVector) null);
   
    DataSource source = new DataSource("testMe.arff");
    Instances instances = source.getDataSet();
    //System.out.println(instances);
    System.out.println(instances.instance(0));
   
       
    Enumeration enu = instances.enumerateInstances();
View Full Code Here

    TrustGraph trustGraph = new TrustGraph(new TrustEdgeFactory());
   
    Util.assertNotNull(arffFileName);
    Util.assertFileExists(arffFileName);
   
    DataSource source;
    try
    {
      source = new DataSource(arffFileName);
      Instances instances = source.getDataSet();
      logger.debug("Number of instances in arff file is " + instances.numInstances());
     
      Enumeration enu = instances.enumerateInstances();
      //get all the feedback lines
     
View Full Code Here

    {
      throw new Exception("Incomplete properties file. Check documentation.");
    }
   
    //parse and create strategies
    DataSource source = new DataSource(strFileName);
    Instances instances = source.getDataSet();
    Enumeration enu = instances.enumerateInstances();
    while(enu.hasMoreElements())
    {
      Instance temp = (Instance)enu.nextElement();
      logger.info("Parsing " + temp);
      if(temp.numValues()!=4) throw new Exception("Strategy line does not have 3 elements. This is illegal.");
      Double[] strategyInstance = new Double[4];
      for(int i=0;i<temp.numValues();i++)
      {
        //number of values == 4
        strategyInstance[i] = temp.value(i);     
        if(strategyInstance[i]==null) throw new Exception("A parameter in strategy line is null.");
      }
      /*
       * @relaton strategy
       * @attribute strategyId NUMERIC
       * @attribute numberOfFeedbacks NUMERIC
       * @attribute feedbackMean NUMERIC
       * @attribute feedbackStdDev NUMERIC
       */
     
      //TODO do checks to make sure feedbackMean and feedbackStdDev are in [0,1].
      Strategy s = new Strategy(strategyInstance[0].intValue(),
          strategyInstance[1].intValue(), strategyInstance[2], strategyInstance[3]);
     
      strategies.add(s);
    }
   
    //parse and create groups
    source = new DataSource(grpFileName);
    instances = source.getDataSet();
    enu = instances.enumerateInstances();
    while(enu.hasMoreElements())
    {
      Instance temp = (Instance)enu.nextElement();
      logger.info("Parsing " + temp);
      if(temp.numValues()!=2) throw new Exception("Group line does not have 2 elements. This is illegal.");
      Double[] groupInstance = new Double[2];
      for(int i=0;i<temp.numValues();i++)
      {
        //number of values == 2
        groupInstance[i] = temp.value(i);   
        if(groupInstance[i]==null) throw new Exception("A parameter in group line is null.");
      }
      /*
       * @relation group
       * @attribute groupdId NUMERIC
       * @attribute numberAgents NUMERIC
       */
      Group g = new Group(groupInstance[0].intValue(), groupInstance[1].intValue());
      groups.add(g);
    }
   
    //parse and create group-strategy assignments
    source = new DataSource(strGrpFileName);
    instances = source.getDataSet();
    enu = instances.enumerateInstances();
    while(enu.hasMoreElements())
    {
      Instance temp = (Instance)enu.nextElement();
      logger.info("Parsing " + temp);
View Full Code Here

TOP

Related Classes of weka.core.converters.ConverterUtils.DataSource

Copyright © 2018 www.massapicom. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.