/*
* Copyright (c) 2011, Yahoo! Inc. All rights reserved.
*
* Redistribution and use of this software in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* Redistributions of source code must retain the above copyright notice, this list of conditions
* and the following disclaimer.
*
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions
* and the following disclaimer in the documentation and/or other materials provided with the
* distribution.
*
* Neither the name of Yahoo! Inc. nor the names of its contributors may be used to endorse or
* promote products derived from this software without specific prior written permission of Yahoo!
* Inc.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
* WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY
* WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package com.yahoo.labs.taxomo;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.PriorityQueue;
import java.util.Vector;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import com.martiansoftware.jsap.FlaggedOption;
import com.martiansoftware.jsap.JSAP;
import com.martiansoftware.jsap.JSAPException;
import com.martiansoftware.jsap.JSAPResult;
import com.martiansoftware.jsap.Parameter;
import com.martiansoftware.jsap.SimpleJSAP;
import com.martiansoftware.jsap.Switch;
import com.yahoo.labs.taxomo.Model.PrintMode;
import com.yahoo.labs.taxomo.learn.Candidate;
import com.yahoo.labs.taxomo.learn.CloserToOrigin;
import com.yahoo.labs.taxomo.learn.ModelScorer;
import com.yahoo.labs.taxomo.learn.ProbabilityDeltaCalculator;
import com.yahoo.labs.taxomo.learn.SearchStrategy;
import com.yahoo.labs.taxomo.util.FrequencyTable;
import com.yahoo.labs.taxomo.util.StateSet;
import com.yahoo.labs.taxomo.util.SymbolTransitionFrequencies;
import com.yahoo.labs.taxomo.util.Taxonomy;
import com.yahoo.labs.taxomo.util.Util;
/**
* Learns a model given a set of input sequences.
* <p>
* See <tt>--help</tt> for command-line options.
* @author chato
*
*/
public class LearnModel {
public static final Logger logger = Logger.getLogger(GenerateSequences.class);
static {
Util.configureLogger(logger, Level.INFO);
}
private static final int ITERATIONS_BEFORE_REPORTING = 500;
/**
* The queue cleanup policy removes the 50% less promising candidates if the queue grows too large.
*
* This number controls the criterion for triggering the cleanup. For instance, if this is 5, and the max iterations
* is set to 1,000, then the cleanup policy is triggered when reaching 5,000 candidates.
*/
private static final int CANDIDATES_PRUNING_TRIGGER_OVERHEAD_FACTOR = 10;
/**
* The queue cleanup policy is never applied if the number of candidates is less than this number.
*/
private static final int CANDIDATES_PRUNING_NEVER_TRIGGER_BELOW = 2000;
/**
* The queue cleanup policy is always applied if the number of candidates is more than this number.
*/
private static final int CANDIDATES_PRUNING_ALWAYS_TRIGGER_ABOVE = 300000;
public static int DEFAULT_BATCH_SIZE = 1;
int batchSize = DEFAULT_BATCH_SIZE;
public static final int DEFAULT_OUTPUT_MAX_STATES = Integer.MAX_VALUE;
final PriorityQueue<Candidate> candidates;
final HashMap<Candidate, Integer> results;
final Taxonomy tree;
final SymbolTransitionFrequencies symbolTransitions;
int maxIterations = Integer.MAX_VALUE;
int iterationNumber = 0;
final double minLogProbability;
final double maxLogProbability;
final int minStates = 1;
final int maxStates;
final ModelScorer outputScorer;
public LearnModel(Taxonomy aTree, SymbolTransitionFrequencies aSymbolTransitions, Candidate initialCandidate, Class<SearchStrategy> strategy, double weight1) {
tree = aTree;
// Read symbol transition
symbolTransitions = aSymbolTransitions;
// Create set of results
results = new HashMap<Candidate, Integer>();
// Compute minProbability
logger.debug("Evaluating minimum probability");
ArrayList<String> rootStateOnly = new ArrayList<String>(1);
rootStateOnly.add(tree.getRootState().name());
Candidate singleStateCandidate = new Candidate(tree, rootStateOnly, null, null);
doSlowEvaluation(singleStateCandidate);
minLogProbability = singleStateCandidate.getLogProbability();
logger.debug("Minimum probability: " + minLogProbability);
// Compute maxProbability
logger.debug("Evaluating maximum probability");
doSlowEvaluation(initialCandidate);
maxLogProbability = initialCandidate.getLogProbability();
logger.debug("Maximum probability: " + minLogProbability);
maxStates = initialCandidate.getNumStates();
logger.debug("Maximum states: " + maxStates);
// Initialize candidate search policy
SearchStrategy comparator;
if (strategy.equals(CloserToOrigin.class)) {
comparator = new CloserToOrigin(minLogProbability, maxLogProbability, minStates, maxStates, weight1);
} else {
try {
comparator = (SearchStrategy) strategy.newInstance();
} catch (InstantiationException e) {
throw new IllegalArgumentException(e);
} catch (IllegalAccessException e) {
throw new IllegalArgumentException(e);
}
}
logger.debug("Search strategy: " + comparator);
// Initialize scorer (for the output)
outputScorer = new CloserToOrigin(minLogProbability, maxLogProbability, minStates, maxStates, weight1);
logger.debug("Output scorer: " + outputScorer);
// Initialize set of candidates
candidates = new PriorityQueue<Candidate>(100, comparator);
addSubCandidates(initialCandidate);
}
public void learn() {
while (!candidates.isEmpty() && results.size() <= maxIterations) {
logger.debug( "There are " + candidates.size() + " candidates in the queue");
checkCandidatePruning();
Vector<Candidate> candidatesToTest = new Vector<Candidate>();
for (int i = 0; (i < batchSize && candidates.size() > 0); i++) {
candidatesToTest.add(candidates.poll());
}
for (Candidate candidate : candidatesToTest) {
if (!results.containsKey(candidate)) {
doQuickEvaluation(candidate);
// doSlowEvaluation(candidate);
// debugDoSlowEvaluation(candidate);
addSubCandidates(candidate);
}
}
}
}
public void doCandidatesPrunning() {
logger.info( "Dropping less promissing candidates from the queue" );
int newMaxElements = (candidates.size())/2;
ArrayList<Candidate> topCandidates = new ArrayList<Candidate>(newMaxElements);
for( int i=0; i<newMaxElements; i++ ) {
topCandidates.add( candidates.poll() );
}
candidates.clear();
candidates.addAll(topCandidates);
logger.info( "Queue checking done, now we have only " + candidates.size() + " candidates" );
}
public void checkCandidatePruning() {
int size = candidates.size();
if( size < CANDIDATES_PRUNING_NEVER_TRIGGER_BELOW ) {
return;
} else if( size > CANDIDATES_PRUNING_ALWAYS_TRIGGER_ABOVE ) {
doCandidatesPrunning();
} else if( ( maxIterations < Integer.MAX_VALUE )
&& ( size > (maxIterations * CANDIDATES_PRUNING_TRIGGER_OVERHEAD_FACTOR) )
) {
doCandidatesPrunning();
}
}
void addSubCandidates(Candidate candidate) {
ArrayList<Candidate> subCandidates = candidate.generateChildrenCandidatesMergingStates();
for (Candidate subCandidate : subCandidates) {
candidates.add(subCandidate);
}
}
void loggerPrintIterationNumber(Candidate candidate) {
// Avoid printing too many informational messages
String msg = "Evaluation " + iterationNumber + (maxIterations < Integer.MAX_VALUE ? "/" + maxIterations : "") + " candidate " + candidate.toBriefString();
if( (maxIterations < ITERATIONS_BEFORE_REPORTING) || ((iterationNumber % ITERATIONS_BEFORE_REPORTING) == 0) ) {
logger.info(msg);
} else {
logger.debug(msg);
}
}
void doSlowEvaluation(Candidate candidate) {
loggerPrintIterationNumber(candidate);
StateSet taxo = new StateSet(tree, candidate);
FrequencyTable frequencyTable = new FrequencyTable(symbolTransitions, taxo);
double perf = frequencyTable.sumLogProb();
logger.debug("LogProb (slow evaluation) = " + perf);
candidate.setLogProbability(perf);
results.put(candidate, new Integer(iterationNumber++));
}
void debugDoSlowEvaluation(Candidate candidate) {
loggerPrintIterationNumber(candidate);
StateSet taxo = new StateSet(tree, candidate);
FrequencyTable frequencyTable = new FrequencyTable(symbolTransitions, taxo);
double perf = frequencyTable.sumLogProb();
logger.debug("LogProb (slow evaluation, debug only) = " + perf);
}
void doQuickEvaluation(Candidate candidate) {
loggerPrintIterationNumber(candidate);
ProbabilityDeltaCalculator calc = new ProbabilityDeltaCalculator(tree, symbolTransitions, candidate);
double perf = calc.getLogLikelihood();
logger.debug("LogProb = " + perf);
candidate.setLogProbability(perf);
results.put(candidate, new Integer(iterationNumber++));
}
public void setMaxIterations(int aMaxExperiments) {
maxIterations = aMaxExperiments;
}
public void setBatchSize(int aBatchSize) {
batchSize = aBatchSize;
}
@SuppressWarnings("boxing")
void printReport(PrintWriter out) throws IOException {
logger.info( "Preparing to write log file: re-sorting by experiment number");
// Sort candidates by experiment number
Candidate resultsSorted[] = results.keySet().toArray(new Candidate[] {});
Arrays.sort(resultsSorted, new Comparator<Candidate>() {
public int compare(Candidate arg0, Candidate arg1) {
int it0 = results.get(arg0).intValue();
int it1 = results.get(arg1).intValue();
if (it0 < it1) {
return -1;
} else if (it0 > it1) {
return +1;
} else {
return 0;
}
}
});
logger.info( "Writing log file");
// Print descriptions
out.println("#Iter\tProbability\tNumStates\tScore\tDescription");
for (Candidate candidate : resultsSorted) {
int itNumber = results.get(candidate).intValue();
out.println(itNumber + "\t" + String.format("%.5f", candidate.getLogProbability()) + "\t" + candidate.getNumStates() + "\t" + outputScorer.getScore(candidate) + "\t" + candidate);
}
}
private void printBestModel(int maxOutputStates, PrintWriter out) {
Model hmm = getBestModel(maxOutputStates);
out.println(hmm.toString(PrintMode.FULL));
}
public Model getBestModel() {
return getBestModel(DEFAULT_OUTPUT_MAX_STATES);
}
public Model getBestModel(int maxOutputStates) {
logger.info( "Will score candidates by " + outputScorer );
// Sort candidates by output scorer
Candidate resultsSorted[] = results.keySet().toArray(new Candidate[] {});
Arrays.sort(resultsSorted, (SearchStrategy)outputScorer);
int minNumStatesReached = Integer.MAX_VALUE;
for( int i=0; i<resultsSorted.length; i++ ) {
int numStates = resultsSorted[i].getNumStates();
if( (numStates>1) && (numStates < minNumStatesReached) ) {
minNumStatesReached = numStates;
}
}
int realisticMaxOutputStates = maxOutputStates;
if( minNumStatesReached > maxOutputStates ) {
logger.warn( "Could not find a model with " + maxOutputStates + " states or less" );
logger.warn( "Apart from the model with one state (which is never returned!), all the other models had " + minNumStatesReached + " states or more" );
realisticMaxOutputStates = minNumStatesReached;
}
// Try to find the best one
Candidate winner = null;
for( int i=0; i<resultsSorted.length; i++ ) {
int numStates = resultsSorted[i].getNumStates();
// The model with only 1 state is never returned
if( (numStates>1) && (numStates <= realisticMaxOutputStates ) ) {
winner = resultsSorted[i];
break;
}
}
StateSet taxo = new StateSet(tree, winner);
FrequencyTable frequencyTable = new FrequencyTable(symbolTransitions, taxo);
Model hmm = new Model(frequencyTable, taxo);
return hmm;
}
/**
* @param args
* @throws IOException
* @throws JSAPException
* @throws ClassNotFoundException
* @throws IllegalAccessException
* @throws InstantiationException
*/
@SuppressWarnings("unchecked")
public static void main(String[] args) throws IOException, JSAPException, ClassNotFoundException, InstantiationException, IllegalAccessException {
String searchCriteria = getSearchCriteriaAsString();
final SimpleJSAP jsap = new SimpleJSAP(
LearnModel.class.getName(),
"Learns a HMM based on the frequency of the transitions observed. Requires a list of acceptable states where not there can not be an state and one of its descendants in the taxonomy.",
new Parameter[] {
new Switch("verbose", 'v', "verbose", "Set verbose output"),
new FlaggedOption("taxonomy-file", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.REQUIRED, 't', "taxonomy-file",
"File containing the description of the taxonomy."),
new Switch("init-all-leaves", 'l', "init-all-leaves", "Initial model states are all leaves from taxonomy."),
new FlaggedOption("init-all-level", JSAP.INTEGER_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'e', "init-all-level",
"Initial model states are leaves or internal nodes at level <= x."),
new FlaggedOption("init-explicit", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'x', "init-explicit",
"Initial model states are a quote-enclosed, space-separated list, given as input."),
new FlaggedOption("input-file", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.REQUIRED, 'i', "input-file", "File containing the input sequences."),
new FlaggedOption("max-iterations", JSAP.INTEGER_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'm', "max-iterations",
"Maximum number of iterations to try for building the model."),
new FlaggedOption("batch-size", JSAP.INTEGER_PARSER, Integer.toString(DEFAULT_BATCH_SIZE), JSAP.NOT_REQUIRED, 'b', "batch-size",
"How many candidates to evaluate at a time before adding their children to the priority queue."),
new FlaggedOption("search-method", JSAP.STRING_PARSER, SearchStrategy.DEFAULT_SEARCH_STRATEGY.getSimpleName(), JSAP.NOT_REQUIRED, 's',
"search-method", "Search method, allowed values: " + searchCriteria.toString()),
new FlaggedOption("search-method-weight-1", JSAP.DOUBLE_PARSER, Double.toString(CloserToOrigin.DEFAULT_PROBABILITY_WEIGHT), JSAP.NOT_REQUIRED,
'w', "search-method-weight-1", "For the method " + CloserToOrigin.class.getSimpleName()
+ ", the relative importance of having low probability versus having less states"),
new FlaggedOption("output-log", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'g', "output-log",
"Output file for logging all searched models"),
new FlaggedOption("output-model", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'o', "output-model",
"Output file for writing the winner model"),
new FlaggedOption("winner-model-maxstates", JSAP.INTEGER_PARSER, Integer.toString(DEFAULT_OUTPUT_MAX_STATES), JSAP.NOT_REQUIRED, 'u',
"winner-model-maxstates",
"Output model (winner) must have at most this number of states; this is taken only as a hint, so if there is no model with that few states, "
+ "the one with the lower number of states will be selected"),
});
final JSAPResult jsapResult = jsap.parse(args);
if (jsap.messagePrinted())
return;
if (jsapResult.getBoolean("verbose")) {
logger.setLevel(Level.DEBUG);
}
File inputFile = new File(jsapResult.getString("input-file"));
File taxoFile = new File(jsapResult.getString("taxonomy-file"));
Taxonomy tree = new Taxonomy(taxoFile);
Candidate initialCandidate;
if (jsapResult.userSpecified("init-explicit")) {
initialCandidate = new Candidate(tree, Util.split(jsapResult.getString("init-explicit")), null, null);
} else if (jsapResult.userSpecified("init-all-level")) {
initialCandidate = Candidate.createFixedLevelCandidate(tree, jsapResult.getInt("init-all-level"));
} else if (jsapResult.getBoolean("init-all-leaves")) {
initialCandidate = Candidate.createLeafCandidate(tree);
} else {
throw new IllegalArgumentException("Either --init-explicit, --init-all-leaves, or --init-all-level should be specified. See --help.");
}
logger.debug("Initial candidate is " + initialCandidate.toBriefString());
Class<SearchStrategy> strategy = (Class<SearchStrategy>) Class.forName(SearchStrategy.class.getPackage().getName() + "." + jsapResult.getString("search-method"));
logger.debug("Search strategy is " + strategy );
double weight1 = jsapResult.getDouble("search-method-weight-1");
if( jsapResult.userSpecified("search-method-weight-1") ) {
logger.info( "Will use weight1=" + weight1 );
}
// Read input file
logger.info("Reading input file");
SymbolTransitionFrequencies symbolTransitions = new SymbolTransitionFrequencies(tree);
symbolTransitions.processFile(inputFile);
// Learning
LearnModel learner = new LearnModel(tree, symbolTransitions, initialCandidate, strategy, weight1);
if (jsapResult.userSpecified("max-iterations")) {
learner.setMaxIterations(jsapResult.getInt("max-iterations"));
}
if (jsapResult.userSpecified("batch-size")) {
learner.setBatchSize(jsapResult.getInt("batch-size"));
}
learner.learn();
if (jsapResult.userSpecified("output-log")) {
String fileName = jsapResult.getString("output-log");
logger.info("Printing logfile to '" + fileName + "'");
PrintWriter reportWriter = outputFile(fileName);
learner.printReport(reportWriter);
reportWriter.close();
}
if (jsapResult.userSpecified("output-model")) {
int maxStates = jsapResult.getInt("winner-model-maxstates");
if( jsapResult.userSpecified("winner-model-maxstates") ) {
logger.info("Will pick the best model only among those having " + maxStates + " states or less" );
}
String fileName = jsapResult.getString("output-model");
logger.info("Printing best model to '" + fileName + "'" );
logger.info("You can use " + GenerateSequences.class.getName() + " to generate sequences");
PrintWriter reportWriter = outputFile(fileName);
learner.printBestModel(maxStates,reportWriter);
reportWriter.close();
} else {
logger.warn("You did not specify a file for writing the output model, it will not be printed.");
}
}
public static String getSearchCriteriaAsString() {
StringBuffer searchCriteria = new StringBuffer();
for (Class<SearchStrategy> comparator : SearchStrategy.KNOWN_SEARCH_STRATEGIES) {
searchCriteria.append(comparator.getSimpleName() + " ");
}
return searchCriteria.toString();
}
static PrintWriter outputFile(String fileName) throws IOException {
if (fileName.equals("-")) {
return new PrintWriter(System.out);
} else {
return new PrintWriter(new FileWriter(fileName));
}
}
}