package edu.brown.markov;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import org.apache.commons.collections15.EnumerationUtils;
import org.apache.commons.collections15.set.ListOrderedSet;
import weka.core.Attribute;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;
import edu.brown.utils.CollectionUtil;
import edu.brown.utils.StringUtil;
/**
*
* @author pavlo
*/
public class MarkovAttributeSet extends ListOrderedSet<Attribute> implements Comparable<MarkovAttributeSet> {
private static final long serialVersionUID = 1L;
private Double cost;
public MarkovAttributeSet(Set<Attribute> items) {
super(items);
}
public MarkovAttributeSet(Attribute...items) {
super((Set<Attribute>)CollectionUtil.addAll(new HashSet<Attribute>(), items));
}
/**
* Copy constructor
* @param clone
*/
public MarkovAttributeSet(MarkovAttributeSet clone) {
super(clone);
}
protected MarkovAttributeSet(Instances data, Collection<Integer> idxs) {
for (Integer i : idxs) {
this.add(data.attribute(i));
} // FOR
}
@SuppressWarnings("unchecked")
protected MarkovAttributeSet(Instances data, String prefix) {
for (Attribute a : (List<Attribute>)EnumerationUtils.toList(data.enumerateAttributes())) {
if (a.name().startsWith(prefix)) this.add(a);
} // FOR
}
public Filter createFilter(Instances data) throws Exception {
Set<Integer> indexes = new HashSet<Integer>();
for (int i = 0, cnt = this.size(); i < cnt; i++) {
indexes.add(this.get(i).index());
} // FOR
SortedSet<Integer> to_remove = new TreeSet<Integer>();
for (int i = 0, cnt = data.numAttributes(); i < cnt; i++) {
if (indexes.contains(i) == false) {
to_remove.add(i+1);
}
} // FOR
Remove filter = new Remove();
filter.setInputFormat(data);
String options[] = { "-R", StringUtil.join(",", to_remove) };
filter.setOptions(options);
return (filter);
}
//
// public Instances copyData(Instances data) throws Exception {
// Set<Integer> indexes = new HashSet<Integer>();
// for (int i = 0, cnt = this.size(); i < cnt; i++) {
// indexes.add(this.get(i).index());
// } // FOR
//
// SortedSet<Integer> to_remove = new TreeSet<Integer>();
// for (int i = 0, cnt = data.numAttributes(); i < cnt; i++) {
// if (indexes.contains(i) == false) {
// to_remove.add(i+1);
// }
// } // FOR
//
// Remove filter = new Remove();
// filter.setInputFormat(data);
// filter.setAttributeIndices(StringUtil.join(",", to_remove));
// for (int i = 0, cnt = data.numInstances(); i < cnt; i++) {
// filter.input(data.instance(i));
// } // FOR
// filter.batchFinished();
//
// Instances newData = filter.getOutputFormat();
// Instance processed;
// while ((processed = filter.output()) != null) {
// newData.add(processed);
// } // WHILE
// return (newData);
// }
public Double getCost() {
return (this.cost);
}
public void setCost(Double cost) {
this.cost = cost;
}
@Override
public int compareTo(MarkovAttributeSet o) {
if (this.cost != o.cost) {
return (this.cost != null ? this.cost.compareTo(o.cost) : o.cost.compareTo(this.cost));
} else if (this.size() != o.size()) {
return (this.size() - o.size());
} else if (this.containsAll(o)) {
return (0);
}
for (int i = 0, cnt = this.size(); i < cnt; i++) {
int idx0 = this.get(i).index();
int idx1 = o.get(i).index();
if (idx0 != idx1) return (idx0 - idx1);
} // FOR
return (0);
}
@Override
public String toString() {
return (MarkovAttributeSet.toString(this));
}
public static String toString(Set<Attribute> attrs) {
StringBuilder sb = new StringBuilder();
String add = "[";
for (Attribute a : attrs) {
sb.append(add).append(a.name());
add = ", ";
}
sb.append("]");
return sb.toString();
}
}