Package org.apache.mahout.math.random

Source Code of org.apache.mahout.math.random.ChineseRestaurantTest

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.math.random;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Multiset;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.MahoutTestCase;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.QRDecomposition;
import org.junit.Test;

import java.util.Collections;
import java.util.List;
import java.util.Set;

public final class ChineseRestaurantTest extends MahoutTestCase {

  @Test
  public void testDepth() {
    List<Integer> totals = Lists.newArrayList();
    for (int i = 0; i < 1000; i++) {
      ChineseRestaurant x = new ChineseRestaurant(10);
      Multiset<Integer> counts = HashMultiset.create();
      for (int j = 0; j < 100; j++) {
        counts.add(x.sample());
      }
      List<Integer> tmp = Lists.newArrayList();
      for (Integer k : counts.elementSet()) {
        tmp.add(counts.count(k));
      }
      Collections.sort(tmp, Collections.reverseOrder());
      while (totals.size() < tmp.size()) {
        totals.add(0);
      }
      int j = 0;
      for (Integer k : tmp) {
        totals.set(j, totals.get(j) + k);
        j++;
      }
    }

    // these are empirically derived values, not principled ones
    assertEquals(25000.0, (double) totals.get(0), 1000);
    assertEquals(24000.0, (double) totals.get(1), 1000);
    assertEquals(8000.0, (double) totals.get(2), 200);
    assertEquals(1000.0, (double) totals.get(15), 50);
    assertEquals(1000.0, (double) totals.get(20), 40);
  }

  @Test
  public void testExtremeDiscount() {
    ChineseRestaurant x = new ChineseRestaurant(100, 1);
    Multiset<Integer> counts = HashMultiset.create();
    for (int i = 0; i < 10000; i++) {
      counts.add(x.sample());
    }
    assertEquals(10000, x.size());
    for (int i = 0; i < 10000; i++) {
      assertEquals(1, x.count(i));
    }
  }

  @Test
  public void testGrowth() {
    ChineseRestaurant s0 = new ChineseRestaurant(10, 0.0);
    ChineseRestaurant s5 = new ChineseRestaurant(10, 0.5);
    ChineseRestaurant s9 = new ChineseRestaurant(10, 0.9);
    Set<Double> splits = ImmutableSet.of(1.0, 1.5, 2.0, 3.0, 5.0, 8.0);

    double offset0 = 0;
    int k = 0;
    int i = 0;
    Matrix m5 = new DenseMatrix(20, 3);
    Matrix m9 = new DenseMatrix(20, 3);
    while (i <= 200000) {
      double n = i / Math.pow(10, Math.floor(Math.log10(i)));
      if (splits.contains(n)) {
        //System.out.printf("%d\t%d\t%d\t%d\n", i, s0.size(), s5.size(), s9.size());
        if (i > 900) {
          double predict5 = predictSize(m5.viewPart(0, k, 0, 3), i, 0.5);
          assertEquals(predict5, Math.log(s5.size()), 1);

          double predict9 = predictSize(m9.viewPart(0, k, 0, 3), i, 0.9);
          assertEquals(predict9, Math.log(s9.size()), 1);

          //assertEquals(10.5 * Math.log(i) - offset0, s0.size(), 10);
        } else if (i > 50) {
          double x = 10.5 * Math.log(i) - s0.size();
          m5.viewRow(k).assign(new double[]{Math.log(s5.size()), Math.log(i), 1});
          m9.viewRow(k).assign(new double[]{Math.log(s9.size()), Math.log(i), 1});

          k++;
          offset0 += (x - offset0) / k;
        }
        if (i > 10000) {
          assertEquals(0.0, (double) hapaxCount(s0) / s0.size(), 0.25);
          assertEquals(0.5, (double) hapaxCount(s5) / s5.size(), 0.1);
          assertEquals(0.9, (double) hapaxCount(s9) / s9.size(), 0.05);
        }
      }
      s0.sample();
      s5.sample();
      s9.sample();
      i++;
    }
  }

  /**
   * Predict the power law growth in number of unique samples from the first few data points.
   * Also check that the fitted growth coefficient is about right.
   *
   * @param m
   * @param currentIndex        Total data points seen so far.  Unique values should be log(currentIndex)*expectedCoefficient + offset.
   * @param expectedCoefficient What slope do we expect.
   * @return The predicted value for log(currentIndex)
   */
  private static double predictSize(Matrix m, int currentIndex, double expectedCoefficient) {
    int rows = m.rowSize();
    Matrix a = m.viewPart(0, rows, 1, 2);
    Matrix b = m.viewPart(0, rows, 0, 1);

    Matrix ata = a.transpose().times(a);
    Matrix atb = a.transpose().times(b);
    QRDecomposition s = new QRDecomposition(ata);
    Matrix r = s.solve(atb).transpose();
    assertEquals(expectedCoefficient, r.get(0, 0), 0.2);
    return r.times(new DenseVector(new double[]{Math.log(currentIndex), 1})).get(0);
  }

  private static int hapaxCount(ChineseRestaurant s) {
    int r = 0;
    for (int i = 0; i < s.size(); i++) {
      if (s.count(i) == 1) {
        r++;
      }
    }
    return r;
  }
}
TOP

Related Classes of org.apache.mahout.math.random.ChineseRestaurantTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.