Package org.apache.mahout.pig

Source Code of org.apache.mahout.pig.LogisticRegressionTest

package org.apache.mahout.pig;

import com.google.common.collect.ImmutableList;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.classifier.sgd.PolymorphicWritable;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.pig.data.*;
import org.junit.Test;

import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.IOException;
import java.util.List;
import java.util.Random;

import static junit.framework.Assert.assertEquals;
import static junit.framework.Assert.fail;
import static org.junit.Assert.assertTrue;

public class LogisticRegressionTest {
    @Test
    public void testConstructor() throws IOException {
        LogisticRegression lr = new LogisticRegression("categories = a b c, features=10000");
        assertEquals(3, lr.getModel().numCategories());
        assertEquals(10000, lr.getModel().numFeatures());
        assertEquals(1, lr.getIterations());
        assertTrue(lr.isInMemory());


        lr = new LogisticRegression("categories = a b c, features=10000, decayExponent=0.3 ,stepOffset=123, learningRate=2.1, lambda=0.001, iterations=12, inMemory=false");
        assertEquals(2.1 * Math.pow(123, -0.3), lr.getModel().currentLearningRate());
        assertEquals(0.001, lr.getModel().getLambda());

        try {
            new LogisticRegression("categories = a , features=10000");
            fail("Should have failed");
        } catch (BadClassifierSpecException e) {
            assertTrue("Single target category", e.getMessage().startsWith("Must have more than one target category"));
        }

        try {
            new LogisticRegression("categories = a b, features=10000, x=3");
            fail("Should have failed");
        } catch (BadClassifierSpecException e) {
            assertEquals("Extra options", "Extra options supplied: x", e.getMessage());
        }

        try {
            new LogisticRegression("categories = a ");
            fail("Should have failed");
        } catch (BadClassifierSpecException e) {
            assertEquals("No features", "Must specify previous model location using \"file\" or supply \"categories\" and \"features\"", e.getMessage());
        }
    }

    @Test
    public void testTraining() throws IOException {
        DoubleFunction randomValue = new DoubleFunction() {
            private Random gen = new Random(1);

            public double apply(double arg1) {
                return gen.nextGaussian();
            }
        };

        // start with a random direction
        Vector n = new DenseVector(4);
        n.assign(randomValue);

        // make up data.  result is 1 if in the same rough direction as n, else 0
        String[] target = new String[]{"0", "1"};
        DataBag examples = new DefaultDataBag();
        for (int i = 0; i < 400; i++) {
            Vector v = new DenseVector(4);
            v.assign(randomValue);

            Tuple x = new DefaultTuple();
            x.append(target[v.dot(n) > 0 ? 1 : 0]);
            x.append(PigVector.toBytes(v));
            examples.add(x);
        }
        Tuple data = new DefaultTuple();
        data.append(examples);

        // train model.  training from tmp file allows absolute repeatability
        LogisticRegression lr = new LogisticRegression("categories = 0 1, features=4, inMemory=false, iterations=5");
        lr.accumulate(data);
        DataByteArray r = lr.getValue();
        DataInputStream in = new DataInputStream(new ByteArrayInputStream(r.get()));
        Classifier c = PolymorphicWritable.read(in, Classifier.class);
        assertEquals(2, c.getCategories().size());
        assertEquals("0", c.getCategories().get(0));
        assertEquals("1", c.getCategories().get(1));
        OnlineLogisticRegression model = c.getModel();
        assertEquals(lr.getModel().currentLearningRate(), model.currentLearningRate(), 1e-10);
        in.close();

        // with that many data points, model should point in the same direction as the original vector
        Vector v = model.getBeta().viewRow(0);
        double z = n.dot(v) / (n.norm(2) * v.norm(2));
        assertEquals(1.0, z, 1e-2);

        // just for grins, we should check whether the model actually computes the correct values
        List<String> categories = ImmutableList.of("0", "1");
        for (Tuple example : examples) {
            double score = model.classifyScalar(PigVector.fromBytes((DataByteArray) example.get(1)));
            int actual = categories.indexOf(example.get(0));
            score = score * actual + (1 - score) * (1 - actual);
            assertTrue(score > 0.4);
        }
    }
}
TOP

Related Classes of org.apache.mahout.pig.LogisticRegressionTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.