Package org.apache.mahout.clustering.dirichlet.models

Source Code of org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalModel

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.clustering.dirichlet.models;

import org.apache.mahout.matrix.AbstractVector;
import org.apache.mahout.matrix.SquareRootFunction;
import org.apache.mahout.matrix.Vector;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;

public class AsymmetricSampledNormalModel implements Model<Vector> {

  private static final double sqrt2pi = Math.sqrt(2.0 * Math.PI);

  // the parameters
  private Vector mean;

  private Vector stdDev;

  // the observation statistics, initialized by the first observation
  private int s0 = 0;

  private Vector s1;

  private Vector s2;

  public AsymmetricSampledNormalModel() {
    super();
  }

  public AsymmetricSampledNormalModel(Vector mean, Vector stdDev) {
    super();
    this.mean = mean;
    this.stdDev = stdDev;
    this.s0 = 0;
    this.s1 = mean.like();
    this.s2 = mean.like();
  }

  public Vector getMean() {
    return mean;
  }

  public Vector getStdDev() {
    return stdDev;
  }

  /**
   * Return an instance with the same parameters
   *
   * @return an AsymmetricSampledNormalModel
   */
  AsymmetricSampledNormalModel sample() {
    return new AsymmetricSampledNormalModel(mean, stdDev);
  }

  @Override
  public void observe(Vector x) {
    s0++;
    if (s1 == null) {
      s1 = x.clone();
    } else {
      s1 = s1.plus(x);
    }
    if (s2 == null) {
      s2 = x.times(x);
    } else {
      s2 = s2.plus(x.times(x));
    }
  }

  @Override
  public void computeParameters() {
    if (s0 == 0) {
      return;
    }
    mean = s1.divide(s0);
    // compute the two component stds
    if (s0 > 1) {
      stdDev = s2.times(s0).minus(s1.times(s1)).assign(new SquareRootFunction())
          .divide(s0);
    } else {
      stdDev.assign(Double.MIN_NORMAL);
    }
  }

  /**
   * Calculate a pdf using the supplied sample and stdDev
   *
   * @param x  a Vector sample
   * @param sd a double std deviation
   */
  private double pdf(Vector x, double sd) {
    if (x.getNumNondefaultElements() != 2) {
      throw new IllegalArgumentException();
    }
    double sd2 = sd * sd;
    double exp = -(x.dot(x) - 2 * x.dot(mean) + mean.dot(mean)) / (2 * sd2);
    double ex = Math.exp(exp);
    return ex / (sd * sqrt2pi);
  }

  @Override
  public double pdf(Vector x) {
    // return the product of the two component pdfs
    if (x.getNumNondefaultElements() != 2) {
      throw new IllegalArgumentException();
    }
    double pdf0 = pdf(x, stdDev.get(0));
    double pdf1 = pdf(x, stdDev.get(1));
    // if (pdf0 < 0 || pdf0 > 1 || pdf1 < 0 || pdf1 > 1)
    // System.out.print("");
    return pdf0 * pdf1;
  }

  @Override
  public int count() {
    return s0;
  }

  @Override
  public String toString() {
    StringBuilder buf = new StringBuilder(50);
    buf.append("asnm{n=").append(s0).append(" m=[");
    if (mean != null) {
      for (int i = 0; i < mean.size(); i++) {
        buf.append(String.format("%.2f", mean.get(i))).append(", ");
      }
    }
    buf.append("] sd=[");
    if (stdDev != null) {
      for (int i = 0; i < stdDev.size(); i++) {
        buf.append(String.format("%.2f", stdDev.get(i))).append(", ");
      }
    }
    buf.append("]}");
    return buf.toString();
  }

  @Override
  public void readFields(DataInput in) throws IOException {
    this.mean = AbstractVector.readVector(in);
    this.stdDev = AbstractVector.readVector(in);
    this.s0 = in.readInt();
    this.s1 = AbstractVector.readVector(in);
    this.s2 = AbstractVector.readVector(in);
  }

  @Override
  public void write(DataOutput out) throws IOException {
    AbstractVector.writeVector(out, mean);
    AbstractVector.writeVector(out, stdDev);
    out.writeInt(s0);
    AbstractVector.writeVector(out, s1);
    AbstractVector.writeVector(out, s2);
  }
}
TOP

Related Classes of org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalModel

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.