Package org.apache.mahout.classifier.sgd

Source Code of org.apache.mahout.classifier.sgd.OnlineLogisticRegressionTest

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.classifier.sgd;

import com.google.common.base.Charsets;
import com.google.common.base.Splitter;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.io.Closer;
import com.google.common.io.Resources;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.vectorizer.encoders.Dictionary;
import org.junit.Assert;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.lang.reflect.Field;
import java.util.Collections;
import java.util.List;
import java.util.Random;


public final class OnlineLogisticRegressionTest extends OnlineBaseTest {

  private static final Logger logger = LoggerFactory.getLogger(OnlineLogisticRegressionTest.class);

  /**
   * The CrossFoldLearner is probably the best learner to use for new applications.
   *
   * @throws IOException If test resources aren't readable.
   */
  @Test
  public void crossValidation() throws IOException {
    Vector target = readStandardData();

    CrossFoldLearner lr = new CrossFoldLearner(5, 2, 8, new L1())
      .lambda(1 * 1.0e-3)
      .learningRate(50);


    train(getInput(), target, lr);

    System.out.printf("%.2f %.5f\n", lr.auc(), lr.logLikelihood());
    test(getInput(), target, lr, 0.05, 0.3);

  }

  @Test
  public void crossValidatedAuc() throws IOException {
    RandomUtils.useTestSeed();
    Random gen = RandomUtils.getRandom();

    Matrix data = readCsv("cancer.csv");
    CrossFoldLearner lr = new CrossFoldLearner(5, 2, 10, new L1())
      .stepOffset(10)
      .decayExponent(0.7)
      .lambda(1 * 1.0e-3)
      .learningRate(5);
    int k = 0;
    int[] ordering = permute(gen, data.numRows());
    for (int epoch = 0; epoch < 100; epoch++) {
      for (int row : ordering) {
        lr.train(row, (int) data.get(row, 9), data.viewRow(row));
        System.out.printf("%d,%d,%.3f\n", epoch, k++, lr.auc());
      }
      assertEquals(1, lr.auc(), 0.2);
    }
    assertEquals(1, lr.auc(), 0.1);
  }

  /**
   * Verifies that a classifier with known coefficients does the right thing.
   */
  @Test
  public void testClassify() {
    OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 2, new L2(1));
    // set up some internal coefficients as if we had learned them
    lr.setBeta(0, 0, -1);
    lr.setBeta(1, 0, -2);

    // zero vector gives no information.  All classes are equal.
    Vector v = lr.classify(new DenseVector(new double[]{0, 0}));
    assertEquals(1 / 3.0, v.get(0), 1.0e-8);
    assertEquals(1 / 3.0, v.get(1), 1.0e-8);

    v = lr.classifyFull(new DenseVector(new double[]{0, 0}));
    assertEquals(1.0, v.zSum(), 1.0e-8);
    assertEquals(1 / 3.0, v.get(0), 1.0e-8);
    assertEquals(1 / 3.0, v.get(1), 1.0e-8);
    assertEquals(1 / 3.0, v.get(2), 1.0e-8);

    // weights for second vector component are still zero so all classifications are equally likely
    v = lr.classify(new DenseVector(new double[]{0, 1}));
    assertEquals(1 / 3.0, v.get(0), 1.0e-3);
    assertEquals(1 / 3.0, v.get(1), 1.0e-3);

    v = lr.classifyFull(new DenseVector(new double[]{0, 1}));
    assertEquals(1.0, v.zSum(), 1.0e-8);
    assertEquals(1 / 3.0, v.get(0), 1.0e-3);
    assertEquals(1 / 3.0, v.get(1), 1.0e-3);
    assertEquals(1 / 3.0, v.get(2), 1.0e-3);

    // but the weights on the first component are non-zero
    v = lr.classify(new DenseVector(new double[]{1, 0}));
    assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8);
    assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8);

    v = lr.classifyFull(new DenseVector(new double[]{1, 0}));
    assertEquals(1.0, v.zSum(), 1.0e-8);
    assertEquals(1 / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8);
    assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8);
    assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(2), 1.0e-8);

    lr.setBeta(0, 1, 1);

    v = lr.classifyFull(new DenseVector(new double[]{1, 1}));
    assertEquals(1.0, v.zSum(), 1.0e-8);
    assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(-2)), v.get(1), 1.0e-3);
    assertEquals(Math.exp(-2) / (1 + Math.exp(0) + Math.exp(-2)), v.get(2), 1.0e-3);
    assertEquals(1 / (1 + Math.exp(0) + Math.exp(-2)), v.get(0), 1.0e-3);

    lr.setBeta(1, 1, 3);

    v = lr.classifyFull(new DenseVector(new double[]{1, 1}));
    assertEquals(1.0, v.zSum(), 1.0e-8);
    assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(1)), v.get(1), 1.0e-8);
    assertEquals(Math.exp(1) / (1 + Math.exp(0) + Math.exp(1)), v.get(2), 1.0e-8);
    assertEquals(1 / (1 + Math.exp(0) + Math.exp(1)), v.get(0), 1.0e-8);
  }

  @Test
  public void iris() throws IOException {
    // this test trains a 3-way classifier on the famous Iris dataset.
    // a similar exercise can be accomplished in R using this code:
    //    library(nnet)
    //    correct = rep(0,100)
    //    for (j in 1:100) {
    //      i = order(runif(150))
    //      train = iris[i[1:100],]
    //      test = iris[i[101:150],]
    //      m = multinom(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, train)
    //      correct[j] = mean(predict(m, newdata=test) == test$Species)
    //    }
    //    hist(correct)
    //
    // Note that depending on the training/test split, performance can be better or worse.
    // There is about a 5% chance of getting accuracy < 90% and about 20% chance of getting accuracy
    // of 100%
    //
    // This test uses a deterministic split that is neither outstandingly good nor bad


    RandomUtils.useTestSeed();
    Splitter onComma = Splitter.on(",");

    // read the data
    List<String> raw = Resources.readLines(Resources.getResource("iris.csv"), Charsets.UTF_8);

    // holds features
    List<Vector> data = Lists.newArrayList();

    // holds target variable
    List<Integer> target = Lists.newArrayList();

    // for decoding target values
    Dictionary dict = new Dictionary();

    // for permuting data later
    List<Integer> order = Lists.newArrayList();

    for (String line : raw.subList(1, raw.size())) {
      // order gets a list of indexes
      order.add(order.size());

      // parse the predictor variables
      Vector v = new DenseVector(5);
      v.set(0, 1);
      int i = 1;
      Iterable<String> values = onComma.split(line);
      for (String value : Iterables.limit(values, 4)) {
        v.set(i++, Double.parseDouble(value));
      }
      data.add(v);

      // and the target
      target.add(dict.intern(Iterables.get(values, 4)));
    }

    // randomize the order ... original data has each species all together
    // note that this randomization is deterministic
    Random random = RandomUtils.getRandom();
    Collections.shuffle(order, random);

    // select training and test data
    List<Integer> train = order.subList(0, 100);
    List<Integer> test = order.subList(100, 150);
    logger.warn("Training set = {}", train);
    logger.warn("Test set = {}", test);

    // now train many times and collect information on accuracy each time
    int[] correct = new int[test.size() + 1];
    for (int run = 0; run < 200; run++) {
      OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 5, new L2(1));
      // 30 training passes should converge to > 95% accuracy nearly always but never to 100%
      for (int pass = 0; pass < 30; pass++) {
        Collections.shuffle(train, random);
        for (int k : train) {
          lr.train(target.get(k), data.get(k));
        }
      }

      // check the accuracy on held out data
      int x = 0;
      int[] count = new int[3];
      for (Integer k : test) {
        int r = lr.classifyFull(data.get(k)).maxValueIndex();
        count[r]++;
        x += r == target.get(k) ? 1 : 0;
      }
      correct[x]++;
    }

    // verify we never saw worse than 95% correct,
    for (int i = 0; i < Math.floor(0.95 * test.size()); i++) {
      assertEquals(String.format("%d trials had unacceptable accuracy of only %.0f%%: ", correct[i], 100.0 * i / test.size()), 0, correct[i]);
    }
    // nor perfect
    assertEquals(String.format("%d trials had unrealistic accuracy of 100%%", correct[test.size() - 1]), 0, correct[test.size()]);
  }

  @Test
  public void testTrain() throws Exception {
    Vector target = readStandardData();


    // lambda here needs to be relatively small to avoid swamping the actual signal, but can be
    // larger than usual because the data are dense.  The learning rate doesn't matter too much
    // for this example, but should generally be < 1
    // --passes 1 --rate 50 --lambda 0.001 --input sgd-y.csv --features 21 --output model --noBias
    //   --target y --categories 2 --predictors  V2 V3 V4 V5 V6 V7 --types n
    OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1())
      .lambda(1 * 1.0e-3)
      .learningRate(50);

    train(getInput(), target, lr);
    test(getInput(), target, lr, 0.05, 0.3);
  }

  /**
   * Test for Serialization/DeSerialization
   *
   */
  @Test
  public void testSerializationAndDeSerialization() throws Exception {
    OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1())
      .lambda(1 * 1.0e-3)
      .stepOffset(11)
      .alpha(0.01)
      .learningRate(50)
      .decayExponent(-0.02);

    lr.close();

    byte[] output;
    Closer closer = Closer.create();

    try {
      ByteArrayOutputStream byteArrayOutputStream = closer.register(new ByteArrayOutputStream());
      DataOutputStream dataOutputStream = closer.register(new DataOutputStream(byteArrayOutputStream));
      PolymorphicWritable.write(dataOutputStream, lr);
      output = byteArrayOutputStream.toByteArray();
    } finally {
      closer.close();
    }

    OnlineLogisticRegression read;

    try {
      ByteArrayInputStream byteArrayInputStream = closer.register(new ByteArrayInputStream(output));
      DataInputStream dataInputStream = closer.register(new DataInputStream(byteArrayInputStream));
      read = closer.register(PolymorphicWritable.read(dataInputStream, OnlineLogisticRegression.class));
    } finally {
      closer.close();
    }

    //lambda
    Assert.assertEquals((1.0e-3), read.getLambda(), 1.0e-7);

    // Reflection to get private variables
    //stepOffset
    Field stepOffset = lr.getClass().getDeclaredField("stepOffset");
    stepOffset.setAccessible(true);
    int stepOffsetVal = (Integer) stepOffset.get(lr);
    Assert.assertEquals(11, stepOffsetVal);

    //decayFactor (alpha)
    Field decayFactor = lr.getClass().getDeclaredField("decayFactor");
    decayFactor.setAccessible(true);
    double decayFactorVal = (Double) decayFactor.get(lr);
    Assert.assertEquals(0.01, decayFactorVal, 1.0e-7);

    //learning rate (mu0)
    Field mu0 = lr.getClass().getDeclaredField("mu0");
    mu0.setAccessible(true);
    double mu0Val = (Double) mu0.get(lr);
    Assert.assertEquals(50, mu0Val, 1.0e-7);

    //forgettingExponent (decayExponent)
    Field forgettingExponent = lr.getClass().getDeclaredField("forgettingExponent");
    forgettingExponent.setAccessible(true);
    double forgettingExponentVal = (Double) forgettingExponent.get(lr);
    Assert.assertEquals(-0.02, forgettingExponentVal, 1.0e-7);
  }
}
TOP

Related Classes of org.apache.mahout.classifier.sgd.OnlineLogisticRegressionTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.