Package org.apache.mahout.pig

Source Code of org.apache.mahout.pig.LogisticRegressionEval

package org.apache.mahout.pig;

import com.google.common.base.Joiner;
import com.google.common.base.Splitter;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.PathFilter;
import org.apache.hadoop.io.Text;
import org.apache.mahout.classifier.sgd.PolymorphicWritable;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator;
import org.apache.mahout.math.Vector;
import org.apache.pig.EvalFunc;
import org.apache.pig.data.DataByteArray;
import org.apache.pig.data.Tuple;
import org.apache.pig.impl.util.UDFContext;

import java.io.IOException;
import java.util.List;
import java.util.Map;

/**
* Evaluates a logistic regression classifier
*/
public class LogisticRegressionEval extends EvalFunc<String> {
    private Classifier model;
    private String location;
    private String key;

    public LogisticRegressionEval(String modelParams) throws IOException {
        System.out.printf("Model params = %s\n", modelParams);

        Splitter onComma = Splitter.on(",").trimResults().omitEmptyStrings();
        Splitter onEquals = Splitter.on("=").trimResults();
        Joiner withSpaces = Joiner.on(" ");

        Map<String, String> options = Maps.newHashMap();

        for (String option : onComma.split(modelParams)) {
            List<String> values = Lists.newArrayList(onEquals.split(option));
            options.put(values.get(0), values.get(1));
        }

        if (options.containsKey("sequence")) {
            location = options.get("sequence");
            if (options.containsKey("key")) {
                key = options.get("key");
                options.remove("sequence");
                options.remove(("key"));
            } else {
                throw new BadClassifierSpecException("Must specify key for model in a sequence file");
            }
        } else if (options.containsKey("file")) {
            Configuration conf = UDFContext.getUDFContext().getJobConf();
            final FSDataInputStream in = FileSystem.get(conf).open(new Path(options.get("file")));
            try {
                model = PolymorphicWritable.read(in, Classifier.class);
            } finally {
                in.close();
            }
            options.remove("file");

        } else {
            throw new BadClassifierSpecException("Must specify existing model");
        }

        if (options.size() > 0) {
            throw new BadClassifierSpecException("Extra options supplied: " + withSpaces.join(options.keySet()));
        }
    }


    /**
     * This callback method must be implemented by all subclasses. This
     * is the method that will be invoked on every Tuple of a given dataset.
     * Since the dataset may be divided up in a variety of ways the programmer
     * should not make assumptions about state that is maintained between
     * invocations of this method.
     *
     * @param input the Tuple to be processed.
     * @return result, of type T.
     * @throws java.io.IOException
     */
    @Override
    public String exec(Tuple input) throws IOException {
        if (model == null) {
            Path sequence = new Path(location);
            Configuration conf = UDFContext.getUDFContext().getJobConf();
            sequence = sequence.makeQualified(FileSystem.get(conf));
            System.out.printf("Model file = %s\n", sequence);
            SequenceFileDirIterator<Text, Classifier> x = new SequenceFileDirIterator<Text, Classifier>(sequence, PathType.GLOB, new PathFilter() {
                public boolean accept(Path path) {
                    System.out.printf("Scanning %s\n", path);
                    return !path.toString().startsWith("_");
                }
            }, null, true, conf);
            int n = 0;
            while (x.hasNext()) {
                Pair<Text, Classifier> pair = x.next();
                n++;
                if (pair.getFirst().toString().equals(key)) {
                    System.out.printf("Found model with categories = %s\n", pair.getSecond().getCategories());
                    model = pair.getSecond();
                    break;
                }
            }
            if (model == null) {
                throw new IOException(String.format("Can't find model with correct key = %s out of %d searched\n", key, n));
            }
            x.close();
        }
        final Vector instance = PigVector.fromBytes(((DataByteArray) input.get(0)));
        final int target = model.getModel().classifyFull(instance).maxValueIndex();
        return model.getCategories().get(target);
    }
}
TOP

Related Classes of org.apache.mahout.pig.LogisticRegressionEval

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.