/*
* Copyright (c) 2011, Yahoo! Inc. Alimport java.util.ArrayList;
import java.util.HashSet;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import com.yahoo.labs.taxomo.util.State;
import com.yahoo.labs.taxomo.util.StateSet;
import com.yahoo.labs.taxomo.util.Taxonomy;
import com.yahoo.labs.taxomo.util.Util;
ns
* and the following disclaimer.
*
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions
* and the following disclaimer in the documentation and/or other materials provided with the
* distribution.
*
* Neither the name of Yahoo! Inc. nor the names of its contributors may be used to endorse or
* promote products derived from this software without specific prior written permission of Yahoo!
* Inc.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
* WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY
* WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package com.yahoo.labs.taxomo.learn;
import java.util.ArrayList;
import java.util.HashSet;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import com.yahoo.labs.taxomo.util.State;
import com.yahoo.labs.taxomo.util.StateSet;
import com.yahoo.labs.taxomo.util.Taxonomy;
import com.yahoo.labs.taxomo.util.Util;
/**
* Represents a candidate set of parameters for a model.
* <p>
* In its current implementation it is mostly a lightweight {@link StateSet}, with
* methods for generating other candidates by merging states in the current
* candidate.
*
* @author chato
*
*/
public class Candidate {
static final Logger logger = Logger.getLogger(Candidate.class);
static {
Util.configureLogger(logger, Level.DEBUG);
}
final private Taxonomy taxo;
final private Candidate parent;
final private ArrayList<String> states;
double logProbability = Double.NaN;
final private State mergedNode;
public Candidate(Taxonomy aTree, ArrayList<String> aStates, Candidate aParent, State aMergedNode) {
taxo = aTree;
states = aStates;
parent = aParent;
mergedNode = aMergedNode;
}
public Candidate clone() {
return new Candidate(taxo, states, parent, mergedNode);
}
public int getNumStates() {
return states.size();
}
public double getLogProbability() {
if (logProbability == Double.NaN) {
throw new IllegalStateException("The probability of this candidate has not been computed yet");
}
return logProbability;
}
public double getParentLogProbability() {
if (parent != null) {
return parent.getLogProbability();
} else {
return getLogProbability();
}
}
public int getParentNumStates() {
if (parent != null) {
return parent.getNumStates();
} else {
return getNumStates();
}
}
/**
* This compares the hash codes of the two candidates; this would introduce
* a random, but fixed, permutation on the candidate lists, to avoid picking
* always candidates in lexicographical ordering by states.
* <p>
* This comparator does not compare candidates by log probability.
*
* @param other
* @return comparison with the other candidate's hashCode
*/
public int compareTo(Object other) {
Candidate otherCand = (Candidate) other;
if (this.hashCode() < otherCand.hashCode()) {
return -1;
} else if (this.hashCode() > otherCand.hashCode()) {
return +1;
} else {
return 0;
}
}
public boolean equals(Object other) {
Candidate otherCand = (Candidate) other;
if (otherCand.getNumStates() != getNumStates()) {
return false;
} else {
HashSet<String> otherStates = new HashSet<String>(otherCand.getStates());
for (String state : getStates()) {
if (!otherStates.contains(state)) {
return false;
}
}
return true;
}
}
public int hashCode() {
int hashCode = 0;
for (String state : states) {
hashCode += state.hashCode();
}
return hashCode;
}
public ArrayList<String> getStates() {
return states;
}
public String toString() {
StringBuffer sb = new StringBuffer("(");
for (String str : states) {
sb.append(str + " ");
}
sb.deleteCharAt(sb.length() - 1);
sb.append(")");
return sb.toString();
}
public String toBriefString() {
int max = 10;
StringBuffer sb = new StringBuffer("(");
for (String str : states) {
if (max-- < 0) {
break;
}
sb.append(str + " ");
}
sb.deleteCharAt(sb.length() - 1);
if (max < 0) {
sb.append("... [" + states.size() + " states]");
}
sb.append(")");
return sb.toString();
}
public void setLogProbability(double aLogProbability) {
if (!Double.isNaN(logProbability)) {
throw new IllegalStateException("Trying to set probability of this candidate again");
} else if (Double.isNaN(aLogProbability)) {
throw new IllegalArgumentException("Argument was NaN");
} else {
logProbability = aLogProbability;
}
}
public ArrayList<Candidate> generateChildrenCandidatesMergingStates() {
HashSet<State> allowedStates = new HashSet<State>();
for (String state : states) {
allowedStates.add(taxo.getNode(state));
}
// Get nodes that can be merged
HashSet<State> mergeableNodes = new HashSet<State>();
HashSet<State> examined = new HashSet<State>();
for (State node : allowedStates) {
State parentNode = node.getParent();
if (parentNode != null && !examined.contains(parentNode)) {
examined.add(parentNode);
if (parentNode.hasOnlyChildrenInSet(allowedStates)) {
mergeableNodes.add(parentNode);
}
}
}
// Create list of candidates
ArrayList<Candidate> candidates = new ArrayList<Candidate>(mergeableNodes.size());
for (State nodeToMerge : mergeableNodes) {
Candidate subCandidate = new Candidate(taxo, generateCandidateStates(nodeToMerge), this, nodeToMerge);
candidates.add(subCandidate);
}
return candidates;
}
private ArrayList<String> generateCandidateStates(State mergeableNode) {
ArrayList<String> newAllowedStates = new ArrayList<String>(states.size() + 1);
boolean addedMergeableNode = false;
for (String state : states) {
State node = taxo.getNode(state);
if (node.getParent() == mergeableNode) {
if (!addedMergeableNode) {
newAllowedStates.add(mergeableNode.name());
addedMergeableNode = true;
}
} else {
newAllowedStates.add(state);
}
}
if (newAllowedStates.size() >= states.size()) {
logger.error("My states : " + printList(states));
logger.error("State to merge : " + mergeableNode.name());
logger.error("New states : " + printList(newAllowedStates));
throw new IllegalStateException(
"I generated a sub-candidate by merging states, but it has the same number of states as me; perhaps there is a degree-one node in the taxonomy");
}
return newAllowedStates;
}
public Candidate getParent() {
return parent;
}
public State getMergedNode() {
return mergedNode;
}
public static Candidate createLeafCandidate(Taxonomy tree) {
Candidate initialCandidate;
State[] leaves = tree.getLeaves();
ArrayList<String> allowedStates = new ArrayList<String>(leaves.length);
for (State leaf : leaves) {
allowedStates.add(leaf.name());
}
initialCandidate = new Candidate(tree, allowedStates, null, null);
return initialCandidate;
}
public static String printList(ArrayList<String> lst) {
StringBuffer sb = new StringBuffer();
for (String element : lst) {
sb.append(element);
sb.append(" ");
}
sb.deleteCharAt(sb.length() - 1);
return sb.toString();
}
public static Candidate createFixedLevelCandidate(Taxonomy tree, int level) {
State[] allLevel = tree.getLevel(level);
ArrayList<String> allowedStates = new ArrayList<String>(allLevel.length);
for (State node : allLevel) {
allowedStates.add(node.name());
}
return new Candidate(tree, allowedStates, null, null);
}
}