/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.util;
import java.util.Map;
import java.util.List;
import java.util.ArrayList;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.Variable;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.Labels;
import cc.mallet.types.LabelsSequence;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AlphabetCarrying;
import gnu.trove.THashMap;
import gnu.trove.TIntArrayList;
/**
* A special kind of assignment for Variables that
* can be arranged in a LabelsSequence. This is an Adaptor
* to adapt LabelsSequences to Assignments.
* <p/>
* $Id: LabelsAssignment.java,v 1.1 2007/10/22 21:37:58 mccallum Exp $
*/
public class LabelsAssignment extends Assignment implements AlphabetCarrying {
// these are just for printing; race conditions don't matter
private static int NEXT_ID = 0;
private int id = NEXT_ID++;
private Variable[][] idx2var;
private LabelsSequence lblseq;
private Map var2label;
public LabelsAssignment (LabelsSequence lbls)
{
super ();
this.lblseq = lbls;
setupLabel2Var ();
addRow (toVariableArray (), toValueArray ());
}
private Variable[] toVariableArray ()
{
List vars = new ArrayList (maxTime () * numSlices ());
for (int t = 0; t < idx2var.length; t++) {
for (int j = 0; j < idx2var[t].length; j++) {
vars.add (idx2var[t][j]);
}
}
return (Variable[]) vars.toArray (new Variable [vars.size ()]);
}
private int[] toValueArray ()
{
TIntArrayList vals = new TIntArrayList (maxTime () * numSlices ());
for (int t = 0; t < lblseq.size (); t++) {
Labels lbls = lblseq.getLabels (t);
for (int j = 0; j < lbls.size (); j++) {
Label lbl = lbls.get (j);
vals.add (lbl.getIndex ());
}
}
return vals.toNativeArray ();
}
private void setupLabel2Var ()
{
idx2var = new Variable [lblseq.size ()][];
var2label = new THashMap ();
for (int t = 0; t < lblseq.size (); t++) {
Labels lbls = lblseq.getLabels (t);
idx2var[t] = new Variable [lbls.size ()];
for (int j = 0; j < lbls.size (); j++) {
Label lbl = lbls.get (j);
Variable var = new Variable (lbl.getLabelAlphabet ());
var.setLabel ("I"+id+"_VAR[f=" + j + "][tm=" + t + "]");
idx2var[t][j] = var;
var2label.put (var, lbl);
}
}
}
public Variable varOfIndex (int t, int j)
{
return idx2var[t][j];
}
public Label labelOfVar (Variable var) { return (Label) var2label.get (var); }
public int maxTime () { return lblseq.size (); }
// assumes that lblseq not ragged
public int numSlices () { return idx2var[0].length; }
public LabelsSequence getLabelsSequence ()
{
return lblseq;
}
public LabelsSequence toLabelsSequence (Assignment assn)
{
int numFactors = numSlices ();
int maxTime = maxTime ();
Labels[] lbls = new Labels [maxTime];
for (int t = 0; t < maxTime; t++) {
Label[] theseLabels = new Label [numFactors];
for (int i = 0; i < numFactors; i++) {
Variable var = varOfIndex (t, i);
int maxidx;
if (var != null) {
maxidx = assn.get (var);
} else {
maxidx = 0;
}
LabelAlphabet dict = labelOfVar (var).getLabelAlphabet ();
theseLabels[i] = dict.lookupLabel (maxidx);
}
lbls[t] = new Labels (theseLabels);
}
return new LabelsSequence (lbls);
}
public LabelAlphabet getOutputAlphabet (int lvl)
{
return idx2var[0][lvl].getLabelAlphabet ();
}
public Alphabet getAlphabet() { return getOutputAlphabet(0); }
public Alphabet[] getAlphabets() { return new Alphabet[] { getAlphabet() }; } //hack
}