Package weka.classifiers

Examples of weka.classifiers.Classifier


   *
   * @param args optional commandline parameters
   */
  public static void main(String [] args) {
    Instances     inst;
    Classifier     classifier;
    int      runs;
    int      folds;
    String     tmpStr;
    boolean    compute;
    Instances     result;
View Full Code Here


    int testMode = 0;
    int numFolds = 10;
          double percent = 66;
    int classIndex = m_ClassCombo.getSelectedIndex();
    Classifier classifier = (Classifier) m_ClassifierEditor.getValue();
    Classifier template = null;
    try {
      template = Classifier.makeCopy(classifier);
    } catch (Exception ex) {
      m_Log.logMessage("Problem copying classifier: " + ex.getMessage());
    }
    Classifier fullClassifier = null;
    StringBuffer outBuff = new StringBuffer();
    String name = (new SimpleDateFormat("HH:mm:ss - "))
    .format(new Date());
    String cname = classifier.getClass().getName();
    if (cname.startsWith("weka.classifiers.")) {
      name += cname.substring("weka.classifiers.".length());
    } else {
      name += cname;
    }
          String cmd = m_ClassifierEditor.getValue().getClass().getName();
          if (m_ClassifierEditor.getValue() instanceof OptionHandler)
            cmd += " " + Utils.joinOptions(((OptionHandler) m_ClassifierEditor.getValue()).getOptions());
    Evaluation eval = null;
    try {
      if (m_CVBut.isSelected()) {
        testMode = 1;
        numFolds = Integer.parseInt(m_CVText.getText());
        if (numFolds <= 1) {
    throw new Exception("Number of folds must be greater than 1");
        }
      } else if (m_PercentBut.isSelected()) {
        testMode = 2;
        percent = Double.parseDouble(m_PercentText.getText());
        if ((percent <= 0) || (percent >= 100)) {
    throw new Exception("Percentage must be between 0 and 100");
        }
      } else if (m_TrainBut.isSelected()) {
        testMode = 3;
      } else if (m_TestSplitBut.isSelected()) {
        testMode = 4;
        // Check the test instance compatibility
        if (source == null) {
    throw new Exception("No user test set has been specified");
        }
        if (!inst.equalHeaders(userTestStructure)) {
    throw new Exception("Train and test set are not compatible");
        }
              userTestStructure.setClassIndex(classIndex);
      } else {
        throw new Exception("Unknown test mode");
      }
      inst.setClassIndex(classIndex);

      // set up the structure of the plottable instances for
      // visualization
            if (saveVis) {
              predInstances = setUpVisualizableInstances(inst);
              predInstances.setClassIndex(inst.classIndex()+1);
            }

      // Output some header information
      m_Log.logMessage("Started " + cname);
      m_Log.logMessage("Command: " + cmd);
      if (m_Log instanceof TaskLogger) {
        ((TaskLogger)m_Log).taskStarted();
      }
      outBuff.append("=== Run information ===\n\n");
      outBuff.append("Scheme:       " + cname);
      if (classifier instanceof OptionHandler) {
        String [] o = ((OptionHandler) classifier).getOptions();
        outBuff.append(" " + Utils.joinOptions(o));
      }
      outBuff.append("\n");
      outBuff.append("Relation:     " + inst.relationName() + '\n');
      outBuff.append("Instances:    " + inst.numInstances() + '\n');
      outBuff.append("Attributes:   " + inst.numAttributes() + '\n');
      if (inst.numAttributes() < 100) {
        for (int i = 0; i < inst.numAttributes(); i++) {
    outBuff.append("              " + inst.attribute(i).name()
             + '\n');
        }
      } else {
        outBuff.append("              [list of attributes omitted]\n");
      }

      outBuff.append("Test mode:    ");
      switch (testMode) {
        case 3: // Test on training
    outBuff.append("evaluate on training data\n");
    break;
        case 1: // CV mode
    outBuff.append("" + numFolds + "-fold cross-validation\n");
    break;
        case 2: // Percent split
    outBuff.append("split " + percent
        + "% train, remainder test\n");
    break;
        case 4: // Test on user split
    if (source.isIncremental())
      outBuff.append("user supplied test set: "
          + " size unknown (reading incrementally)\n");
    else
      outBuff.append("user supplied test set: "
          + source.getDataSet().numInstances() + " instances\n");
    break;
      }
            if (costMatrix != null) {
               outBuff.append("Evaluation cost matrix:\n")
               .append(costMatrix.toString()).append("\n");
            }
      outBuff.append("\n");
      m_History.addResult(name, outBuff);
      m_History.setSingle(name);
     
      // Build the model and output it.
      if (outputModel || (testMode == 3) || (testMode == 4)) {
        m_Log.statusMessage("Building model on training data...");

        trainTimeStart = System.currentTimeMillis();
        classifier.buildClassifier(inst);
        trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
      }

      if (outputModel) {
        outBuff.append("=== Classifier model (full training set) ===\n\n");
        outBuff.append(classifier.toString() + "\n");
        outBuff.append("\nTime taken to build model: " +
           Utils.doubleToString(trainTimeElapsed / 1000.0,2)
           + " seconds\n\n");
        m_History.updateResult(name);
        if (classifier instanceof Drawable) {
    grph = null;
    try {
      grph = ((Drawable)classifier).graph();
    } catch (Exception ex) {
    }
        }
        // copy full model for output
        SerializedObject so = new SerializedObject(classifier);
        fullClassifier = (Classifier) so.getObject();
      }
     
      switch (testMode) {
        case 3: // Test on training
        m_Log.statusMessage("Evaluating on training data...");
        eval = new Evaluation(inst, costMatrix);
       
        if (outputPredictionsText) {
    printPredictionsHeader(outBuff, inst, "training set");
        }

        for (int jj=0;jj<inst.numInstances();jj++) {
    processClassifierPrediction(inst.instance(jj), classifier,
              eval, predInstances, plotShape,
              plotSize);
   
    if (outputPredictionsText) {
      outBuff.append(predictionText(classifier, inst.instance(jj), jj+1));
    }
    if ((jj % 100) == 0) {
      m_Log.statusMessage("Evaluating on training data. Processed "
              +jj+" instances...");
    }
        }
        if (outputPredictionsText) {
    outBuff.append("\n");
        }
        outBuff.append("=== Evaluation on training set ===\n");
        break;

        case 1: // CV mode
        m_Log.statusMessage("Randomizing instances...");
        int rnd = 1;
        try {
    rnd = Integer.parseInt(m_RandomSeedText.getText().trim());
    // System.err.println("Using random seed "+rnd);
        } catch (Exception ex) {
    m_Log.logMessage("Trouble parsing random seed value");
    rnd = 1;
        }
        Random random = new Random(rnd);
        inst.randomize(random);
        if (inst.attribute(classIndex).isNominal()) {
    m_Log.statusMessage("Stratifying instances...");
    inst.stratify(numFolds);
        }
        eval = new Evaluation(inst, costMatrix);
     
        if (outputPredictionsText) {
    printPredictionsHeader(outBuff, inst, "test data");
        }

        // Make some splits and do a CV
        for (int fold = 0; fold < numFolds; fold++) {
    m_Log.statusMessage("Creating splits for fold "
            + (fold + 1) + "...");
    Instances train = inst.trainCV(numFolds, fold, random);
    eval.setPriors(train);
    m_Log.statusMessage("Building model for fold "
            + (fold + 1) + "...");
    Classifier current = null;
    try {
      current = Classifier.makeCopy(template);
    } catch (Exception ex) {
      m_Log.logMessage("Problem copying classifier: " + ex.getMessage());
    }
    current.buildClassifier(train);
    Instances test = inst.testCV(numFolds, fold);
    m_Log.statusMessage("Evaluating model for fold "
            + (fold + 1) + "...");
    for (int jj=0;jj<test.numInstances();jj++) {
      processClassifierPrediction(test.instance(jj), current,
                eval, predInstances, plotShape,
                plotSize);
      if (outputPredictionsText) {
        outBuff.append(predictionText(current, test.instance(jj), jj+1));
      }
    }
        }
        if (outputPredictionsText) {
    outBuff.append("\n");
        }
        if (inst.attribute(classIndex).isNominal()) {
    outBuff.append("=== Stratified cross-validation ===\n");
        } else {
    outBuff.append("=== Cross-validation ===\n");
        }
        break;
   
        case 2: // Percent split
        if (!m_PreserveOrderBut.isSelected()) {
    m_Log.statusMessage("Randomizing instances...");
    try {
      rnd = Integer.parseInt(m_RandomSeedText.getText().trim());
    } catch (Exception ex) {
      m_Log.logMessage("Trouble parsing random seed value");
      rnd = 1;
    }
    inst.randomize(new Random(rnd));
        }
        int trainSize = (int) Math.round(inst.numInstances() * percent / 100);
        int testSize = inst.numInstances() - trainSize;
        Instances train = new Instances(inst, 0, trainSize);
        Instances test = new Instances(inst, trainSize, testSize);
        m_Log.statusMessage("Building model on training split ("+trainSize+" instances)...");
        Classifier current = null;
        try {
    current = Classifier.makeCopy(template);
        } catch (Exception ex) {
    m_Log.logMessage("Problem copying classifier: " + ex.getMessage());
        }
        current.buildClassifier(train);
        eval = new Evaluation(train, costMatrix);
        m_Log.statusMessage("Evaluating on test split...");
      
        if (outputPredictionsText) {
    printPredictionsHeader(outBuff, inst, "test split");
View Full Code Here

    VisualizePanel temp_vp = null;
    String temp_grph = null;
    FastVector temp_preds = null;
    Attribute temp_classAtt = null;
    Classifier temp_classifier = null;
    Instances temp_trainHeader = null;
     
    if (o != null) {
      for (int i = 0; i < o.size(); i++) {
  Object temp = o.elementAt(i);
  if (temp instanceof Classifier) {
    temp_classifier = (Classifier)temp;
  } else if (temp instanceof Instances) { // training header
    temp_trainHeader = (Instances)temp;
  } else if (temp instanceof VisualizePanel) { // normal errors
    temp_vp = (VisualizePanel)temp;
  } else if (temp instanceof String) { // graphable output
    temp_grph = (String)temp;
  } else if (temp instanceof FastVector) { // predictions
    temp_preds = (FastVector)temp;
  } else if (temp instanceof Attribute) { // class attribute
    temp_classAtt = (Attribute)temp;
  }
      }
    }

    final VisualizePanel vp = temp_vp;
    final String grph = temp_grph;
    final FastVector preds = temp_preds;
    final Attribute classAtt = temp_classAtt;
    final Classifier classifier = temp_classifier;
    final Instances trainHeader = temp_trainHeader;
   
    JMenuItem saveModel = new JMenuItem("Save model");
    if (classifier != null) {
      saveModel.addActionListener(new ActionListener() {
View Full Code Here

  protected void loadClassifier() {

    int returnVal = m_FileChooser.showOpenDialog(this);
    if (returnVal == JFileChooser.APPROVE_OPTION) {
      File selected = m_FileChooser.getSelectedFile();
      Classifier classifier = null;
      Instances trainHeader = null;

      m_Log.statusMessage("Loading model from file...");

      try {
  InputStream is = new FileInputStream(selected);
  if (selected.getName().endsWith(".gz")) {
    is = new GZIPInputStream(is);
  }
  ObjectInputStream objectInputStream = new ObjectInputStream(is);
  classifier = (Classifier) objectInputStream.readObject();
  try { // see if we can load the header
    trainHeader = (Instances) objectInputStream.readObject();
  } catch (Exception e) {} // don't fuss if we can't
  objectInputStream.close();
      } catch (Exception e) {
 
  JOptionPane.showMessageDialog(null, e, "Load Failed",
              JOptionPane.ERROR_MESSAGE);
     

      m_Log.statusMessage("OK");
     
      if (classifier != null) {
  m_Log.logMessage("Loaded model from file '" + selected.getName()+ "'");
  String name = (new SimpleDateFormat("HH:mm:ss - ")).format(new Date());
  String cname = classifier.getClass().getName();
  if (cname.startsWith("weka.classifiers."))
    cname = cname.substring("weka.classifiers.".length());
  name += cname + " from file '" + selected.getName() + "'";
  StringBuffer outBuff = new StringBuffer();

  outBuff.append("=== Model information ===\n\n");
  outBuff.append("Filename:     " + selected.getName() + "\n");
  outBuff.append("Scheme:       " + classifier.getClass().getName());
  if (classifier instanceof OptionHandler) {
    String [] o = ((OptionHandler) classifier).getOptions();
    outBuff.append(" " + Utils.joinOptions(o));
  }
  outBuff.append("\n");
  if (trainHeader != null) {
    outBuff.append("Relation:     " + trainHeader.relationName() + '\n');
    outBuff.append("Attributes:   " + trainHeader.numAttributes() + '\n');
    if (trainHeader.numAttributes() < 100) {
      for (int i = 0; i < trainHeader.numAttributes(); i++) {
        outBuff.append("              " + trainHeader.attribute(i).name()
           + '\n');
      }
    } else {
      outBuff.append("              [list of attributes omitted]\n");
    }
  } else {
    outBuff.append("\nTraining data unknown\n");
  }

  outBuff.append("\n=== Classifier model ===\n\n");
  outBuff.append(classifier.toString() + "\n");
 
  m_History.addResult(name, outBuff);
  m_History.setSingle(name);
  FastVector vv = new FastVector();
  vv.addElement(classifier);
View Full Code Here

   *
   * @return the classifier string.
   */
  protected String getClassifierSpec() {
   
    Classifier c = getClassifier();
    if (c instanceof OptionHandler) {
      return c.getClass().getName() + " "
  + Utils.joinOptions(((OptionHandler)c).getOptions());
    }
    return c.getClass().getName();
  }
View Full Code Here

   *
   * @return     the classifier string.
   */
  protected String getClassifierSpec() {
    String  result;
    Classifier   c;
   
    c      = getClassifier();
    result = c.getClass().getName();
    if (c instanceof OptionHandler)
      result += " " + Utils.joinOptions(((OptionHandler) c).getOptions());
   
    return result;
  }
View Full Code Here

  {
    for (Dataset dataset : datasets)
    {
      // Set parameters
      int folds = 10;
      Classifier baseClassifier = new LinearRegression();
     
      // Set up the random number generator
        long seed = new Date().getTime();     
      Random random = new Random(seed)
       
      // Add IDs to the instances
      AddID.main(new String[] {"-i", MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".arff",
                     "-o", MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + "-plusIDs.arff" });

      String location = MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + "-plusIDs.arff";
     
          Instances data = DataSource.read(location);
         
          if (data == null) {
              throw new IOException("Could not load data from: " + location);
          }
         
          data.setClassIndex(data.numAttributes() - 1);
           
          // Instantiate the Remove filter
          Remove removeIDFilter = new Remove();
          removeIDFilter.setAttributeIndices("first");
     
      // Randomize the data
      data.randomize(random);
   
      // Perform cross-validation
        Instances predictedData = null;
        Evaluation eval = new Evaluation(data);
       
        for (int n = 0; n < folds; n++)
        {
          Instances train = data.trainCV(folds, n, random);
            Instances test = data.testCV(folds, n);
           
            // Apply log filter
          Filter logFilter = new LogFilter();
            logFilter.setInputFormat(train);
            train = Filter.useFilter(train, logFilter);       
            logFilter.setInputFormat(test);
            test = Filter.useFilter(test, logFilter);
           
            // Copy the classifier
            Classifier classifier = AbstractClassifier.makeCopy(baseClassifier);
                                  
            // Instantiate the FilteredClassifier
            FilteredClassifier filteredClassifier = new FilteredClassifier();
            filteredClassifier.setFilter(removeIDFilter);
            filteredClassifier.setClassifier(classifier);
View Full Code Here

  public static void runClassifierCV(WekaClassifier wekaClassifier, Dataset dataset)
    throws Exception
  {
    // Set parameters
    int folds = 10;
    Classifier baseClassifier = getClassifier(wekaClassifier);
   
    // Set up the random number generator
      long seed = new Date().getTime();     
    Random random = new Random(seed)
     
    // Add IDs to the instances
    AddID.main(new String[] {"-i", MODELS_DIR + "/" + dataset.toString() + ".arff",
                  "-o", MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff" });
    Instances data = DataSource.read(MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff");
    data.setClassIndex(data.numAttributes() - 1);       
   
        // Instantiate the Remove filter
        Remove removeIDFilter = new Remove();
      removeIDFilter.setAttributeIndices("first");
   
    // Randomize the data
    data.randomize(random);
 
    // Perform cross-validation
      Instances predictedData = null;
      Evaluation eval = new Evaluation(data);
     
      for (int n = 0; n < folds; n++)
      {
        Instances train = data.trainCV(folds, n, random);
          Instances test = data.testCV(folds, n);
         
          // Apply log filter
//        Filter logFilter = new LogFilter();
//          logFilter.setInputFormat(train);
//          train = Filter.useFilter(train, logFilter);       
//          logFilter.setInputFormat(test);
//          test = Filter.useFilter(test, logFilter);
         
          // Copy the classifier
          Classifier classifier = AbstractClassifier.makeCopy(baseClassifier);
                                
          // Instantiate the FilteredClassifier
          FilteredClassifier filteredClassifier = new FilteredClassifier();
          filteredClassifier.setFilter(removeIDFilter);
          filteredClassifier.setClassifier(classifier);
View Full Code Here

  {
    for (Dataset dataset : datasets)
    {
      // Set parameters
      int folds = 10;
      Classifier baseClassifier = new LinearRegression();
     
      // Set up the random number generator
        long seed = new Date().getTime();     
      Random random = new Random(seed)
       
      // Add IDs to the instances
      AddID.main(new String[] {"-i", MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + ".arff",
                     "-o", MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + "-plusIDs.arff" });
      Instances data = DataSource.read(MODELS_DIR + "/" + mode.toString().toLowerCase() + "/" + dataset.toString() + "-plusIDs.arff");
      data.setClassIndex(data.numAttributes() - 1);       
     
          // Instantiate the Remove filter
          Remove removeIDFilter = new Remove();
          removeIDFilter.setAttributeIndices("first");
     
      // Randomize the data
      data.randomize(random);
   
      // Perform cross-validation
        Instances predictedData = null;
        Evaluation eval = new Evaluation(data);
       
        for (int n = 0; n < folds; n++)
        {
          Instances train = data.trainCV(folds, n, random);
            Instances test = data.testCV(folds, n);
           
            // Apply log filter
          Filter logFilter = new LogFilter();
            logFilter.setInputFormat(train);
            train = Filter.useFilter(train, logFilter);       
            logFilter.setInputFormat(test);
            test = Filter.useFilter(test, logFilter);
           
            // Copy the classifier
            Classifier classifier = AbstractClassifier.makeCopy(baseClassifier);
                                  
            // Instantiate the FilteredClassifier
            FilteredClassifier filteredClassifier = new FilteredClassifier();
            filteredClassifier.setFilter(removeIDFilter);
            filteredClassifier.setClassifier(classifier);
View Full Code Here

//  }
 
  public static void runClassifier(WekaClassifier wekaClassifier, Dataset trainDataset, Dataset testDataset)
      throws Exception
  {
    Classifier baseClassifier = ClassifierSimilarityMeasure.getClassifier(wekaClassifier);
   
    // Set up the random number generator
      long seed = new Date().getTime();     
    Random random = new Random(seed)
       
    // Add IDs to the train instances and get the instances
    AddID.main(new String[] {"-i", MODELS_DIR + "/" + trainDataset.toString() + ".arff",
                  "-o", MODELS_DIR + "/" + trainDataset.toString() + "-plusIDs.arff" });
    Instances train = DataSource.read(MODELS_DIR + "/" + trainDataset.toString() + "-plusIDs.arff");
    train.setClassIndex(train.numAttributes() - 1)
   
    // Add IDs to the test instances and get the instances
    AddID.main(new String[] {"-i", MODELS_DIR + "/" + testDataset.toString() + ".arff",
                  "-o", MODELS_DIR + "/" + testDataset.toString() + "-plusIDs.arff" });
    Instances test = DataSource.read(MODELS_DIR + "/" + testDataset.toString() + "-plusIDs.arff");
    test.setClassIndex(test.numAttributes() - 1);   
   
    // Instantiate the Remove filter
        Remove removeIDFilter = new Remove();
      removeIDFilter.setAttributeIndices("first");
       
    // Randomize the data
    test.randomize(random);
   
    // Apply log filter
//      Filter logFilter = new LogFilter();
//      logFilter.setInputFormat(train);
//      train = Filter.useFilter(train, logFilter);       
//      logFilter.setInputFormat(test);
//      test = Filter.useFilter(test, logFilter);
       
        // Copy the classifier
        Classifier classifier = AbstractClassifier.makeCopy(baseClassifier);
         
        // Instantiate the FilteredClassifier
        FilteredClassifier filteredClassifier = new FilteredClassifier();
        filteredClassifier.setFilter(removeIDFilter);
        filteredClassifier.setClassifier(classifier);
          
        // Build the classifier
        filteredClassifier.buildClassifier(train);
   
        // Prepare the output buffer
        AbstractOutput output = new PlainText();
        output.setBuffer(new StringBuffer());
        output.setHeader(test);
        output.setAttributes("first");
       
    Evaluation eval = new Evaluation(train);
        eval.evaluateModel(filteredClassifier, test, output);
       
        // Convert predictions to CSV
        // Format: inst#, actual, predicted, error, probability, (ID)
        String[] scores = new String[new Double(eval.numInstances()).intValue()];
        double[] probabilities = new double[new Double(eval.numInstances()).intValue()];
        for (String line : output.getBuffer().toString().split("\n"))
        {
          String[] linesplit = line.split("\\s+");

          // If there's been an error, the length of linesplit is 6, otherwise 5,
          // due to the error flag "+"
         
          int id;
          String expectedValue, classification;
          double probability;
         
          if (line.contains("+"))
          {
               id = Integer.parseInt(linesplit[6].substring(1, linesplit[6].length() - 1));
            expectedValue = linesplit[2].substring(2);
            classification = linesplit[3].substring(2);
            probability = Double.parseDouble(linesplit[5]);
          } else {
            id = Integer.parseInt(linesplit[5].substring(1, linesplit[5].length() - 1));
            expectedValue = linesplit[2].substring(2);
            classification = linesplit[3].substring(2);
            probability = Double.parseDouble(linesplit[4]);
          }
         
          scores[id - 1] = classification;
          probabilities[id - 1] = probability;
        }
               
        System.out.println(eval.toSummaryString());
      System.out.println(eval.toMatrixString());
     
      // Output classifications
      StringBuilder sb = new StringBuilder();
      for (String score : scores)
        sb.append(score.toString() + LF);
     
      FileUtils.writeStringToFile(
        new File(OUTPUT_DIR + "/" + testDataset.toString() + "/" + wekaClassifier.toString() + "/" + testDataset.toString() + ".csv"),
        sb.toString());
     
      // Output probabilities
      sb = new StringBuilder();
      for (Double probability : probabilities)
        sb.append(probability.toString() + LF);
     
      FileUtils.writeStringToFile(
        new File(OUTPUT_DIR + "/" + testDataset.toString() + "/" + wekaClassifier.toString() + "/" + testDataset.toString() + ".probabilities.csv"),
        sb.toString());
     
      // Output predictions
      FileUtils.writeStringToFile(
        new File(OUTPUT_DIR + "/" + testDataset.toString() + "/" + wekaClassifier.toString() + "/" + testDataset.toString() + ".predictions.txt"),
        output.getBuffer().toString());
     
      // Output meta information
      sb = new StringBuilder();
      sb.append(classifier.toString() + LF);
      sb.append(eval.toSummaryString() + LF);
      sb.append(eval.toMatrixString() + LF);
     
      FileUtils.writeStringToFile(
        new File(OUTPUT_DIR + "/" + testDataset.toString() + "/" + wekaClassifier.toString() + "/" + testDataset.toString() + ".meta.txt"),
View Full Code Here

TOP

Related Classes of weka.classifiers.Classifier

Copyright © 2018 www.massapicom. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.