Package cc.mallet.grmm.inference.gbp

Source Code of cc.mallet.grmm.inference.gbp.ParentChildGBP

/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.inference.gbp;


import java.util.logging.Logger;
import java.util.logging.Level;
import java.util.*;

import cc.mallet.grmm.inference.AbstractInferencer;
import cc.mallet.grmm.types.*;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Timing;

/**
* Created: May 27, 2005
*
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: ParentChildGBP.java,v 1.1 2007/10/22 21:37:58 mccallum Exp $
*/
public class ParentChildGBP extends AbstractInferencer {

  private static final Logger logger = MalletLogger.getLogger (ParentChildGBP.class.getName());
  private static final boolean debug = false;

  private RegionGraphGenerator regioner;
  private MessageStrategy sender;

  private boolean useInertia = true;
  private double inertiaWeight = 0.5;

  // convergence criteria

  private static final double THRESHOLD = 1e-3;
  private static final int MAX_ITER = 500;

  // current inferencing state

  private MessageArray oldMessages;
  private MessageArray newMessages;
  private RegionGraph rg;
  private FactorGraph mdl;

  private ParentChildGBP ()
  {
  }

  public ParentChildGBP (RegionGraphGenerator regioner)
  {
    this (regioner, new FullMessageStrategy ());
  }

  public ParentChildGBP (RegionGraphGenerator regioner, MessageStrategy sender)
  {
    this.regioner = regioner;
    this.sender = sender;
  }

  public static ParentChildGBP makeBPInferencer ()
  {
    ParentChildGBP inferencer = new ParentChildGBP ();
    inferencer.regioner = new BPRegionGenerator ();
    inferencer.sender = new FullMessageStrategy ();
    return inferencer;
  }

  public static ParentChildGBP makeKikuchiInferencer ()
  {
    ParentChildGBP inferencer = new ParentChildGBP ();
    inferencer.regioner = new Kikuchi4SquareRegionGenerator ();
    inferencer.sender = new FullMessageStrategy ();
    return inferencer;
  }

  // accessors

  public boolean getUseInertia ()
  {
    return useInertia;
  }

  public void setUseInertia (boolean useInertia)
  {
    this.useInertia = useInertia;
  }

  public double getInertiaWeight ()
  {
    return inertiaWeight;
  }

  public void setInertiaWeight (double inertiaWeight)
  {
    this.inertiaWeight = inertiaWeight;
  }
  // inferencer interface

  public Factor lookupMarginal (Variable variable)
  {
    Region region = rg.findContainingRegion (variable);
    if (region == null)
      throw new IllegalArgumentException ("Could not find region containing variable "+variable+" in region graph "+rg);

    Factor belief = computeBelief (region);
    Factor varBelief = belief.marginalize (variable);
    return varBelief;
  }


  public Factor lookupMarginal (VarSet varSet)
  {
    Region region = rg.findContainingRegion (varSet);
    if (region == null)
      throw new IllegalArgumentException ("Could not find region containing clique "+varSet +" in region graph "+rg);

    Factor belief = computeBelief (region);
    Factor cliqueBelief = belief.marginalize (varSet);
    return cliqueBelief;
  }


  private Factor computeBelief (Region region)
  {
    return computeBelief (region, newMessages);
  }

  static Factor computeBelief (Region region, MessageArray messages)
  {
    DiscreteFactor result = new LogTableFactor(region.vars);

    for (Iterator it = region.factors.iterator(); it.hasNext();) {
      Factor factor = (Factor) it.next();
      result.multiplyBy(factor);
    }

    for (Iterator it = region.parents.iterator(); it.hasNext();) {
      Region parent = (Region) it.next();
      Factor msg = messages.getMessage(parent, region);
      result.multiplyBy(msg);
    }

    for (Iterator it = region.descendants.iterator(); it.hasNext();) {
      Region child = (Region) it.next();
      for (Iterator it2 = child.parents.iterator(); it2.hasNext();) {
        Region uncle = (Region) it2.next();
        if (uncle != region && !region.descendants.contains(uncle)) {
          result.multiplyBy(messages.getMessage(uncle, child));
        }
      }
    }

    result.normalize();

    return result;
  }

  public double lookupLogJoint (Assignment assn)
  {
    double factorProduct = mdl.logValue (assn);
//    value += computeFreeEnergy (rg);
    double F = computeFreeEnergy (rg);

    double value = factorProduct + F;

    if (debug)
      System.err.println ("GBP factor product:"+factorProduct+" + free energy: "+F+" = value:"+value);

    return value;
  }

  private double computeFreeEnergy (RegionGraph rg)
  {
    double avgEnergy = 0;
    double entropy = 0;
    for (Iterator it = rg.iterator (); it.hasNext();) {
      Region region = (Region) it.next();
      Factor belief = computeBelief(region);
      double thisEntropy = belief.entropy();

      if (debug)
        System.err.println("Region " + region + " c:" + region.countingNumber + "  entropy:" + thisEntropy);

      entropy += region.countingNumber * thisEntropy;

      DiscreteFactor product = new LogTableFactor(belief.varSet());
      for (Iterator ptlIt = region.factors.iterator(); ptlIt.hasNext();) {
        Factor ptl = (Factor) ptlIt.next();
        product.multiplyBy(ptl);
      }

      double thisAvgEnergy = 0;
      for (AssignmentIterator assnIt = belief.assignmentIterator(); assnIt.hasNext();) {
        Assignment assn = assnIt.assignment();

        // Note: Do not use assnIt here before fixing variable ordering issues.
        double thisEnergy = -product.logValue(assn);
//        double thisEnergy = product.phi (assnIt);
        double thisBel = belief.value(assn);
        thisAvgEnergy += thisBel * thisEnergy;
        assnIt.advance();
      }

      if (debug) {
        System.err.println("Region " + region + " c:" + region.countingNumber + " avgEnergy: " + thisAvgEnergy);
/*        DiscretePotential b2 = belief.duplicate ();
        b2.delogify ();
        System.err.println ("BELIEF:"+b2);
        System.err.println ("ENERGY:"+product);
        */
      }
      avgEnergy += region.countingNumber * thisAvgEnergy;

    }

    if (debug)
      System.err.println ("GBP computeFreeEnergy: avgEnergy:"+avgEnergy+"  entropy:"+entropy+"  free energy:"+(avgEnergy-entropy));

//    return avgEnergy + entropy;
    return avgEnergy - entropy;
  }

  public void computeMarginals (FactorGraph mdl)
  {
    Timing timing = new Timing ();

    this.mdl = mdl;
    rg = regioner.constructRegionGraph (mdl);
    RegionEdge[] pairs = chooseMessageSendingOrder ();

    newMessages = new MessageArray (rg);

    timing.tick ("GBP Region Graph construction");
   
    int iter = 0;
    do {

      oldMessages = newMessages;
      newMessages = oldMessages.duplicate ();
      sender.setMessageArray (oldMessages, newMessages);

      for (int i = 0; i < pairs.length; i++) {
        RegionEdge edge = pairs[i];
        sender.sendMessage (edge);
      }

      if (logger.isLoggable (Level.FINER)) {
        timing.tick ("GBP iteration "+iter);
      }

      iter++;

      if (useInertia)
        newMessages = sender.averageMessages (rg, oldMessages, newMessages, inertiaWeight);

    } while (!hasConverged () && (iter < MAX_ITER));

    logger.info ("GBP: Used "+iter+" iterations.");
    if (iter >= MAX_ITER) {
      logger.warning ("***WARNING: GBP not converged!");
    }
  }

  private RegionEdge[] chooseMessageSendingOrder ()
  {
    List l = new ArrayList ();
    for (Iterator it = rg.edgeIterator (); it.hasNext();) {
      RegionEdge edge = (RegionEdge) it.next ();
      l.add (edge);
    }

    Collections.sort (l, new Comparator () {
      public int compare (Object o1, Object o2)
      {
        RegionEdge e1 = (RegionEdge) o1;
        RegionEdge e2 = (RegionEdge) o2;
        int l1 = e1.to.vars.size();
        int l2 = e2.to.vars.size();
        return Double.compare (l1, l2);
      };
    });

    return (RegionEdge[]) l.toArray (new RegionEdge [l.size()]);
  }

  private boolean hasConverged ()
  {
    for (Iterator it = rg.edgeIterator (); it.hasNext();) {
      RegionEdge edge = (RegionEdge) it.next ();
      Factor oldMsg = oldMessages.getMessage (edge.from, edge.to);
      Factor newMsg = newMessages.getMessage (edge.from, edge.to);
      if (oldMsg == null) {
        assert newMsg == null;
      } else {
        if (!oldMsg.almostEquals (newMsg, THRESHOLD)) {
          /*
         //xxx debug
          if (sender instanceof SparseMessageSender)
            System.out.println ("NOT CONVERGED:\n"+newMsg+"\n.......");
          */
          return false;
        }
      }
    }

    return true;
  }

  public void dump ()
  {
    for (Iterator it = rg.edgeIterator (); it.hasNext();) {
      RegionEdge edge = (RegionEdge) it.next ();
      Factor newMsg = newMessages.getMessage (edge.from, edge.to);
      System.out.println ("Message: "+edge.from+" --> "+edge.to+" "+newMsg);
    }
  }

}
TOP

Related Classes of cc.mallet.grmm.inference.gbp.ParentChildGBP

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.