Package edu.uci.jforestsx.applications

Source Code of edu.uci.jforestsx.applications.Runner

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package edu.uci.jforestsx.applications;

import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.io.PrintStream;
import java.util.List;
import java.util.Properties;

import joptsimple.OptionParser;
import joptsimple.OptionSet;
import edu.uci.jforestsx.config.TrainingConfig;
import edu.uci.jforestsx.dataset.Dataset;
import edu.uci.jforestsx.dataset.DatasetLoader;
import edu.uci.jforestsx.dataset.RankingDataset;
import edu.uci.jforestsx.dataset.RankingDatasetLoader;
import edu.uci.jforestsx.input.RankingRaw2BinConvertor;
import edu.uci.jforestsx.input.Raw2BinConvertor;
import edu.uci.jforestsx.learning.LearningUtils;
import edu.uci.jforestsx.learning.trees.Ensemble;
import edu.uci.jforestsx.learning.trees.decision.DecisionTree;
import edu.uci.jforestsx.learning.trees.regression.RegressionTree;
import edu.uci.jforestsx.sample.RankingSample;
import edu.uci.jforestsx.sample.Sample;
import edu.uci.jforestsx.util.IOUtils;

/**
* @author Yasser Ganjisaffar <ganjisaffar at gmail dot com>
*/

public class Runner {

  @SuppressWarnings("unchecked")
  private static void generateBin(OptionSet options) throws Exception {
    if (!options.has("folder")) {
      System.err.println("The input folder is not specified.");
      return;
    }

    if (!options.has("file")) {
      System.err.println("Input files are not specified.");
      return;
    }

    String folder = (String) options.valueOf("folder");
    List<String> filesList = (List<String>) options.valuesOf("file");
    String[] files = new String[filesList.size()];
    for (int i = 0; i < files.length; i++) {
      files[i] = filesList.get(i);
    }

    if (options.has("ranking")) {
      System.out.println("Generating binary files for ranking data sets...");
      new RankingRaw2BinConvertor().convert(folder, files);
    } else {
      System.out.println("Generating binary files...");
      new Raw2BinConvertor().convert(folder, files);
    }
  }

  private static void train(OptionSet options) throws Exception {
    if (!options.has("config-file")) {
      System.err.println("The configurations file is not specified.");
      return;
    }

    InputStream configInputStream = new FileInputStream((String) options.valueOf("config-file"));
    Properties configProperties = new Properties();
    configProperties.load(configInputStream);

    if (options.has("train-file")) {
      configProperties.put(TrainingConfig.TRAIN_FILENAME, options.valueOf("train-file"));
    }

    if (options.has("validation-file")) {
      configProperties.put(TrainingConfig.VALID_FILENAME, options.valueOf("validation-file"));
    }

    Ensemble ensemble;

    if (options.has("ranking")) {
      RankingApp app = new RankingApp();
      ensemble = app.run(configProperties);
    } else {
      ClassificationApp app = new ClassificationApp();
      ensemble = app.run(configProperties);
    }

    /*
     * Dump the output model if requested.
     */
    if (options.has("output-model")) {
      String outputModelFile = (String) options.valueOf("output-model");
      File file = new File(outputModelFile);
      PrintStream ensembleOutput = new PrintStream(file);
      ensembleOutput.println(ensemble);
      ensembleOutput.close();
    }

  }

  private static void predict(OptionSet options) throws Exception {

    if (!options.has("model-file")) {
      System.err.println("Model file is not specified.");
      return;
    }

    if (!options.has("tree-type")) {
      System.err.println("Types of trees in the ensemble is not specified.");
      return;
    }

    if (!options.has("test-file")) {
      System.err.println("Test file is not specified.");
      return;
    }

    /*
     * Load the ensemble
     */
    File modelFile = new File((String) options.valueOf("model-file"));
    Ensemble ensemble = new Ensemble();
    if (options.valueOf("tree-type").equals("RegressionTree")) {
      ensemble.loadFromFile(RegressionTree.class, modelFile);
    } else if (options.valueOf("tree-type").equals("DecisionTree")) {
      ensemble.loadFromFile(DecisionTree.class, modelFile);
    } else {
      System.err.println("Unknown tree type: " + options.valueOf("tree-type"));
    }

    /*
     * Load the data set
     */
    InputStream in = new IOUtils().getInputStream((String) options.valueOf("test-file"));
    Sample sample;
    if (options.has("ranking")) {
      RankingDataset dataset = new RankingDataset();
      RankingDatasetLoader.load(in, dataset);
      sample = new RankingSample(dataset);
    } else {
      Dataset dataset = new Dataset();
      DatasetLoader.load(in, dataset);
      sample = new Sample(dataset);
    }
    in.close();

    double[] predictions = new double[sample.size];
    LearningUtils.updateScores(sample, predictions, ensemble);

    PrintStream output;
    if (options.has("output-file")) {
      output = new PrintStream(new File((String) options.valueOf("output-file")));
    } else {
      output = System.out;
    }
   
    for (int i = 0; i < sample.size; i++) {
      output.println(predictions[i]);
    }

  }

  public static void main(String[] args) throws Exception {

    OptionParser parser = new OptionParser();

    parser.accepts("cmd").withRequiredArg();
    parser.accepts("ranking");

    /*
     * Bin generation arguments
     */
    parser.accepts("folder").withRequiredArg();
    parser.accepts("file").withRequiredArg();

    /*
     * Training arguments
     */
    parser.accepts("config-file").withRequiredArg();
    parser.accepts("train-file").withRequiredArg();
    parser.accepts("validation-file").withRequiredArg();
    parser.accepts("output-model").withRequiredArg();

    /*
     * Prediction arguments
     */
    parser.accepts("model-file").withRequiredArg();
    parser.accepts("tree-type").withRequiredArg();
    parser.accepts("test-file").withRequiredArg();
    parser.accepts("output-file").withRequiredArg();

    OptionSet options = parser.parse(args);

    if (!options.has("cmd")) {
      System.err.println("You must specify the command through 'cmd' parameter.");
      return;
    }

    if (options.valueOf("cmd").equals("generate-bin")) {
      generateBin(options);
    } else if (options.valueOf("cmd").equals("train")) {
      train(options);
    } else if (options.valueOf("cmd").equals("predict")) {
      predict(options);
    } else {
      System.err.println("Unknown command: " + options.valueOf("cmd"));
    }

    /*
     * Make sure that thread pool is terminated.
     */
    ClassificationApp.shutdown();
  }
}
 
TOP

Related Classes of edu.uci.jforestsx.applications.Runner

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.