Package cc.mallet.fst

Source Code of cc.mallet.fst.FeatureTransducer$State

/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */




/**
   @author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a>
*/

package cc.mallet.fst;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.logging.Logger;

import cc.mallet.types.Alphabet;
import cc.mallet.types.Multinomial;
import cc.mallet.types.Sequence;

import cc.mallet.util.MalletLogger;

public class FeatureTransducer extends Transducer
{
  private static Logger logger = MalletLogger.getLogger(FeatureTransducer.class.getName());
 
  // These next two dictionaries may be the same
  Alphabet inputAlphabet;
  Alphabet outputAlphabet;
  ArrayList<State> states = new ArrayList<State> ();
  ArrayList<State> initialStates = new ArrayList<State> ();
  HashMap<String,State> name2state = new HashMap<String,State> ();
  Multinomial.Estimator initialStateCounts;
  Multinomial.Estimator finalStateCounts;
  boolean trainable = false;

  public FeatureTransducer (Alphabet inputAlphabet,
                            Alphabet outputAlphabet)
  {
    this.inputAlphabet = inputAlphabet;
    this.outputAlphabet = outputAlphabet;
    // xxx When should these be frozen?
  }

  public FeatureTransducer (Alphabet dictionary)
  {
    this (dictionary, dictionary);
  }

  public FeatureTransducer ()
  {
    this (new Alphabet ());
  }

  public Alphabet getInputAlphabet () { return inputAlphabet; }
  public Alphabet getOutputAlphabet () { return outputAlphabet; }
 
  public void addState (String name, double initialWeight, double finalWeight,
                        int[] inputs, int[] outputs, double[] weights,
                        String[] destinationNames)
  {
    if (name2state.get(name) != null)
      throw new IllegalArgumentException ("State with name `"+name+"' already exists.");
    State s = new State (name, states.size(), initialWeight, finalWeight,
                         inputs, outputs, weights, destinationNames, this);
    states.add (s);
    if (initialWeight < IMPOSSIBLE_WEIGHT)
      initialStates.add (s);
    name2state.put (name, s);
    setTrainable (false);
  }

  public void addState (String name, double initialWeight, double finalWeight,
                        Object[] inputs, Object[] outputs, double[] weights,
                        String[] destinationNames)
  {
    this.addState (name, initialWeight, finalWeight,
                   inputAlphabet.lookupIndices (inputs, true),
                   outputAlphabet.lookupIndices (outputs, true),
                   weights, destinationNames);
  }

  public int numStates () { return states.size(); }

  public Transducer.State getState (int index) {
    return states.get(index); }
 
  public Iterator<State> initialStateIterator () { return initialStates.iterator (); }

  public boolean isTrainable () { return trainable; }

  public void setTrainable (boolean f)
  {
    trainable = f;
    if (f) {
      // This wipes away any previous counts we had.
      // It also potentially allocates an esimator of a new size if
      // the number of states has increased.
      initialStateCounts = new Multinomial.LaplaceEstimator (states.size());
      finalStateCounts = new Multinomial.LaplaceEstimator (states.size());
    } else {
      initialStateCounts = null;
      finalStateCounts = null;
    }
    for (int i = 0; i < numStates(); i++)
      ((State)getState(i)).setTrainable(f);
  }

  public void reset ()
  {
    if (trainable) {
      initialStateCounts.reset ();
      finalStateCounts.reset ();
      for (int i = 0; i < numStates(); i++)
        ((State)getState(i)).reset ();
    }
  }

  public void estimate ()
  {
    if (initialStateCounts == null || finalStateCounts == null)
      throw new IllegalStateException ("This transducer not currently trainable.");
    Multinomial initialStateDistribution = initialStateCounts.estimate ();
    Multinomial finalStateDistribution = finalStateCounts.estimate ();
    for (int i = 0; i < states.size(); i++) {
      State s = states.get (i);
      s.initialWeight = initialStateDistribution.logProbability (i);
      s.finalWeight = finalStateDistribution.logProbability (i);
      s.estimate ();
    }
  }

  // Note that this is a non-static inner class, so we have access to all of
  // FeatureTransducer's instance variables.
  public class State extends Transducer.State
  {
    String name;
    int index;
    double initialWeight, finalWeight;
    Transition[] transitions;
    gnu.trove.TIntObjectHashMap input2transitions;
    Multinomial.Estimator transitionCounts;
    FeatureTransducer transducer;

    // Note that you cannot add transitions to a state once it is created.
    protected State (String name, int index, double initialWeight, double finalWeight,
                     int[] inputs, int[] outputs, double[] weights,
                     String[] destinationNames, FeatureTransducer transducer)
    {
      assert (inputs.length == outputs.length
              && inputs.length == weights.length
              && inputs.length == destinationNames.length);
      this.transducer = transducer;
      this.name = name;
      this.index = index;
      this.initialWeight = initialWeight;
      this.finalWeight = finalWeight;
      this.transitions = new Transition[inputs.length];
      this.input2transitions = new gnu.trove.TIntObjectHashMap ();
      transitionCounts = null;
      for (int i = 0; i < inputs.length; i++) {
        // This constructor places the transtion into this.input2transitions
        transitions[i] = new Transition (inputs[i], outputs[i],
                                         weights[i], this, destinationNames[i]);
        transitions[i].index = i;
      }
    }
   
    public Transducer getTransducer () { return transducer; }
    public double getInitialWeight () { return initialWeight; }
    public double getFinalWeight () { return finalWeight; }
    public void setInitialWeight (double v) { initialWeight = v; }
    public void setFinalWeight (double v) { finalWeight = v; }

    private void setTrainable (boolean f)
    {
      if (f)
        transitionCounts = new Multinomial.LaplaceEstimator (transitions.length);
      else
        transitionCounts = null;
    }

    // Temporarily here for debugging
    public Multinomial.Estimator getTransitionEstimator()
    {
      return transitionCounts;
    }

    private void reset ()
    {
      if (transitionCounts != null)
        transitionCounts.reset();
    }

    public int getIndex () { return index; }

    public Transducer.TransitionIterator transitionIterator (Sequence input,
                                                             int inputPosition,
                                                             Sequence output,
                                                             int outputPosition)
    {
      if (inputPosition < 0 || outputPosition < 0 || output != null)
        throw new UnsupportedOperationException ("Not yet implemented.");
      if (input == null)
        return transitionIterator ();
      return transitionIterator (input, inputPosition);
    }

    public Transducer.TransitionIterator transitionIterator (Sequence inputSequence,
                                                             int inputPosition)
    {
      int inputIndex = inputAlphabet.lookupIndex (inputSequence.get(inputPosition), false);
      if (inputIndex == -1)
        throw new IllegalArgumentException ("Input not in dictionary.");
      return transitionIterator (inputIndex);
    }

    public Transducer.TransitionIterator transitionIterator (Object o)
    {
      int inputIndex = inputAlphabet.lookupIndex (o, false);
      if (inputIndex == -1)
        throw new IllegalArgumentException ("Input not in dictionary.");
      return transitionIterator (inputIndex);
    }
   
    public Transducer.TransitionIterator transitionIterator (int input)
    {
      return new TransitionIterator (this, input);
    }
   
    public Transducer.TransitionIterator transitionIterator ()
    {
      return new TransitionIterator (this);
    }

    public String getName ()
    {
      return name;
    }

    public void incrementInitialCount (double count)
    {
      if (initialStateCounts == null)
        throw new IllegalStateException ("Transducer is not currently trainable.");
      initialStateCounts.increment (index, count);
    }

    public void incrementFinalCount (double count)
    {
      if (finalStateCounts == null)
        throw new IllegalStateException ("Transducer is not currently trainable.");
      finalStateCounts.increment (index, count);
    }

    private void estimate ()
    {
      if (transitionCounts == null)
        throw new IllegalStateException ("Transducer is not currently trainable.");
      Multinomial transitionDistribution = transitionCounts.estimate ();
      for (int i = 0; i < transitions.length; i++)
        transitions[i].weight = transitionDistribution.logProbability (i);
    }

    private static final long serialVersionUID = 1;
  }

  @SuppressWarnings("serial")
  protected class TransitionIterator extends Transducer.TransitionIterator
  {
    // If "index" is >= -1 we are going through all FeatureState.transitions[] by index.
    // If "index" is -2, we are following the chain of FeatureTransition.nextWithSameInput,
    // and "transition" is already initialized to the first transition.
    // If "index" is -3, we are following the chain of FeatureTransition.nextWithSameInput,
    // and the next transition should be found by following the chain.
    int index;
    Transition transition;
    State source;
    int input;

    // Iterate through all transitions, independent of input
    public TransitionIterator (State source)
    {
      //System.out.println ("FeatureTransitionIterator over all");
      this.source = source;
      this.input = -1;
      this.index = -1;
      this.transition = null;
    }
   
    public TransitionIterator (State source, int input)
    {
      //System.out.println ("SymbolTransitionIterator over "+input);
      this.source = source;
      this.input = input;
      this.index = -2;
      this.transition = (Transition) source.input2transitions.get (input);
    }

    public boolean hasNext ()
    {
      if (index >= -1) {
        //System.out.println ("hasNext index " + index);
        return (index < source.transitions.length-1);
      }
      return (index == -2 ? transition != null : transition.nextWithSameInput != null);
    };

    public Transducer.State nextState ()
    {
      if (index >= -1)
        transition = source.transitions[++index];
      else if (index == -2)
        index = -3;
      else
        transition = transition.nextWithSameInput;
      return transition.getDestinationState();
    }

    public int getIndex () { return index; }
    public Object getInput () { return inputAlphabet.lookupObject(transition.input); }
    public Object getOutput () { return outputAlphabet.lookupObject(transition.output); }
    public double getWeight () { return transition.weight; }
    public Transducer.State getSourceState () { return source; }
    public Transducer.State getDestinationState () {
      return transition.getDestinationState ()}

    public void incrementCount (double count) {
      logger.info ("FeatureTransducer incrementCount "+count);
      source.transitionCounts.increment (transition.index, count); }
   
  }

  // Note: this class has a natural ordering that is inconsistent with equals.
  protected class Transition
  {
    int input, output;
    double weight;
    int index;
    String destinationName;
    State destination = null;
    Transition nextWithSameInput;

    public Transition (int input, int output, double weight,
                       State sourceState, String destinationName)
    {
      this.input = input;
      this.output = output;
      this.weight = weight;
      this.nextWithSameInput = (Transition) sourceState.input2transitions.get (input);
      sourceState.input2transitions.put (input, this);
      // this.index is set by the caller of this constructor
      this.destinationName = destinationName;
    }

    public State getDestinationState ()
    {
      if (destination == null) {
        destination = name2state.get (destinationName);
        assert (destination != null);
      }
      return destination;
    }

  }

  private static final long serialVersionUID = 1;

 

}
TOP

Related Classes of cc.mallet.fst.FeatureTransducer$State

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.