Package cc.mallet.grmm.inference

Source Code of cc.mallet.grmm.inference.JunctionTreePropagation

/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://mallet.cs.umass.edu/
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.inference;


import java.util.Collection;
import java.util.Iterator;
import java.util.logging.Logger;
import java.util.logging.Level;
import java.io.Serializable;
import java.io.ObjectOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;

import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.HashVarSet;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.util.MalletLogger;

/**
* An implementation of Hugin-style propagation for junction trees.
* This destructively modifies the junction tree so that its clique potentials
* are the true marginals of the underlying graph.
* <p/>
* End users will not usually need to use this class directly.  Use
* <tt>JunctionTreeInferencer</tt> instead.
* <p/>
* This class is not an instance of Inferencer because it destructively
* modifies the junction tree, which the Inferencer methods do not do to
* factor graphs.
* <p/>
* Created: Feb 1, 2006
*
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: JunctionTreePropagation.java,v 1.1 2007/10/22 21:37:49 mccallum Exp $
*/
class JunctionTreePropagation implements Serializable {

  private static Logger logger = MalletLogger.getLogger (JunctionTreePropagation.class.getName ());

  transient private int totalMessagesSent = 0;

  private MessageStrategy strategy;

  public JunctionTreePropagation (MessageStrategy strategy)
  {
    this.strategy = strategy;
  }

  public static JunctionTreePropagation createSumProductInferencer ()
  {
    return new JunctionTreePropagation (new SumProductMessageStrategy ());
  }

  public static JunctionTreePropagation createMaxProductInferencer ()
  {
    return new JunctionTreePropagation (new MaxProductMessageStrategy ());
  }


  public int getTotalMessagesSent ()
  {
    return totalMessagesSent;
  }

  public void computeMarginals (JunctionTree jt)
  {
    propagate (jt);
    jt.normalizeAll ();      // Necessary if jt originally unnormalized
  }

/* Hugin-style propagation for junction trees */

  // bottom-up pass
  private void collectEvidence (JunctionTree jt, VarSet parent, VarSet child)
  {
    logger.finer ("collectEvidence " + parent + " --> " + child);
    for (Iterator it = jt.getChildren (child).iterator (); it.hasNext ();) {
      VarSet gchild = (VarSet) it.next ();
      collectEvidence (jt, child, gchild);
    }
    if (parent != null) {
      totalMessagesSent++;
      strategy.sendMessage (jt, child, parent);
    }
  }

  // top-down pass
  private void distributeEvidence (JunctionTree jt, VarSet parent)
  {
    for (Iterator it = jt.getChildren (parent).iterator (); it.hasNext ();) {
      VarSet child = (VarSet) it.next ();
      totalMessagesSent++;
      strategy.sendMessage (jt, parent, child);
      distributeEvidence (jt, child);
    }
  }

  private void propagate (JunctionTree jt)
  {
    VarSet root = (VarSet) jt.getRoot ();
    collectEvidence (jt, null, root);
    distributeEvidence (jt, root);
  }


  public Factor lookupMarginal (JunctionTree jt, VarSet varSet)
  {
    if (jt == null) { throw new IllegalStateException ("Call computeMarginals() first."); }

    VarSet parent = jt.findParentCluster (varSet);
    if (parent == null) {
      throw new UnsupportedOperationException
              ("No parent cluster in " + jt + " for clique " + varSet);
    }

    Factor cpf = jt.getCPF (parent);
    if (logger.isLoggable (Level.FINER)) {
      logger.finer ("Lookup jt marginal: clique " + varSet + " cluster " + parent);
      logger.finest ("  cpf " + cpf);
    }

    Factor marginal = strategy.extractBelief (cpf, varSet);
    marginal.normalize ();

    return marginal;
  }

  public Factor lookupMarginal (JunctionTree jt, Variable var)
  {
    if (jt == null) { throw new IllegalStateException ("Call computeMarginals() first."); }

    VarSet parent = jt.findParentCluster (var);
    Factor cpf = jt.getCPF (parent);
    if (logger.isLoggable (Level.FINER)) {
      logger.finer ("Lookup jt marginal: var " + var + " cluster " + parent);
      logger.finest (" cpf " + cpf);
    }

    Factor marginal = strategy.extractBelief (cpf, new HashVarSet (new Variable[] { var }));
    marginal.normalize ();

    return marginal;
  }

  ///////////////////////////////////////////////////////////////////////////
  //   MEESAGE STRATEGIES
  ///////////////////////////////////////////////////////////////////////////


  /**
   * Implements a strategy pattern for message sending.  This allows sum-product
   * and max-product messages, e.g., to be different implementations of this strategy.
   */
  public interface MessageStrategy {

    /**
     * Sends a message from the clique FROM to TO in a junction tree.
     */
    public void sendMessage (JunctionTree jt, VarSet from, VarSet to);

    public Factor extractBelief (Factor cpf, VarSet varSet);

  }

  public static class SumProductMessageStrategy implements MessageStrategy, Serializable {

    /**
     * This sends a sum-product message, normalized to avoid
     * underflow.
     */
    public void sendMessage (JunctionTree jt, VarSet from, VarSet to)
    {
      Collection sepset = jt.getSepset (from, to);
      Factor fromCpf = jt.getCPF (from);
      Factor toCpf = jt.getCPF (to);
      Factor oldSepsetPot = jt.getSepsetPot (from, to);
      Factor lambda = fromCpf.marginalize (sepset);

      lambda.normalize ();

      jt.setSepsetPot (lambda, from, to);
      toCpf = toCpf.multiply (lambda);
      toCpf.divideBy (oldSepsetPot);
      toCpf.normalize ();
      jt.setCPF (to, toCpf);
    }

    public Factor extractBelief (Factor cpf, VarSet varSet)
    {
      return cpf.marginalize (varSet);
    }

    // Serialization
    private static final long serialVersionUID = 1;
    private static final int CUURENT_SERIAL_VERSION = 1;

    private void writeObject (ObjectOutputStream out) throws IOException
    {
      out.defaultWriteObject ();
      out.writeInt (CUURENT_SERIAL_VERSION);
    }

    private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
    {
      in.defaultReadObject ();
      in.readInt (); // version
    }
  }


  public static class MaxProductMessageStrategy implements MessageStrategy, Serializable {

    /**
     * This sends a max-product message.
     */
    public void sendMessage (JunctionTree jt, VarSet from, VarSet to)
    {
//      System.err.println ("Send message "+from+" --> "+to);
      Collection sepset = jt.getSepset (from, to);
      Factor fromCpf = jt.getCPF (from);
      Factor toCpf = jt.getCPF (to);
      Factor oldSepsetPot = jt.getSepsetPot (from, to);
      Factor lambda = fromCpf.extractMax (sepset);

      lambda.normalize ();

      jt.setSepsetPot (lambda, from, to);
      toCpf = toCpf.multiply (lambda);
      toCpf.divideBy (oldSepsetPot);
      toCpf.normalize ();
      jt.setCPF (to, toCpf);
    }

    public Factor extractBelief (Factor cpf, VarSet varSet)
    {
      return cpf.extractMax (varSet);
    }

    // Serialization
    private static final long serialVersionUID = 1;
    private static final int CUURENT_SERIAL_VERSION = 1;

    private void writeObject (ObjectOutputStream out) throws IOException
    {
      out.defaultWriteObject ();
      out.writeInt (CUURENT_SERIAL_VERSION);
    }

    private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
    {
      in.defaultReadObject ();
      in.readInt (); // version
    }
  }


  // Serialization
  private static final long serialVersionUID = 1;
  private static final int CUURENT_SERIAL_VERSION = 1;

  private void writeObject (ObjectOutputStream out) throws IOException
  {
    out.defaultWriteObject ();
    out.writeInt (CUURENT_SERIAL_VERSION);
  }

  private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
  {
    in.defaultReadObject ();
    in.readInt (); // version
  }

}
TOP

Related Classes of cc.mallet.grmm.inference.JunctionTreePropagation

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.