Package org.apache.mahout.classifier.sgd

Source Code of org.apache.mahout.classifier.sgd.TestASFEmail

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.classifier.sgd;

import com.google.common.base.Charsets;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.util.HelpFormatter;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.PathFilter;
import org.apache.hadoop.io.Text;
import org.apache.mahout.classifier.ClassifierResult;
import org.apache.mahout.classifier.ResultAnalyzer;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.vectorizer.encoders.Dictionary;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;

/**
* Run the ASF email, as trained by TrainASFEmail
*/
public final class TestASFEmail {

  private String inputFile;
  private String modelFile;

  private TestASFEmail() {}

  public static void main(String[] args) throws IOException {
    TestASFEmail runner = new TestASFEmail();
    if (runner.parseArgs(args)) {
      runner.run(new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
    }
  }

  public void run(PrintWriter output) throws IOException {

    File base = new File(inputFile);
    //contains the best model
    OnlineLogisticRegression classifier =
        ModelSerializer.readBinary(new FileInputStream(modelFile), OnlineLogisticRegression.class);


    Dictionary asfDictionary = new Dictionary();
    Configuration conf = new Configuration();
    PathFilter testFilter = new PathFilter() {
      @Override
      public boolean accept(Path path) {
        return path.getName().contains("test");
      }
    };
    SequenceFileDirIterator<Text, VectorWritable> iter =
        new SequenceFileDirIterator<Text, VectorWritable>(new Path(base.toString()), PathType.LIST, testFilter,
        null, true, conf);

    long numItems = 0;
    while (iter.hasNext()) {
      Pair<Text, VectorWritable> next = iter.next();
      asfDictionary.intern(next.getFirst().toString());
      numItems++;
    }

    System.out.println(numItems + " test files");
    ResultAnalyzer ra = new ResultAnalyzer(asfDictionary.values(), "DEFAULT");
    iter = new SequenceFileDirIterator<Text, VectorWritable>(new Path(base.toString()), PathType.LIST, testFilter,
            null, true, conf);
    while (iter.hasNext()) {
      Pair<Text, VectorWritable> next = iter.next();
      String ng = next.getFirst().toString();

      int actual = asfDictionary.intern(ng);
      Vector result = classifier.classifyFull(next.getSecond().get());
      int cat = result.maxValueIndex();
      double score = result.maxValue();
      double ll = classifier.logLikelihood(actual, next.getSecond().get());
      ClassifierResult cr = new ClassifierResult(asfDictionary.values().get(cat), score, ll);
      ra.addInstance(asfDictionary.values().get(actual), cr);

    }
    output.println(ra);
  }

  boolean parseArgs(String[] args) {
    DefaultOptionBuilder builder = new DefaultOptionBuilder();

    Option help = builder.withLongName("help").withDescription("print this list").create();

    ArgumentBuilder argumentBuilder = new ArgumentBuilder();
    Option inputFileOption = builder.withLongName("input")
            .withRequired(true)
            .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
            .withDescription("where to get training data")
            .create();

    Option modelFileOption = builder.withLongName("model")
            .withRequired(true)
            .withArgument(argumentBuilder.withName("model").withMaximum(1).create())
            .withDescription("where to get a model")
            .create();

    Group normalArgs = new GroupBuilder()
            .withOption(help)
            .withOption(inputFileOption)
            .withOption(modelFileOption)
            .create();

    Parser parser = new Parser();
    parser.setHelpOption(help);
    parser.setHelpTrigger("--help");
    parser.setGroup(normalArgs);
    parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
    CommandLine cmdLine = parser.parseAndHelp(args);

    if (cmdLine == null) {
      return false;
    }

    inputFile = (String) cmdLine.getValue(inputFileOption);
    modelFile = (String) cmdLine.getValue(modelFileOption);
    return true;
  }

}
TOP

Related Classes of org.apache.mahout.classifier.sgd.TestASFEmail

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.