/* Copyright (C) 2009 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.classify;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.logging.Logger;
import cc.mallet.topics.ParallelTopicModel;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.InfoGain;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Labeling;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;
/**
* Utility functions for creating feature constraints that can be used with GE training.
* @author Gregory Druck <a href="mailto:gdruck@cs.umass.edu">gdruck@cs.umass.edu</a>
*/
public class FeatureConstraintUtil {
private static Logger logger = MalletLogger.getLogger(FeatureConstraintUtil.class.getName());
/**
* Reads range constraints stored using strings from a file. Format can be either:
*
* feature_name (label_name:lower_probability,upper_probability)+
*
* or
*
* feature_name (label_name:probability)+
*
* Constraints are only added for feature-label pairs that are present.
*
* @param filename File with feature constraints.
* @param data InstanceList used for alphabets.
* @return Constraints.
*/
public static HashMap<Integer,double[][]> readRangeConstraintsFromFile(String filename, InstanceList data) {
HashMap<Integer,double[][]> constraints = new HashMap<Integer,double[][]>();
for (int li = 0; li < data.getTargetAlphabet().size(); li++) {
System.err.println(data.getTargetAlphabet().lookupObject(li));
}
try {
BufferedReader reader = new BufferedReader(new FileReader(filename));
String line = reader.readLine();
while (line != null) {
String[] split = line.split("\\s+");
// assume the feature name has no spaces
String featureName = split[0];
int featureIndex = data.getDataAlphabet().lookupIndex(featureName,false);
if (featureIndex == -1) {
throw new RuntimeException("Feature " + featureName + " not found in the alphabet!");
}
double[][] probs = new double[data.getTargetAlphabet().size()][2];
for (int i = 0; i < probs.length; i++) Arrays.fill(probs[i ],Double.NEGATIVE_INFINITY);
for (int index = 1; index < split.length; index++) {
String[] labelSplit = split[index].split(":");
int li = data.getTargetAlphabet().lookupIndex(labelSplit[0],false);
assert (li != -1) : labelSplit[0];
if (labelSplit[1].contains(",")) {
String[] rangeSplit = labelSplit[1].split(",");
double lower = Double.parseDouble(rangeSplit[0]);
double upper = Double.parseDouble(rangeSplit[1]);
probs[li][0] = lower;
probs[li][1] = upper;
}
else {
double prob = Double.parseDouble(labelSplit[1]);
probs[li][0] = prob;
probs[li][1] = prob;
}
}
constraints.put(featureIndex, probs);
line = reader.readLine();
}
}
catch (Exception e) {
e.printStackTrace();
System.exit(1);
}
return constraints;
}
/**
* Reads feature constraints from a file, whether they are stored
* using Strings or indices.
*
* @param filename File with feature constraints.
* @param data InstanceList used for alphabets.
* @return Constraints.
*/
public static HashMap<Integer,double[]> readConstraintsFromFile(String filename, InstanceList data) {
if (testConstraintsFileIndexBased(filename)) {
return readConstraintsFromFileIndex(filename,data);
}
return readConstraintsFromFileString(filename,data);
}
/**
* Reads feature constraints stored using strings from a file.
*
* feature_name (label_name:probability)+
*
* Labels that do appear get probability 0.
*
* @param filename File with feature constraints.
* @param data InstanceList used for alphabets.
* @return Constraints.
*/
public static HashMap<Integer,double[]> readConstraintsFromFileString(String filename, InstanceList data) {
HashMap<Integer,double[]> constraints = new HashMap<Integer,double[]>();
File file = new File(filename);
try {
BufferedReader reader = new BufferedReader(new FileReader(file));
String line = reader.readLine();
while (line != null) {
String[] split = line.split("\\s+");
// assume the feature name has no spaces
String featureName = split[0];
int featureIndex = data.getDataAlphabet().lookupIndex(featureName,false);
assert(split.length - 1 == data.getTargetAlphabet().size()) : split.length + " " + data.getTargetAlphabet().size();
double[] probs = new double[split.length - 1];
for (int index = 1; index < split.length; index++) {
String[] labelSplit = split[index].split(":");
int li = data.getTargetAlphabet().lookupIndex(labelSplit[0],false);
assert(li != -1) : "Label " + labelSplit[0] + " not found";
double prob = Double.parseDouble(labelSplit[1]);
probs[li] = prob;
}
constraints.put(featureIndex, probs);
line = reader.readLine();
}
}
catch (Exception e) {
e.printStackTrace();
System.exit(1);
}
return constraints;
}
/**
* Reads feature constraints stored using strings from a file.
*
* feature_index label_0_prob label_1_prob ... label_n_prob
*
* Here each label must appear.
*
* @param filename File with feature constraints.
* @param data InstanceList used for alphabets.
* @return Constraints.
*/
public static HashMap<Integer,double[]> readConstraintsFromFileIndex(String filename, InstanceList data) {
HashMap<Integer,double[]> constraints = new HashMap<Integer,double[]>();
File file = new File(filename);
try {
BufferedReader reader = new BufferedReader(new FileReader(file));
String line = reader.readLine();
while (line != null) {
String[] split = line.split("\\s+");
int featureIndex = Integer.parseInt(split[0]);
assert(split.length - 1 == data.getTargetAlphabet().size());
double[] probs = new double[split.length - 1];
for (int index = 1; index < split.length; index++) {
double prob = Double.parseDouble(split[index]);
probs[index-1] = prob;
}
constraints.put(featureIndex, probs);
line = reader.readLine();
}
}
catch (Exception e) {
e.printStackTrace();
System.exit(1);
}
return constraints;
}
private static boolean testConstraintsFileIndexBased(String filename) {
File file = new File(filename);
String firstLine = "";
try {
BufferedReader reader = new BufferedReader(new FileReader(file));
firstLine = reader.readLine();
}
catch (Exception e) {
e.printStackTrace();
System.exit(1);
}
return !firstLine.contains(":");
}
/**
* Select features with the highest information gain.
*
* @param list InstanceList for computing information gain.
* @param numFeatures Number of features to select.
* @return List of features with the highest information gains.
*/
public static ArrayList<Integer> selectFeaturesByInfoGain(InstanceList list, int numFeatures) {
ArrayList<Integer> features = new ArrayList<Integer>();
InfoGain infogain = new InfoGain(list);
for (int rank = 0; rank < numFeatures; rank++) {
features.add(infogain.getIndexAtRank(rank));
}
return features;
}
/**
* Select top features in LDA topics.
*
* @param numSelFeatures Number of features to select.
* @param ldaEst LDAEstimatePr which provides an interface to an LDA model.
* @param seqAlphabet The alphabet for the sequence dataset, which may be different from the vector dataset alphabet.
* @param alphabet The vector dataset alphabet.
* @return ArrayList with the int indices of the selected features.
*/
public static ArrayList<Integer> selectTopLDAFeatures(int numSelFeatures, ParallelTopicModel lda, Alphabet alphabet) {
ArrayList<Integer> features = new ArrayList<Integer>();
Alphabet seqAlphabet = lda.getAlphabet();
int numTopics = lda.getNumTopics();
Object[][] sorted = lda.getTopWords(seqAlphabet.size());
for (int pos = 0; pos < seqAlphabet.size(); pos++) {
for (int ti = 0; ti < numTopics; ti++) {
Object feat = sorted[ti][pos].toString();
int fi = alphabet.lookupIndex(feat,false);
if ((fi >=0) && (!features.contains(fi))) {
logger.info("Selected feature: " + feat);
features.add(fi);
if (features.size() == numSelFeatures) {
return features;
}
}
}
}
return features;
}
public static HashMap<Integer,double[]> setTargetsUsingData(InstanceList list, ArrayList<Integer> features) {
return setTargetsUsingData(list,features,true);
}
public static HashMap<Integer,double[]> setTargetsUsingData(InstanceList list, ArrayList<Integer> features, boolean normalize) {
return setTargetsUsingData(list,features,false,normalize);
}
/**
* Set target distributions using estimates from data.
*
* @param list InstanceList used to estimate targets.
* @param features List of features for constraints.
* @param normalize Whether to normalize by feature counts
* @return Constraints (map of feature index to target), with targets
* set using estimates from supplied data.
*/
public static HashMap<Integer,double[]> setTargetsUsingData(InstanceList list, ArrayList<Integer> features, boolean useValues, boolean normalize) {
HashMap<Integer,double[]> constraints = new HashMap<Integer,double[]>();
double[][] featureLabelCounts = getFeatureLabelCounts(list,useValues);
for (int i = 0; i < features.size(); i++) {
int fi = features.get(i);
if (fi != list.getDataAlphabet().size()) {
double[] prob = featureLabelCounts[fi];
if (normalize) {
// Smooth probability distributions by adding a (very)
// small count. We just need to make sure they aren't
// zero in which case the KL-divergence is infinite.
MatrixOps.plusEquals(prob, 1e-8);
MatrixOps.timesEquals(prob, 1./MatrixOps.sum(prob));
}
constraints.put(fi, prob);
}
}
return constraints;
}
/**
* Set target distributions using "Schapire" heuristic described in
* "Learning from Labeled Features using Generalized Expectation Criteria"
* Gregory Druck, Gideon Mann, Andrew McCallum.
*
* @param labeledFeatures HashMap of feature indices to lists of label indices for that feature.
* @param numLabels Total number of labels.
* @param majorityProb Probability mass divided among majority labels.
* @return Constraints (map of feature index to target distribution), with target
* distributions set using heuristic.
*/
public static HashMap<Integer,double[]> setTargetsUsingHeuristic(HashMap<Integer,ArrayList<Integer>> labeledFeatures, int numLabels, double majorityProb) {
HashMap<Integer,double[]> constraints = new HashMap<Integer,double[]>();
Iterator<Integer> keyIter = labeledFeatures.keySet().iterator();
while (keyIter.hasNext()) {
int fi = keyIter.next();
ArrayList<Integer> labels = labeledFeatures.get(fi);
constraints.put(fi, getHeuristicPrior(labels,numLabels,majorityProb));
}
return constraints;
}
/**
* Set target distributions using feature voting heuristic described in
* "Learning from Labeled Features using Generalized Expectation Criteria"
* Gregory Druck, Gideon Mann, Andrew McCallum.
*
* @param labeledFeatures HashMap of feature indices to lists of label indices for that feature.
* @param trainingData InstanceList to use for computing expectations with feature voting.
* @return Constraints (map of feature index to target distribution), with target
* distributions set using feature voting.
*/
public static HashMap<Integer, double[]> setTargetsUsingFeatureVoting(HashMap<Integer,ArrayList<Integer>> labeledFeatures,
InstanceList trainingData) {
HashMap<Integer,double[]> constraints = new HashMap<Integer,double[]>();
int numLabels = trainingData.getTargetAlphabet().size();
Iterator<Integer> keyIter = labeledFeatures.keySet().iterator();
double[][] featureCounts = new double[labeledFeatures.size()][numLabels];
for (int ii = 0; ii < trainingData.size(); ii++) {
Instance instance = trainingData.get(ii);
FeatureVector fv = (FeatureVector)instance.getData();
Labeling labeling = trainingData.get(ii).getLabeling();
double[] labelDist = new double[numLabels];
if (labeling == null) {
labelByVoting(labeledFeatures,instance,labelDist);
} else {
int li = labeling.getBestIndex();
labelDist[li] = 1.;
}
keyIter = labeledFeatures.keySet().iterator();
int i = 0;
while (keyIter.hasNext()) {
int fi = keyIter.next();
if (fv.location(fi) >= 0) {
for (int li = 0; li < numLabels; li++) {
featureCounts[i][li] += labelDist[li] * fv.valueAtLocation(fv.location(fi));
}
}
i++;
}
}
keyIter = labeledFeatures.keySet().iterator();
int i = 0;
while (keyIter.hasNext()) {
int fi = keyIter.next();
// smoothing counts
MatrixOps.plusEquals(featureCounts[i], 1e-8);
MatrixOps.timesEquals(featureCounts[i],1./MatrixOps.sum(featureCounts[i]));
constraints.put(fi, featureCounts[i]);
i++;
}
return constraints;
}
/**
* Label features using heuristic described in
* "Learning from Labeled Features using Generalized Expectation Criteria"
* Gregory Druck, Gideon Mann, Andrew McCallum.
*
* @param list InstanceList used to compute statistics for labeling features.
* @param features List of features to label.
* @param reject Whether to reject labeling features.
* @return Labeled features, HashMap mapping feature indices to list of labels.
*/
public static HashMap<Integer, ArrayList<Integer>> labelFeatures(InstanceList list, ArrayList<Integer> features, boolean reject) {
HashMap<Integer,ArrayList<Integer>> labeledFeatures = new HashMap<Integer,ArrayList<Integer>>();
double[][] featureLabelCounts = getFeatureLabelCounts(list,true);
int numLabels = list.getTargetAlphabet().size();
int minRank = 100 * numLabels;
InfoGain infogain = new InfoGain(list);
double sum = 0;
for (int rank = 0; rank < minRank; rank++) {
sum += infogain.getValueAtRank(rank);
}
double mean = sum / minRank;
for (int i = 0; i < features.size(); i++) {
int fi = features.get(i);
// reject features with infogain
// less than cutoff
if (reject && infogain.value(fi) < mean) {
//System.err.println("Oracle labeler rejected labeling: " + list.getDataAlphabet().lookupObject(fi));
logger.info("Oracle labeler rejected labeling: " + list.getDataAlphabet().lookupObject(fi));
continue;
}
double[] prob = featureLabelCounts[fi];
MatrixOps.plusEquals(prob,1e-8);
MatrixOps.timesEquals(prob, 1./MatrixOps.sum(prob));
int[] sortedIndices = getMaxIndices(prob);
ArrayList<Integer> labels = new ArrayList<Integer>();
if (numLabels > 2) {
// take anything within a factor of 2 of the best
// but no more than numLabels/2
boolean discard = false;
double threshold = prob[sortedIndices[0]] / 2;
for (int li = 0; li < numLabels; li++) {
if (prob[li] > threshold) {
labels.add(li);
}
if (reject && labels.size() > (numLabels / 2)) {
//System.err.println("Oracle labeler rejected labeling: " + list.getDataAlphabet().lookupObject(fi));
logger.info("Oracle labeler rejected labeling: " + list.getDataAlphabet().lookupObject(fi));
discard = true;
break;
}
}
if (discard) {
continue;
}
}
else {
labels.add(sortedIndices[0]);
}
labeledFeatures.put(fi, labels);
}
return labeledFeatures;
}
public static HashMap<Integer, ArrayList<Integer>> labelFeatures(InstanceList list, ArrayList<Integer> features) {
return labelFeatures(list,features,true);
}
public static double[][] getFeatureLabelCounts(InstanceList list, boolean useValues) {
int numFeatures = list.getDataAlphabet().size();
int numLabels = list.getTargetAlphabet().size();
double[][] featureLabelCounts = new double[numFeatures][numLabels];
for (int ii = 0; ii < list.size(); ii++) {
Instance instance = list.get(ii);
FeatureVector featureVector = (FeatureVector)instance.getData();
// this handles distributions over labels
for (int li = 0; li < numLabels; li++) {
double py = instance.getLabeling().value(li);
for (int loc = 0; loc < featureVector.numLocations(); loc++) {
int fi = featureVector.indexAtLocation(loc);
double val;
if (useValues) {
val = featureVector.valueAtLocation(loc);
}
else {
val = 1.0;
}
featureLabelCounts[fi][li] += py * val;
}
}
}
return featureLabelCounts;
}
private static double[] getHeuristicPrior (ArrayList<Integer> labeledFeatures, int numLabels, double majorityProb) {
int numIndices = labeledFeatures.size();
double[] dist = new double[numLabels];
if (numIndices == numLabels) {
for (int i = 0; i < dist.length; i++) {
dist[i] = 1./numLabels;
}
return dist;
}
double keywordProb = majorityProb / numIndices;
double otherProb = (1 - majorityProb) / (numLabels - numIndices);
for (int i = 0; i < labeledFeatures.size(); i++) {
int li = labeledFeatures.get(i);
dist[li] = keywordProb;
}
for (int li = 0; li < numLabels; li++) {
if (dist[li] == 0) {
dist[li] = otherProb;
}
}
assert(Maths.almostEquals(MatrixOps.sum(dist),1));
return dist;
}
private static void labelByVoting(HashMap<Integer,ArrayList<Integer>> labeledFeatures, Instance instance, double[] scores) {
FeatureVector fv = (FeatureVector)instance.getData();
int numFeatures = instance.getDataAlphabet().size() + 1;
int[] numLabels = new int[instance.getTargetAlphabet().size()];
Iterator<Integer> keyIterator = labeledFeatures.keySet().iterator();
while (keyIterator.hasNext()) {
ArrayList<Integer> majorityClassList = labeledFeatures.get(keyIterator.next());
for (int i = 0; i < majorityClassList.size(); i++) {
int li = majorityClassList.get(i);
numLabels[li]++;
}
}
keyIterator = labeledFeatures.keySet().iterator();
while (keyIterator.hasNext()) {
int next = keyIterator.next();
assert(next < numFeatures);
int loc = fv.location(next);
if (loc < 0) {
continue;
}
ArrayList<Integer> majorityClassList = labeledFeatures.get(next);
for (int i = 0; i < majorityClassList.size(); i++) {
int li = majorityClassList.get(i);
scores[li] += 1;
}
}
double sum = MatrixOps.sum(scores);
if (sum == 0) {
MatrixOps.plusEquals(scores, 1.0);
sum = MatrixOps.sum(scores);
}
for (int li = 0; li < scores.length; li++) {
scores[li] /= sum;
}
}
/*
* These functions are no longer needed.
*
private static double[][] getPrWordTopic(LDAHyper lda){
int numTopics = lda.getNumTopics();
int numTypes = lda.getAlphabet().size();
double[][] prWordTopic = new double[numTopics][numTypes];
for (int ti = 0 ; ti < numTopics; ti++){
for (int wi = 0 ; wi < numTypes; wi++){
prWordTopic[ti][wi] = (double) lda.getCountFeatureTopic(wi, ti) / (double) lda.getCountTokensPerTopic(ti);
}
}
return prWordTopic;
}
private static int[][] getSortedTopic(double[][] prTopicWord){
int numTopics = prTopicWord.length;
int numTypes = prTopicWord[0].length;
int[][] sortedTopicIdx = new int[numTopics][numTypes];
for (int ti = 0; ti < numTopics; ti++){
int[] topicIdx = getMaxIndices(prTopicWord[ti]);
System.arraycopy(topicIdx, 0, sortedTopicIdx[ti], 0, topicIdx.length);
}
return sortedTopicIdx;
}
*/
private static int[] getMaxIndices(double[] x) {
ArrayList<Element> list = new ArrayList<Element>();
for (int i = 0; i < x.length; i++) {
Element element = new Element(i,x[i]);
list.add(element);
}
Collections.sort(list);
Collections.reverse(list);
int[] sortedIndices = new int[x.length];
for (int i = 0; i < x.length; i++) {
sortedIndices[i] = list.get(i).index;
}
return sortedIndices;
}
private static class Element implements Comparable<Element> {
private int index;
private double value;
public Element(int index, double value) {
this.index = index;
this.value = value;
}
public int compareTo(Element element) {
return Double.compare(this.value, element.value);
}
}
}