Package cc.mallet.grmm.types

Source Code of cc.mallet.grmm.types.PottsTableFactor

/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.types;

import cc.mallet.grmm.util.Matrices;
import cc.mallet.types.Matrix;
import cc.mallet.types.SparseMatrixn;
import cc.mallet.util.Randoms;

/**
* A factor over a continuous variable alpha and discrete variables <tt>x</tt>
*  such that <tt>phi(x|alpha)<tt> is Potts.  That is, for fixed alpha, <tt>phi(x)</tt> = 1
*  if all x are equal, and <tt>exp^{-alpha}</tt> otherwise.
* $Id: PottsTableFactor.java,v 1.1 2007/10/22 21:37:44 mccallum Exp $
*/
public class PottsTableFactor extends AbstractFactor implements ParameterizedFactor {

  private Variable alpha;
  private VarSet xs;

  public PottsTableFactor (VarSet xs, Variable alpha)
  {
    super (combineVariables (alpha, xs));
    this.alpha = alpha;
    this.xs = xs;
    if (!alpha.isContinuous ()) throw new IllegalArgumentException ("alpha must be continuous");
  }

  public PottsTableFactor (Variable x1, Variable x2, Variable alpha)
  {
    super (new HashVarSet (new Variable[] { x1, x2, alpha }));
    this.alpha = alpha;
    this.xs = new HashVarSet (new Variable[] { x1, x2 });
    if (!alpha.isContinuous ()) throw new IllegalArgumentException ("alpha must be continuous");
  }

  private static VarSet combineVariables (Variable alpha, VarSet xs)
  {
    VarSet ret = new HashVarSet (xs);
    ret.add (alpha);
    return ret;
  }

  protected Factor extractMaxInternal (VarSet varSet)
  {
    throw new UnsupportedOperationException ();
  }

  protected double lookupValueInternal (int i)
  {
    throw new UnsupportedOperationException ();
  }

  protected Factor marginalizeInternal (VarSet varsToKeep)
  {
    throw new UnsupportedOperationException ();
  }

  /* Inefficient, but this will seldom be called. */
  public double value (AssignmentIterator it)
  {
    Assignment assn = it.assignment();
    Factor tbl = sliceForAlpha (assn);
    return tbl.value (assn);
  }

  private Factor sliceForAlpha (Assignment assn)
  {
    double alph = assn.getDouble (alpha);
    int[] sizes = sizesFromVarSet (xs);
    Matrix diag = Matrices.diag (sizes, alph);
    Matrix matrix = Matrices.constant (sizes, -alph);
    matrix.plusEquals (diag);
    return LogTableFactor.makeFromLogMatrix (xs.toVariableArray (), (SparseMatrixn) matrix);
  }

  private int[] sizesFromVarSet (VarSet xs)
  {
    int[] szs = new int [xs.size ()];
    for (int i = 0; i < xs.size (); i++) {
      szs[i] = xs.get (i).getNumOutcomes ();
    }
    return szs;
  }

  public Factor normalize ()
  {
    throw new UnsupportedOperationException ();
  }

  public Assignment sample (Randoms r)
  {
    throw new UnsupportedOperationException ();
  }

  public double logValue (AssignmentIterator it)
  {
    return Math.log (value (it));
  }

  public Factor slice (Assignment assn)
  {
    Factor alphSlice = sliceForAlpha (assn);
    // recursively slice, in case assn includes some of the xs
    return alphSlice.slice (assn);
  }

  public String dumpToString ()
  {
    StringBuffer buf = new StringBuffer ();
    buf.append ("[Potts: alpha:");
    buf.append (alpha);
    buf.append (" xs:");
    buf.append (xs);
    buf.append ("]");
    return buf.toString ();
  }

  public double sumGradLog (Factor q, Variable param, Assignment theta)
  {
    if (param != alpha) throw new IllegalArgumentException ();
    Factor q_xs = q.marginalize (xs);
    double qDiff = 0.0;

    for (AssignmentIterator it = xs.assignmentIterator (); it.hasNext(); it.advance()) {
      Assignment assn = it.assignment ();
      if (!isAllEqual (assn)) {
        qDiff += -q_xs.value (it);
      }
    }
   
    return qDiff;
  }

  public double secondDerivative (Factor q, Variable param, Assignment theta)
  {
    double e_x = sumGradLog (q, param, theta);

    Factor q_xs = q.marginalize (xs);
    double e_x2 = 0.0;

    for (AssignmentIterator it = xs.assignmentIterator (); it.hasNext(); it.advance()) {
      Assignment assn = it.assignment ();
      if (!isAllEqual (assn)) {
        e_x2 += q_xs.value (it);
      }
    }

    return e_x2 - (e_x * e_x);
  }

  private boolean isAllEqual (Assignment assn)
  {
    Object val1 = assn.getObject (xs.get (0));
    for (int i = 1; i < xs.size (); i++) {
      Object val2 = assn.getObject (xs.get (i));
      if (!val1.equals (val2)) return false;
    }
    return true;
  }

  public Factor duplicate ()
  {
    return new PottsTableFactor (xs, alpha);
  }

  public boolean isNaN ()
  {
    return false;
  }

  public boolean almostEquals (Factor p, double epsilon)
  {
    return equals (p);
  }

  public boolean equals (Object o)
  {
    if (this == o) return true;
    if (o == null || getClass () != o.getClass ()) return false;

    final PottsTableFactor that = (PottsTableFactor) o;

    if (alpha != null ? !alpha.equals (that.alpha) : that.alpha != null) return false;
    if (xs != null ? !xs.equals (that.xs) : that.xs != null) return false;

    return true;
  }

  public int hashCode ()
  {
    int result;
    result = (alpha != null ? alpha.hashCode () : 0);
    result = 29 * result + (xs != null ? xs.hashCode () : 0);
    return result;
  }
}
TOP

Related Classes of cc.mallet.grmm.types.PottsTableFactor

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.