Package cc.mallet.grmm.util

Source Code of cc.mallet.grmm.util.LabelsAssignment

/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.util;


import java.util.Map;
import java.util.List;
import java.util.ArrayList;

import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.Variable;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.Labels;
import cc.mallet.types.LabelsSequence;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AlphabetCarrying;

import gnu.trove.THashMap;
import gnu.trove.TIntArrayList;

/**
* A special kind of assignment for Variables that
* can be arranged in a LabelsSequence.  This is an Adaptor
* to adapt LabelsSequences to Assignments.
* <p/>
* $Id: LabelsAssignment.java,v 1.1 2007/10/22 21:37:58 mccallum Exp $
*/
public class LabelsAssignment extends Assignment implements AlphabetCarrying {

  // these are just for printing; race conditions don't matter
  private static int NEXT_ID = 0;
  private int id = NEXT_ID++;

  private Variable[][] idx2var;
  private LabelsSequence lblseq;
  private Map var2label;

  public LabelsAssignment (LabelsSequence lbls)
  {
    super ();
    this.lblseq = lbls;
    setupLabel2Var ();
    addRow (toVariableArray (), toValueArray ());
  }

  private Variable[] toVariableArray ()
  {
    List vars = new ArrayList (maxTime () * numSlices ());
    for (int t = 0; t < idx2var.length; t++) {
      for (int j = 0; j < idx2var[t].length; j++) {
        vars.add (idx2var[t][j]);
      }
    }
    return (Variable[]) vars.toArray (new Variable [vars.size ()]);
  }

  private int[] toValueArray ()
  {
    TIntArrayList vals = new TIntArrayList (maxTime () * numSlices ());
    for (int t = 0; t < lblseq.size (); t++) {
      Labels lbls = lblseq.getLabels (t);
      for (int j = 0; j < lbls.size (); j++) {
        Label lbl = lbls.get (j);
        vals.add (lbl.getIndex ());
      }
    }
    return vals.toNativeArray ();
  }

  private void setupLabel2Var ()
  {
    idx2var = new Variable [lblseq.size ()][];
    var2label = new THashMap ();
    for (int t = 0; t < lblseq.size (); t++) {
      Labels lbls = lblseq.getLabels (t);
      idx2var[t] = new Variable [lbls.size ()];
      for (int j = 0; j < lbls.size (); j++) {
        Label lbl = lbls.get (j);
        Variable var = new Variable (lbl.getLabelAlphabet ());
        var.setLabel ("I"+id+"_VAR[f=" + j + "][tm=" + t + "]");
        idx2var[t][j] = var;
        var2label.put (var, lbl);
      }
    }
  }

  public Variable varOfIndex (int t, int j)
  {
    return idx2var[t][j];
  }

  public Label labelOfVar (Variable var) { return (Label) var2label.get (var); }

  public int maxTime () { return lblseq.size (); }

  // assumes that lblseq not ragged
  public int numSlices () { return idx2var[0].length; }

  public LabelsSequence getLabelsSequence ()
  {
    return lblseq;
  }

  public LabelsSequence toLabelsSequence (Assignment assn)
  {
    int numFactors = numSlices ();
    int maxTime = maxTime ();
    Labels[] lbls = new Labels [maxTime];
    for (int t = 0; t < maxTime; t++) {
      Label[] theseLabels = new Label [numFactors];
      for (int i = 0; i < numFactors; i++) {
        Variable var = varOfIndex (t, i);
        int maxidx;

        if (var != null) {
          maxidx = assn.get (var);
        } else {
          maxidx = 0;
        }

        LabelAlphabet dict = labelOfVar (var).getLabelAlphabet ();
        theseLabels[i] = dict.lookupLabel (maxidx);
      }

      lbls[t] = new Labels (theseLabels);
    }

    return new LabelsSequence (lbls);
  }


  public LabelAlphabet getOutputAlphabet (int lvl)
  {
    return idx2var[0][lvl].getLabelAlphabet ();
  }

  public Alphabet getAlphabet() { return getOutputAlphabet(0); }
  public Alphabet[] getAlphabets() { return new Alphabet[] { getAlphabet() }; } //hack

}
TOP

Related Classes of cc.mallet.grmm.util.LabelsAssignment

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.