Package org.apache.mahout.classifier.sgd

Source Code of org.apache.mahout.classifier.sgd.RunAdaptiveLogistic

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.classifier.sgd;

import com.google.common.base.Charsets;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.util.HelpFormatter;
import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper;
import org.apache.mahout.ep.State;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.util.HashMap;
import java.util.Map;

public final class RunAdaptiveLogistic {

  private static String inputFile;
  private static String modelFile;
  private static String outputFile;
  private static String idColumn;
  private static boolean maxScoreOnly;

  private RunAdaptiveLogistic() {
  }

  public static void main(String[] args) throws Exception {
    mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
  }

  static void mainToOutput(String[] args, PrintWriter output) throws Exception {
    if (!parseArgs(args)) {
      return;
    }
    AdaptiveLogisticModelParameters lmp = AdaptiveLogisticModelParameters
        .loadFromFile(new File(modelFile));

    CsvRecordFactory csv = lmp.getCsvRecordFactory();
    csv.setIdName(idColumn);

    AdaptiveLogisticRegression lr = lmp.createAdaptiveLogisticRegression();

    State<Wrapper, CrossFoldLearner> best = lr.getBest();
    if (best == null) {
      output.println("AdaptiveLogisticRegression has not be trained probably.");
      return;
    }
    CrossFoldLearner learner = best.getPayload().getLearner();

    BufferedReader in = TrainAdaptiveLogistic.open(inputFile);
    BufferedWriter out = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(outputFile), Charsets.UTF_8));

    out.write(idColumn + ",target,score");
    out.newLine();

    String line = in.readLine();
    csv.firstLine(line);
    line = in.readLine();
    Map<String, Double> results = new HashMap<String, Double>();
    int k = 0;
    while (line != null) {
      Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
      csv.processLine(line, v, false);
      Vector scores = learner.classifyFull(v);
      results.clear();
      if (maxScoreOnly) {
        results.put(csv.getTargetLabel(scores.maxValueIndex()),
            scores.maxValue());
      } else {
        for (int i = 0; i < scores.size(); i++) {
          results.put(csv.getTargetLabel(i), scores.get(i));
        }
      }

      for (Map.Entry<String,Double> entry : results.entrySet()) {
        out.write(csv.getIdString(line) + ',' + entry.getKey() + ',' + entry.getValue());
        out.newLine();
      }
      k++;
      if (k % 100 == 0) {
        output.println(k + " records processed");
      }
      line = in.readLine();
    }
    out.flush();
    out.close();
    output.println(k + " records processed totally.");
  }

  private static boolean parseArgs(String[] args) {
    DefaultOptionBuilder builder = new DefaultOptionBuilder();

    Option help = builder.withLongName("help")
      .withDescription("print this list").create();

    Option quiet = builder.withLongName("quiet")
      .withDescription("be extra quiet").create();

    ArgumentBuilder argumentBuilder = new ArgumentBuilder();
    Option inputFileOption = builder
      .withLongName("input")
      .withRequired(true)
      .withArgument(
          argumentBuilder.withName("input").withMaximum(1)
            .create())
      .withDescription("where to get training data").create();

    Option modelFileOption = builder
      .withLongName("model")
      .withRequired(true)
      .withArgument(
          argumentBuilder.withName("model").withMaximum(1)
            .create())
      .withDescription("where to get the trained model").create();
   
    Option outputFileOption = builder
      .withLongName("output")
      .withRequired(true)
      .withDescription("the file path to output scores")
      .withArgument(argumentBuilder.withName("output").withMaximum(1).create())
      .create();
   
    Option idColumnOption = builder
      .withLongName("idcolumn")
      .withRequired(true)
      .withDescription("the name of the id column for each record")
      .withArgument(argumentBuilder.withName("idcolumn").withMaximum(1).create())
      .create();
   
    Option maxScoreOnlyOption = builder
      .withLongName("maxscoreonly")
      .withDescription("only output the target label with max scores")
      .create();

    Group normalArgs = new GroupBuilder()
      .withOption(help).withOption(quiet)
      .withOption(inputFileOption).withOption(modelFileOption)
      .withOption(outputFileOption).withOption(idColumnOption)
      .withOption(maxScoreOnlyOption)
      .create();

    Parser parser = new Parser();
    parser.setHelpOption(help);
    parser.setHelpTrigger("--help");
    parser.setGroup(normalArgs);
    parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
    CommandLine cmdLine = parser.parseAndHelp(args);

    if (cmdLine == null) {
      return false;
    }

    inputFile = getStringArgument(cmdLine, inputFileOption);
    modelFile = getStringArgument(cmdLine, modelFileOption);
    outputFile = getStringArgument(cmdLine, outputFileOption);
    idColumn = getStringArgument(cmdLine, idColumnOption);
    maxScoreOnly = getBooleanArgument(cmdLine, maxScoreOnlyOption);   
    return true;
  }

  private static boolean getBooleanArgument(CommandLine cmdLine, Option option) {
    return cmdLine.hasOption(option);
  }

  private static String getStringArgument(CommandLine cmdLine, Option inputFile) {
    return (String) cmdLine.getValue(inputFile);
  }

}
TOP

Related Classes of org.apache.mahout.classifier.sgd.RunAdaptiveLogistic

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.