/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.clustering;
import java.util.ArrayList;
import java.util.Collection;
import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.SquareRootFunction;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public final class TestGaussianAccumulators extends MahoutTestCase {
private static final Logger log = LoggerFactory.getLogger(TestGaussianAccumulators.class);
private Collection<VectorWritable> sampleData = new ArrayList<VectorWritable>();
private int sampleN;
private Vector sampleMean;
private Vector sampleStd;
@Override
@Before
public void setUp() throws Exception {
super.setUp();
sampleData = new ArrayList<VectorWritable>();
generateSamples();
sampleN = 0;
Vector sum = new DenseVector(2);
for (VectorWritable v : sampleData) {
v.get().addTo(sum);
sampleN++;
}
sampleMean = sum.divide(sampleN);
Vector sampleVar = new DenseVector(2);
for (VectorWritable v : sampleData) {
Vector delta = v.get().minus(sampleMean);
delta.times(delta).addTo(sampleVar);
}
sampleVar = sampleVar.divide(sampleN - 1);
sampleStd = sampleVar.clone();
sampleStd.assign(new SquareRootFunction());
log.info("Observing {} samples m=[{}, {}] sd=[{}, {}]",
new Object[] { sampleN, sampleMean.get(0), sampleMean.get(1), sampleStd.get(0), sampleStd.get(1) });
}
/**
* Generate random samples and add them to the sampleData
*
* @param num
* int number of samples to generate
* @param mx
* double x-value of the sample mean
* @param my
* double y-value of the sample mean
* @param sdx
* double x-value standard deviation of the samples
* @param sdy
* double y-value standard deviation of the samples
*/
private void generate2dSamples(int num, double mx, double my, double sdx, double sdy) {
log.info("Generating {} samples m=[{}, {}] sd=[{}, {}]", new Object[] { num, mx, my, sdx, sdy });
for (int i = 0; i < num; i++) {
sampleData.add(new VectorWritable(new DenseVector(new double[] { UncommonDistributions.rNorm(mx, sdx),
UncommonDistributions.rNorm(my, sdy) })));
}
}
private void generateSamples() {
generate2dSamples(50000, 1, 2, 3, 4);
}
@Test
public void testAccumulatorNoSamples() {
GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
accumulator0.compute();
accumulator1.compute();
assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
assertEquals("Means", accumulator0.getMean(), accumulator1.getMean());
assertEquals("Avg Stds", accumulator0.getAverageStd(), accumulator1.getAverageStd(), EPSILON);
}
@Test
public void testAccumulatorOneSample() {
GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
Vector sample = new DenseVector(2);
accumulator0.observe(sample, 1.0);
accumulator1.observe(sample, 1.0);
accumulator0.compute();
accumulator1.compute();
assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
assertEquals("Means", accumulator0.getMean(), accumulator1.getMean());
assertEquals("Avg Stds", accumulator0.getAverageStd(), accumulator1.getAverageStd(), EPSILON);
}
@Test
public void testOLAccumulatorResults() {
GaussianAccumulator accumulator = new OnlineGaussianAccumulator();
for (VectorWritable vw : sampleData) {
accumulator.observe(vw.get(), 1.0);
}
accumulator.compute();
log.info("OL Observed {} samples m=[{}, {}] sd=[{}, {}]",
new Object[] {
accumulator.getN(),
accumulator.getMean().get(0),
accumulator.getMean().get(1),
accumulator.getStd().get(0),
accumulator.getStd().get(1)
});
assertEquals("OL N", sampleN, accumulator.getN(), EPSILON);
assertEquals("OL Mean", sampleMean.zSum(), accumulator.getMean().zSum(), EPSILON);
assertEquals("OL Std", sampleStd.zSum(), accumulator.getStd().zSum(), EPSILON);
}
@Test
public void testRSAccumulatorResults() {
GaussianAccumulator accumulator = new RunningSumsGaussianAccumulator();
for (VectorWritable vw : sampleData) {
accumulator.observe(vw.get(), 1.0);
}
accumulator.compute();
log.info("RS Observed {} samples m=[{}, {}] sd=[{}, {}]",
new Object[] {
(int) accumulator.getN(),
accumulator.getMean().get(0),
accumulator.getMean().get(1),
accumulator.getStd().get(0),
accumulator.getStd().get(1)
});
assertEquals("OL N", sampleN, accumulator.getN(), EPSILON);
assertEquals("OL Mean", sampleMean.zSum(), accumulator.getMean().zSum(), EPSILON);
assertEquals("OL Std", sampleStd.zSum(), accumulator.getStd().zSum(), 0.0001);
}
@Test
public void testAccumulatorWeightedResults() {
GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
for (VectorWritable vw : sampleData) {
accumulator0.observe(vw.get(), 0.5);
accumulator1.observe(vw.get(), 0.5);
}
accumulator0.compute();
accumulator1.compute();
assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
assertEquals("Means", accumulator0.getMean().zSum(), accumulator1.getMean().zSum(), EPSILON);
assertEquals("Stds", accumulator0.getStd().zSum(), accumulator1.getStd().zSum(), 0.001);
assertEquals("Variance", accumulator0.getVariance().zSum(), accumulator1.getVariance().zSum(), 0.01);
}
@Test
public void testAccumulatorWeightedResults2() {
GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
for (VectorWritable vw : sampleData) {
accumulator0.observe(vw.get(), 1.5);
accumulator1.observe(vw.get(), 1.5);
}
accumulator0.compute();
accumulator1.compute();
assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
assertEquals("Means", accumulator0.getMean().zSum(), accumulator1.getMean().zSum(), EPSILON);
assertEquals("Stds", accumulator0.getStd().zSum(), accumulator1.getStd().zSum(), 0.001);
assertEquals("Variance", accumulator0.getVariance().zSum(), accumulator1.getVariance().zSum(), 0.01);
}
}