Package org.apache.mahout.clustering.lda

Source Code of org.apache.mahout.clustering.lda.LDAInference$InferredDocument

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.clustering.lda;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

import org.apache.commons.math.special.Gamma;
import org.apache.mahout.matrix.BinaryFunction;
import org.apache.mahout.matrix.DenseMatrix;
import org.apache.mahout.matrix.DenseVector;
import org.apache.mahout.matrix.Matrix;
import org.apache.mahout.matrix.Vector;

/**
* Class for performing infererence on a document, which involves
* computing (an approximation to) p(word|topic) for each word and
* topic, and a prior distribution p(topic) for each topic.
*/
public class LDAInference {

  private static final double E_STEP_CONVERGENCE = 1.0E-6;

  public LDAInference(LDAState state) {
    this.state = state;
  }

  /**
  * An estimate of the probabilitys for each document.
  * Gamma(k) is the probability of seeing topic k in
  * the document, phi(k,w) is the probability of
  * topic k generating w in this document.
  */
  public static class InferredDocument {

    private final Vector wordCounts;
    private final Vector gamma; // p(topic)
    private final Matrix mphi; // log p(columnMap(w)|t)
    private final Map<Integer, Integer> columnMap; // maps words into the matrix's column map
    public final double logLikelihood;

    public double phi(int k, int w) {
      return mphi.getQuick(k, columnMap.get(w));
    }

    InferredDocument(Vector wordCounts, Vector gamma,
                     Map<Integer, Integer> columnMap, Matrix phi,
                     double ll) {
      this.wordCounts = wordCounts;
      this.gamma = gamma;
      this.mphi = phi;
      this.columnMap = columnMap;
      this.logLikelihood = ll;
    }

    public Vector getWordCounts() {
      return wordCounts;
    }

    public Vector getGamma() {
      return gamma;
    }
  }

  /**
  * Performs inference on the given document, returning
  * an InferredDocument.
  */
  public InferredDocument infer(Vector wordCounts) {
    double docTotal = wordCounts.zSum();
    int docLength = wordCounts.size();

    // initialize variational approximation to p(z|doc)
    Vector gamma = new DenseVector(state.numTopics);
    gamma.assign(state.topicSmoothing + docTotal / state.numTopics);
    Vector nextGamma = new DenseVector(state.numTopics);

    DenseMatrix phi = new DenseMatrix(state.numTopics, docLength);

    // digamma is expensive, precompute
    Vector digammaGamma = digamma(gamma);
    // and log normalize:
    double digammaSumGamma = digamma(gamma.zSum());
    digammaGamma = digammaGamma.plus(-digammaSumGamma);

    Map<Integer, Integer> columnMap = new HashMap<Integer, Integer>();

    int iteration = 0;
    final int MAX_ITER = 20;

    boolean converged = false;
    double oldLL = 1;
    while (!converged && iteration < MAX_ITER) {
      nextGamma.assign(state.topicSmoothing); // nG := alpha, for all topics

      int mapping = 0;
      for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero();
          iter.hasNext();) {
      Vector.Element e = iter.next();
        int word = e.index();
        Vector phiW = eStepForWord(word, digammaGamma);
        phi.assignColumn(mapping, phiW);
        if (iteration == 0) { // first iteration
          columnMap.put(word, mapping);
        }

        for (int k = 0; k < nextGamma.size(); ++k) {
          double g = nextGamma.getQuick(k);
          nextGamma.setQuick(k, g + e.get() * Math.exp(phiW.get(k)));
        }

        mapping++;
      }

      Vector tempG = gamma;
      gamma = nextGamma;
      nextGamma = tempG;

      // digamma is expensive, precompute
      digammaGamma = digamma(gamma);
      // and log normalize:
      digammaSumGamma = digamma(gamma.zSum());
      digammaGamma = digammaGamma.plus(-digammaSumGamma);

      double ll = computeLikelihood(wordCounts, columnMap, phi, gamma, digammaGamma);
      assert !Double.isNaN(ll);
      converged = oldLL < 0 && ((oldLL - ll) / oldLL < E_STEP_CONVERGENCE);

      oldLL = ll;
      iteration++;
    }

    return new InferredDocument(wordCounts, gamma, columnMap, phi, oldLL);
  }

  private final LDAState state;

  private double computeLikelihood(Vector wordCounts, Map<Integer, Integer> columnMap,
      Matrix phi, Vector gamma, Vector digammaGamma) {
    double ll = 0.0;

    // log normalizer for q(gamma);
    ll += Gamma.logGamma(state.topicSmoothing * state.numTopics);
    ll -= state.numTopics * Gamma.logGamma(state.topicSmoothing);
    assert !Double.isNaN(ll) : state.topicSmoothing + " " + state.numTopics;

    // now for the the rest of q(gamma);
    for (int k = 0; k < state.numTopics; ++k) {
      ll += (state.topicSmoothing - gamma.get(k)) * digammaGamma.get(k);
      ll += Gamma.logGamma(gamma.get(k));

    }
    ll -= Gamma.logGamma(gamma.zSum());
    assert !Double.isNaN(ll);


    // for each word
    for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero();
        iter.hasNext();) {
      Vector.Element e = iter.next();
      int w = e.index();
      double n = e.get();
      int mapping = columnMap.get(w);
      // now for each topic:
      for (int k = 0; k < state.numTopics; k++) {
        double llPart = 0.0;
        llPart += Math.exp(phi.get(k, mapping))
          * (digammaGamma.get(k) - phi.get(k, mapping)
             + state.logProbWordGivenTopic(w, k));

        ll += llPart * n;

        assert state.logProbWordGivenTopic(w, k< 0;
        assert !Double.isNaN(llPart);
      }
    }
    assert ll <= 0;
    return ll;
  }

  /**
   * Compute log q(k|w,doc) for each topic k, for a given word.
   */
  private Vector eStepForWord(int word, Vector digammaGamma) {
    Vector phi = new DenseVector(state.numTopics); // log q(k|w), for each w
    double phiTotal = Double.NEGATIVE_INFINITY; // log Normalizer
    for (int k = 0; k < state.numTopics; ++k) { // update q(k|w)'s param phi
      phi.set(k, state.logProbWordGivenTopic(word, k) + digammaGamma.get(k));
      phiTotal = LDAUtil.logSum(phiTotal, phi.get(k));

      assert !Double.isNaN(phiTotal);
      assert !Double.isNaN(state.logProbWordGivenTopic(word, k));
      assert !Double.isInfinite(state.logProbWordGivenTopic(word, k));
      assert !Double.isNaN(digammaGamma.get(k));
    }
    return phi.plus(-phiTotal); // log normalize
  }


  private static Vector digamma(Vector v) {
    Vector digammaGamma = new DenseVector(v.size());
    digammaGamma.assign(v, new BinaryFunction() {
      @Override
      public double apply(double unused, double g) {
        return digamma(g);
      }
    });
    return digammaGamma;
  }

  /**
   * Approximation to the digamma function, from Radford Neal.
   *
   * Original License:
   * Copyright (c) 1995-2003 by Radford M. Neal
   *
   * Permission is granted for anyone to copy, use, modify, or distribute this
   * program and accompanying programs and documents for any purpose, provided
   * this copyright notice is retained and prominently displayed, along with
   * a note saying that the original programs are available from Radford Neal's
   * web page, and note is made of any changes made to the programs.  The
   * programs and documents are distributed without any warranty, express or
   * implied.  As the programs were written for research purposes only, they have
   * not been tested to the degree that would be advisable in any important
   * application.  All use of these programs is entirely at the user's own risk.
   *
   *
   * Ported to Java for Mahout.
   *
   */
  private static double digamma(double x) {
    double r = 0.0;

    while (x <= 5) {
      r -= 1 / x;
      x += 1;
    }

    double f = 1.0 / (x * x);
    double t = f * (-1 / 12.0
        + f * (1 / 120.0
        + f * (-1 / 252.0
        + f * (1 / 240.0
        + f * (-1 / 132.0
        + f * (691 / 32760.0
        + f * (-1 / 12.0
        + f * 3617.0 / 8160.0)))))));
    return r + Math.log(x) - 0.5 / x + t;
  }

}
TOP

Related Classes of org.apache.mahout.clustering.lda.LDAInference$InferredDocument

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.