Package cc.mallet.grmm.types

Source Code of cc.mallet.grmm.types.DirectedModel

/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.grmm.types; // Generated package name

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Iterator;
import java.util.Map;

import org._3pq.jgrapht.DirectedGraph;
import org._3pq.jgrapht.graph.DefaultDirectedGraph;
import org._3pq.jgrapht.alg.ConnectivityInspector;
import gnu.trove.THashMap;


/**
*  Class for directed graphical models. This is just a
*   souped-up Graph.
*
* Created: Mon Sep 15 14:50:19 2003
*
* @author <a href="mailto:casutton@cs.umass.edu">Charles Sutton</a>
* @version $Id: DirectedModel.java,v 1.1 2007/10/22 21:37:44 mccallum Exp $
*/

public class DirectedModel extends FactorGraph {

  private Map allCpts = new THashMap ();

  // Graph object used to prevent directed cycles
  private DirectedGraph graph = new DefaultDirectedGraph ();

  public DirectedModel ()
  {
  }

  public DirectedModel (Variable[] vars)
  {
    super (vars);
  }

  public DirectedModel (int capacity)
  {
    super (capacity);
  }

  protected void beforeFactorAdd (Factor factor)
  {
    super.beforeFactorAdd (factor);
    if (!(factor instanceof CPT)) {
      throw new IllegalArgumentException ("Factors of a directed model must be an instance of CPT, was "+factor);
    }

    CPT cpt = (CPT) factor;
    Variable child = cpt.getChild ();
    VarSet parents = cpt.getParents ();
    if (graph.containsVertex (child)) {
      checkForNoCycle (parents, child, cpt);
    }
  }

  private void checkForNoCycle (VarSet parents, Variable child, CPT cpt) {
    ConnectivityInspector inspector = new ConnectivityInspector (graph);
    for (Iterator it = parents.iterator (); it.hasNext ();) {
      Variable rent = (Variable) it.next ();
      if (inspector.pathExists (child, rent)) {
        throw new IllegalArgumentException ("Error adding CPT: Would create directed cycle"+
                        "From: "+rent+" To:"+child+"\nCPT: "+cpt);
      }
    }
  }

  protected void afterFactorAdd (Factor factor)
  {
    super.afterFactorAdd (factor);
    CPT cpt = (CPT) factor;
    Variable child = cpt.getChild ();
    VarSet parents = cpt.getParents ();
    allCpts.put (child, cpt);

    graph.addVertex (child);
    graph.addAllVertices (parents);
    for (Iterator it = parents.iterator (); it.hasNext ();) {
      Variable rent = (Variable) it.next ();
      graph.addEdge (rent, child);
    }
  }

  /**
   *  Returns the conditional distribution <tt>P ( node | Parents (node) )</tt>
   */
  public CPT getCptofVar (Variable node)
  {
    return (CPT) allCpts.get (node);
  }

  // Serialization garbage

  private static final long serialVersionUID = 1;
  private static final int CURRENT_SERIAL_VERSION = 1;

  private void writeObject (ObjectOutputStream out) throws IOException
  {
    out.defaultWriteObject ();
    out.writeInt (CURRENT_SERIAL_VERSION);
  }


  private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
  {
    in.defaultReadObject ();
    int version = in.readInt ();
  }

}// DirectedModel
TOP

Related Classes of cc.mallet.grmm.types.DirectedModel

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.