Package org.apache.mahout.math.random

Source Code of org.apache.mahout.math.random.MultiNormal

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.math.random;

import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.DiagonalMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleFunction;

import java.util.Random;

/**
* Samples from a multi-variate normal distribution.
* <p/>
* This is done by sampling from several independent unit normal distributions to get a vector u.
* The sample value that is returned is then A u + m where A is derived from the covariance matrix
* and m is the mean of the result.
* <p/>
* If \Sigma is the desired covariance matrix, then you can use any value of A such that A' A =
* \Sigma.  The Cholesky decomposition can be used to compute A if \Sigma is positive definite.
* Slightly more expensive is to use the SVD U S V' = \Sigma and then set A = U \sqrt{S}.
*
* Useful special cases occur when \Sigma is diagonal so that A = \sqrt(\Sigma) or where \Sigma = r I.
*
* Another special case is where m = 0.
*/
public class MultiNormal implements Sampler<Vector> {
  private final Random gen;
  private final int dimension;
  private final Matrix scale;
  private final Vector mean;

  /**
   * Constructs a sampler with diagonal scale matrix.
   * @param diagonal The diagonal elements of the scale matrix.
   */
  public MultiNormal(Vector diagonal) {
    this(new DiagonalMatrix(diagonal), null);
  }

  /**
   * Constructs a sampler with diagonal scale matrix and (potentially)
   * non-zero mean.
   * @param diagonal The scale matrix's principal diagonal.
   * @param mean The desired mean.  Set to null if zero mean is desired.
   */
  public MultiNormal(Vector diagonal, Vector mean) {
    this(new DiagonalMatrix(diagonal), mean);
  }

  /**
   * Constructs a sampler with non-trivial scale matrix and mean.
   */
  public MultiNormal(Matrix a, Vector mean) {
    this(a, mean, a.columnSize());
  }

  public MultiNormal(int dimension) {
    this(null, null, dimension);
  }

  public MultiNormal(double radius, Vector mean) {
    this(new DiagonalMatrix(radius, mean.size()), mean);
  }

  private MultiNormal(Matrix scale, Vector mean, int dimension) {
    gen = RandomUtils.getRandom();
    this.dimension = dimension;
    this.scale = scale;
    this.mean = mean;
  }

  @Override
  public Vector sample() {
    Vector v = new DenseVector(dimension).assign(
      new DoubleFunction() {
        @Override
        public double apply(double ignored) {
          return gen.nextGaussian();
        }
      }
    );
    if (mean != null) {
      if (scale != null) {
        return scale.times(v).plus(mean);
      } else {
        return v.plus(mean);
      }
    } else {
      if (scale != null) {
        return scale.times(v);
      } else {
        return v;
      }
    }
  }

  public Vector getScale() {
    return mean;
  }
}
TOP

Related Classes of org.apache.mahout.math.random.MultiNormal

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.