/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.types;
import cc.mallet.grmm.util.Matrices;
import cc.mallet.types.Matrix;
import cc.mallet.types.SparseMatrixn;
import cc.mallet.util.Randoms;
/**
* A factor over a continuous variable alpha and discrete variables <tt>x</tt>
* such that <tt>phi(x|alpha)<tt> is Potts. That is, for fixed alpha, <tt>phi(x)</tt> = 1
* if all x are equal, and <tt>exp^{-alpha}</tt> otherwise.
* $Id: PottsTableFactor.java,v 1.1 2007/10/22 21:37:44 mccallum Exp $
*/
public class PottsTableFactor extends AbstractFactor implements ParameterizedFactor {
private Variable alpha;
private VarSet xs;
public PottsTableFactor (VarSet xs, Variable alpha)
{
super (combineVariables (alpha, xs));
this.alpha = alpha;
this.xs = xs;
if (!alpha.isContinuous ()) throw new IllegalArgumentException ("alpha must be continuous");
}
public PottsTableFactor (Variable x1, Variable x2, Variable alpha)
{
super (new HashVarSet (new Variable[] { x1, x2, alpha }));
this.alpha = alpha;
this.xs = new HashVarSet (new Variable[] { x1, x2 });
if (!alpha.isContinuous ()) throw new IllegalArgumentException ("alpha must be continuous");
}
private static VarSet combineVariables (Variable alpha, VarSet xs)
{
VarSet ret = new HashVarSet (xs);
ret.add (alpha);
return ret;
}
protected Factor extractMaxInternal (VarSet varSet)
{
throw new UnsupportedOperationException ();
}
protected double lookupValueInternal (int i)
{
throw new UnsupportedOperationException ();
}
protected Factor marginalizeInternal (VarSet varsToKeep)
{
throw new UnsupportedOperationException ();
}
/* Inefficient, but this will seldom be called. */
public double value (AssignmentIterator it)
{
Assignment assn = it.assignment();
Factor tbl = sliceForAlpha (assn);
return tbl.value (assn);
}
private Factor sliceForAlpha (Assignment assn)
{
double alph = assn.getDouble (alpha);
int[] sizes = sizesFromVarSet (xs);
Matrix diag = Matrices.diag (sizes, alph);
Matrix matrix = Matrices.constant (sizes, -alph);
matrix.plusEquals (diag);
return LogTableFactor.makeFromLogMatrix (xs.toVariableArray (), (SparseMatrixn) matrix);
}
private int[] sizesFromVarSet (VarSet xs)
{
int[] szs = new int [xs.size ()];
for (int i = 0; i < xs.size (); i++) {
szs[i] = xs.get (i).getNumOutcomes ();
}
return szs;
}
public Factor normalize ()
{
throw new UnsupportedOperationException ();
}
public Assignment sample (Randoms r)
{
throw new UnsupportedOperationException ();
}
public double logValue (AssignmentIterator it)
{
return Math.log (value (it));
}
public Factor slice (Assignment assn)
{
Factor alphSlice = sliceForAlpha (assn);
// recursively slice, in case assn includes some of the xs
return alphSlice.slice (assn);
}
public String dumpToString ()
{
StringBuffer buf = new StringBuffer ();
buf.append ("[Potts: alpha:");
buf.append (alpha);
buf.append (" xs:");
buf.append (xs);
buf.append ("]");
return buf.toString ();
}
public double sumGradLog (Factor q, Variable param, Assignment theta)
{
if (param != alpha) throw new IllegalArgumentException ();
Factor q_xs = q.marginalize (xs);
double qDiff = 0.0;
for (AssignmentIterator it = xs.assignmentIterator (); it.hasNext(); it.advance()) {
Assignment assn = it.assignment ();
if (!isAllEqual (assn)) {
qDiff += -q_xs.value (it);
}
}
return qDiff;
}
public double secondDerivative (Factor q, Variable param, Assignment theta)
{
double e_x = sumGradLog (q, param, theta);
Factor q_xs = q.marginalize (xs);
double e_x2 = 0.0;
for (AssignmentIterator it = xs.assignmentIterator (); it.hasNext(); it.advance()) {
Assignment assn = it.assignment ();
if (!isAllEqual (assn)) {
e_x2 += q_xs.value (it);
}
}
return e_x2 - (e_x * e_x);
}
private boolean isAllEqual (Assignment assn)
{
Object val1 = assn.getObject (xs.get (0));
for (int i = 1; i < xs.size (); i++) {
Object val2 = assn.getObject (xs.get (i));
if (!val1.equals (val2)) return false;
}
return true;
}
public Factor duplicate ()
{
return new PottsTableFactor (xs, alpha);
}
public boolean isNaN ()
{
return false;
}
public boolean almostEquals (Factor p, double epsilon)
{
return equals (p);
}
public boolean equals (Object o)
{
if (this == o) return true;
if (o == null || getClass () != o.getClass ()) return false;
final PottsTableFactor that = (PottsTableFactor) o;
if (alpha != null ? !alpha.equals (that.alpha) : that.alpha != null) return false;
if (xs != null ? !xs.equals (that.xs) : that.xs != null) return false;
return true;
}
public int hashCode ()
{
int result;
result = (alpha != null ? alpha.hashCode () : 0);
result = 29 * result + (xs != null ? xs.hashCode () : 0);
return result;
}
}