Package upenn.junto.util

Source Code of upenn.junto.util.ProbUtil

package upenn.junto.util;

import gnu.trove.iterator.TObjectDoubleIterator;
import gnu.trove.map.hash.TObjectDoubleHashMap;

import java.util.ArrayList;

public class ProbUtil {
 
  public static TObjectDoubleHashMap GetUniformPrior(ArrayList<String> labels) {
    int totalLabels = labels.size();
    assert (totalLabels > 0);
    double prior = 1.0 / totalLabels;
    assert (prior > 0);

    TObjectDoubleHashMap retMap = new TObjectDoubleHashMap();
    for (int li = 0; li < totalLabels; ++li) {
      retMap.put(labels.get(li), prior);
    }
    return (retMap);
  }

  // this method returns result += mult * addDist
  public static void AddScores(TObjectDoubleHashMap result, double mult,
                               TObjectDoubleHashMap addDist) {
    assert (result != null);
    assert (addDist != null);

    TObjectDoubleIterator iter = addDist.iterator();
    while (iter.hasNext()) {
      iter.advance();
      double adjVal = mult * iter.value();
     
      //      System.out.println(">> adjVal: " + mult + " " + iter.key() + " " + iter.value() + " " + adjVal);
      result.adjustOrPutValue(iter.key(), adjVal, adjVal);
    }
  }

  public static void DivScores(TObjectDoubleHashMap result, double divisor) {
    assert (result != null);
    assert (divisor > 0);

    TObjectDoubleIterator li = result.iterator();
    while (li.hasNext()) {
      li.advance();
      // System.out.println("Before: " + " " + li.key() + " " + li.value() + " " + divisor);
      double newVal = (1.0 * li.value()) / divisor;
      result.put(li.key(), newVal);
      // System.out.println("After: " + " " + li.key() + " " + result.get(li.key()) + " " + divisor);
    }
  }
 
  public static void KeepTopScoringKeys(TObjectDoubleHashMap m, int keepTopK) {
    ArrayList<ObjectDoublePair> lsps = CollectionUtil.ReverseSortMap(m);

    // the array is sorted from large to small, so start
    // from beginning and retain only top scoring k keys.
    m.clear();
    int totalAdded = 0;
    int totalSorted = lsps.size();
    // for (int li = lsps.size() - 1; li >= 0 && totalAdded <= keepTopK; --li) {
    for (int li = 0; li < totalSorted && totalAdded < keepTopK; ++li) {
      ++totalAdded;

      if (lsps.get(li).GetScore() > 0) {
        m.put(lsps.get(li).GetLabel(), lsps.get(li).GetScore());
      }
    }
   
    // size of the new map is upper bounded by the max
    // number of entries requested
    assert (m.size() <= keepTopK);
  }

  public static void Normalize(TObjectDoubleHashMap m) {
    Normalize(m, Integer.MAX_VALUE);
  }

  public static void Normalize(TObjectDoubleHashMap m, int keepTopK) {
    // if the number of labels to retain are not the trivial
    // default value, then keep the top scoring k labels as requested
    if (keepTopK != Integer.MAX_VALUE) {
      KeepTopScoringKeys(m, keepTopK);
    }

    TObjectDoubleIterator mi = m.iterator();
    double denom = 0;
    while (mi.hasNext()) {
      mi.advance();
      denom += mi.value();
    }
    // assert (denom > 0);

    if (denom > 0) {
      mi = m.iterator();
      while (mi.hasNext()) {
        mi.advance();
        double newVal = mi.value() / denom;
        mi.setValue(newVal);
      }
    }
  }

  public static double GetSum(TObjectDoubleHashMap m) {
    TObjectDoubleIterator mi = m.iterator();
    double sum = 0;
    while (mi.hasNext()) {
      mi.advance();
      sum += mi.value();
    }
    return (sum);
  }

  public static double GetDifferenceNorm2Squarred(TObjectDoubleHashMap m1,
                                                  double m1Mult, TObjectDoubleHashMap m2, double m2Mult) {
    TObjectDoubleHashMap diffMap = new TObjectDoubleHashMap();

    // copy m1 into the difference map
    TObjectDoubleIterator iter = m1.iterator();
    while (iter.hasNext()) {
      iter.advance();
      diffMap.put(iter.key(), m1Mult * iter.value());
    }

    iter = m2.iterator();
    while (iter.hasNext()) {
      iter.advance();
      diffMap.adjustOrPutValue(iter.key(), -1 * m2Mult * iter.value(), -1
                               * m2Mult * iter.value());
    }

    double val = 0;
    iter = diffMap.iterator();
    while (iter.hasNext()) {
      iter.advance();
      val += iter.value() * iter.value();
    }

    return (Math.sqrt(val));
  }

  // KL (m1 || m2)
  public static double GetKLDifference(TObjectDoubleHashMap m1,
                                       TObjectDoubleHashMap m2) {
    double divergence = 0;

    TObjectDoubleIterator iter = m1.iterator();
    while (iter.hasNext()) {
      iter.advance();
      if (iter.value() > 0) {
        //        if (!m2.containsKey(iter.key()) && m2.get(iter.key()) <= 0) {
        //          divergence += Double.NEGATIVE_INFINITY;
        //        } else {
        // add a small quantity to the numerator and denominator to avoid
        // infinite divergence
        divergence += iter.value()
          * Math.log((iter.value() + Constants.GetSmallConstant())
                     / (m2.get(iter.key()) + Constants.GetSmallConstant()));
        //        }
      }
    }

    return (divergence);
  }

  // Entropy(m1)
  public static double GetEntropy(TObjectDoubleHashMap m1) {
    double entropy = 0;
    TObjectDoubleIterator iter = m1.iterator();
    while (iter.hasNext()) {
      iter.advance();
      if (iter.value() > 0) {
        entropy += -1 * iter.value() * Math.log(iter.value());
      }
    }

    return (entropy);
  }

}
TOP

Related Classes of upenn.junto.util.ProbUtil

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.