Package cc.mallet.pipe.iterator

Source Code of cc.mallet.pipe.iterator.RandomTokenSequenceIterator

/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */




/**
   @author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a>
*/

package cc.mallet.pipe.iterator;

import java.net.URI;
import java.util.Iterator;
import java.util.logging.*;

import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.iterator.PipeInputIterator;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.Instance;
import cc.mallet.types.Label;
import cc.mallet.types.Multinomial;
import cc.mallet.types.TokenSequence;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Randoms;

public class RandomTokenSequenceIterator implements Iterator<Instance>
{
  private static Logger logger = MalletLogger.getLogger(RandomTokenSequenceIterator.class.getName());

  Randoms r;
  Dirichlet classCentroidDistribution;
  double classCentroidAvergeAlphaMean;
  double classCentroidAvergeAlphaVariance;
  double featureVectorSizePoissonLambda;
  double classInstanceCountPoissonLamba;
  String[] classNames;

  int[] numInstancesPerClass;            // indexed over classes
  Dirichlet[] classCentroid;            // indexed over classes
  int currentClassIndex;
  int currentInstanceIndex;
 
  public RandomTokenSequenceIterator (Randoms r,
                                      // the generator of all random-ness used here
                                      Dirichlet classCentroidDistribution,
                                      // includes a Alphabet
                                      double classCentroidAvergeAlphaMean,
                                      // Gaussian mean on the sum of alphas
                                      double classCentroidAvergeAlphaVariance,
                                      // Gaussian variance on the sum of alphas
                                      double featureVectorSizePoissonLambda,
                                      double classInstanceCountPoissonLamba,
                                      String[] classNames)
  {
    this.r = r;
    this.classCentroidDistribution = classCentroidDistribution;
    assert (classCentroidDistribution.getAlphabet() instanceof Alphabet);
    this.classCentroidAvergeAlphaMean = classCentroidAvergeAlphaMean;
    this.classCentroidAvergeAlphaVariance = classCentroidAvergeAlphaVariance;
    this.featureVectorSizePoissonLambda = featureVectorSizePoissonLambda;
    this.classInstanceCountPoissonLamba = classInstanceCountPoissonLamba;
    this.classNames = classNames;
    this.numInstancesPerClass = new int[classNames.length];
    this.classCentroid = new Dirichlet[classNames.length];
    for (int i = 0; i < classNames.length; i++) {
      logger.fine ("classCentroidAvergeAlphaMean = "+classCentroidAvergeAlphaMean);
      double aveAlpha = r.nextGaussian (classCentroidAvergeAlphaMean,
                                        classCentroidAvergeAlphaVariance);
      logger.fine ("aveAlpha = "+aveAlpha);
      classCentroid[i] = classCentroidDistribution.randomDirichlet (r, aveAlpha);
      //logger.fine ("Dirichlet for class "+classNames[i]);  classCentroid[i].print();
    }
    reset ();
  }

  public RandomTokenSequenceIterator (Randoms r, Alphabet vocab, String[] classnames)
  {
    this (r, new Dirichlet(vocab, 2.0),
          30, 0,
          10, 20, classnames);
  }

  public Alphabet getAlphabet () { return classCentroidDistribution.getAlphabet(); }

  private static Alphabet dictOfSize (int size)
  {
    Alphabet ret = new Alphabet ();
    for (int i = 0; i < size; i++)
      ret.lookupIndex ("feature"+i);
    return ret;
  }

  private static String[] classNamesOfSize (int size)
  {
    String[] ret = new String[size];
    for (int i = 0; i < size; i++)
      ret[i] = "class"+i;
    return ret;
  }

  public RandomTokenSequenceIterator (Randoms r, int vocabSize, int numClasses)
  {
    this (r, new Dirichlet(dictOfSize(vocabSize), 2.0),
          30, 0,
          10, 20, classNamesOfSize(numClasses));
  }

  public void reset ()
  {
    for (int i = 0; i < classNames.length; i++) {
      this.numInstancesPerClass[i] = r.nextPoisson (classInstanceCountPoissonLamba);
      logger.fine ("Class "+classNames[i]+" will have "
                   +numInstancesPerClass[i]+" instances.");
    }
    this.currentClassIndex = classNames.length - 1;
    this.currentInstanceIndex = numInstancesPerClass[currentClassIndex] - 1;
  }

  public Instance next ()
  {
    if (currentInstanceIndex < 0) {
      if (currentClassIndex <= 0)
        throw new IllegalStateException ("No next TokenSequence.");
      currentClassIndex--;
      currentInstanceIndex = numInstancesPerClass[currentClassIndex] - 1;
    }
    URI uri = null;
    try { uri = new URI ("random:" + classNames[currentClassIndex] + "/" + currentInstanceIndex); }
    catch (Exception e) {e.printStackTrace(); throw new IllegalStateException (); }
    //xxx Producing small numbers? int randomSize = r.nextPoisson (featureVectorSizePoissonLambda);
    int randomSize = (int)featureVectorSizePoissonLambda;
    TokenSequence ts = classCentroid[currentClassIndex].randomTokenSequence (r, randomSize);
    //logger.fine ("FeatureVector "+currentClassIndex+" "+currentInstanceIndex); fv.print();
    currentInstanceIndex--;
    return new Instance (ts, classNames[currentClassIndex], uri, null);
  }

  public boolean hasNext ()  {  return ! (currentClassIndex <= 0 && currentInstanceIndex <= 0)}
 
  public void remove () {
    throw new IllegalStateException ("This Iterator<Instance> does not support remove().");
  }

}
TOP

Related Classes of cc.mallet.pipe.iterator.RandomTokenSequenceIterator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.