Package com.yahoo.labs.taxomo

Source Code of com.yahoo.labs.taxomo.LearnModel

/*
* Copyright (c) 2011, Yahoo! Inc. All rights reserved.
*
* Redistribution and use of this software in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* Redistributions of source code must retain the above copyright notice, this list of conditions
* and the following disclaimer.
*
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions
* and the following disclaimer in the documentation and/or other materials provided with the
* distribution.
*
* Neither the name of Yahoo! Inc. nor the names of its contributors may be used to endorse or
* promote products derived from this software without specific prior written permission of Yahoo!
* Inc.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
* WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY
* WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

package com.yahoo.labs.taxomo;

import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.PriorityQueue;
import java.util.Vector;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;

import com.martiansoftware.jsap.FlaggedOption;
import com.martiansoftware.jsap.JSAP;
import com.martiansoftware.jsap.JSAPException;
import com.martiansoftware.jsap.JSAPResult;
import com.martiansoftware.jsap.Parameter;
import com.martiansoftware.jsap.SimpleJSAP;
import com.martiansoftware.jsap.Switch;
import com.yahoo.labs.taxomo.Model.PrintMode;
import com.yahoo.labs.taxomo.learn.Candidate;
import com.yahoo.labs.taxomo.learn.CloserToOrigin;
import com.yahoo.labs.taxomo.learn.ModelScorer;
import com.yahoo.labs.taxomo.learn.ProbabilityDeltaCalculator;
import com.yahoo.labs.taxomo.learn.SearchStrategy;
import com.yahoo.labs.taxomo.util.FrequencyTable;
import com.yahoo.labs.taxomo.util.StateSet;
import com.yahoo.labs.taxomo.util.SymbolTransitionFrequencies;
import com.yahoo.labs.taxomo.util.Taxonomy;
import com.yahoo.labs.taxomo.util.Util;

/**
* Learns a model given a set of input sequences.
* <p>
* See <tt>--help</tt> for command-line options.
* @author chato
*
*/
public class LearnModel {

  public static final Logger logger = Logger.getLogger(GenerateSequences.class);
  static {
    Util.configureLogger(logger, Level.INFO);
  }
 
  private static final int ITERATIONS_BEFORE_REPORTING = 500;
 
  /**
   * The queue cleanup policy removes the 50% less promising candidates if the queue grows too large.
   *
   * This number controls the criterion for triggering the cleanup. For instance, if this is 5, and the max iterations
   * is set to 1,000, then the cleanup policy is triggered when reaching 5,000 candidates.
   */
  private static final int CANDIDATES_PRUNING_TRIGGER_OVERHEAD_FACTOR = 10;
 
  /**
   * The queue cleanup policy is never applied if the number of candidates is less than this number.
   */
  private static final int CANDIDATES_PRUNING_NEVER_TRIGGER_BELOW = 2000;
 
  /**
   * The queue cleanup policy is always applied if the number of candidates is more than this number.
   */
  private static final int CANDIDATES_PRUNING_ALWAYS_TRIGGER_ABOVE = 300000;
 
  public static int DEFAULT_BATCH_SIZE = 1;

  int batchSize = DEFAULT_BATCH_SIZE;

  public static final int DEFAULT_OUTPUT_MAX_STATES = Integer.MAX_VALUE;

  final PriorityQueue<Candidate> candidates;

  final HashMap<Candidate, Integer> results;

  final Taxonomy tree;

  final SymbolTransitionFrequencies symbolTransitions;

  int maxIterations = Integer.MAX_VALUE;

  int iterationNumber = 0;

  final double minLogProbability;

  final double maxLogProbability;

  final int minStates = 1;

  final int maxStates;
 
  final ModelScorer outputScorer;
 
  public LearnModel(Taxonomy aTree, SymbolTransitionFrequencies aSymbolTransitions, Candidate initialCandidate, Class<SearchStrategy> strategy, double weight1) {
    tree = aTree;

    // Read symbol transition
    symbolTransitions = aSymbolTransitions;
   
    // Create set of results
    results = new HashMap<Candidate, Integer>();

    // Compute minProbability
    logger.debug("Evaluating minimum probability");
    ArrayList<String> rootStateOnly = new ArrayList<String>(1);
    rootStateOnly.add(tree.getRootState().name());
    Candidate singleStateCandidate = new Candidate(tree, rootStateOnly, null, null);
    doSlowEvaluation(singleStateCandidate);
    minLogProbability = singleStateCandidate.getLogProbability();
    logger.debug("Minimum probability: " + minLogProbability);

    // Compute maxProbability
    logger.debug("Evaluating maximum probability");
    doSlowEvaluation(initialCandidate);
    maxLogProbability = initialCandidate.getLogProbability();
    logger.debug("Maximum probability: " + minLogProbability);
    maxStates = initialCandidate.getNumStates();
    logger.debug("Maximum states: " + maxStates);

    // Initialize candidate search policy
    SearchStrategy comparator;
    if (strategy.equals(CloserToOrigin.class)) {
      comparator = new CloserToOrigin(minLogProbability, maxLogProbability, minStates, maxStates, weight1);
    } else {
      try {
        comparator = (SearchStrategy) strategy.newInstance();
      } catch (InstantiationException e) {
        throw new IllegalArgumentException(e);
      } catch (IllegalAccessException e) {
        throw new IllegalArgumentException(e);
      }
    }
    logger.debug("Search strategy: " + comparator);
   
    // Initialize scorer (for the output)
    outputScorer = new CloserToOrigin(minLogProbability, maxLogProbability, minStates, maxStates, weight1);
    logger.debug("Output scorer: " + outputScorer);

    // Initialize set of candidates
    candidates = new PriorityQueue<Candidate>(100, comparator);
    addSubCandidates(initialCandidate);
  }

  public void learn() {

    while (!candidates.isEmpty() && results.size() <= maxIterations) {
     
      logger.debug( "There are " + candidates.size() + " candidates in the queue");
      checkCandidatePruning();
     
      Vector<Candidate> candidatesToTest = new Vector<Candidate>();
      for (int i = 0; (i < batchSize && candidates.size() > 0); i++) {
        candidatesToTest.add(candidates.poll());
      }

      for (Candidate candidate : candidatesToTest) {
        if (!results.containsKey(candidate)) {

          doQuickEvaluation(candidate);
          // doSlowEvaluation(candidate);
          // debugDoSlowEvaluation(candidate);

          addSubCandidates(candidate);

        }
      }
     
    }
  }
 
  public void doCandidatesPrunning() {
    logger.info( "Dropping less promissing candidates from the queue" );
   
    int newMaxElements = (candidates.size())/2;
    ArrayList<Candidate> topCandidates = new ArrayList<Candidate>(newMaxElements);
    for( int i=0; i<newMaxElements; i++ ) {
      topCandidates.add( candidates.poll() );
    }
    candidates.clear();
    candidates.addAll(topCandidates);
    logger.info( "Queue checking done, now we have only " + candidates.size() + " candidates" );

  }
 
  public void checkCandidatePruning() {
    int size = candidates.size();
   
    if( size < CANDIDATES_PRUNING_NEVER_TRIGGER_BELOW ) {
      return;
     
    } else if( size > CANDIDATES_PRUNING_ALWAYS_TRIGGER_ABOVE ) {
      doCandidatesPrunning();
     
    } else if(  ( maxIterations < Integer.MAX_VALUE )
         && ( size > (maxIterations * CANDIDATES_PRUNING_TRIGGER_OVERHEAD_FACTOR) )
         ) {
      doCandidatesPrunning();
    }
  }

  void addSubCandidates(Candidate candidate) {
    ArrayList<Candidate> subCandidates = candidate.generateChildrenCandidatesMergingStates();
    for (Candidate subCandidate : subCandidates) {
      candidates.add(subCandidate);
    }
  }
 
 
  void loggerPrintIterationNumber(Candidate candidate) {
    // Avoid printing too many informational messages
    String msg = "Evaluation " + iterationNumber + (maxIterations < Integer.MAX_VALUE ? "/" + maxIterations : "") + " candidate " + candidate.toBriefString();
    if( (maxIterations < ITERATIONS_BEFORE_REPORTING) || ((iterationNumber % ITERATIONS_BEFORE_REPORTING) == 0) ) {
      logger.info(msg);
    } else {
      logger.debug(msg);
    }
  }

  void doSlowEvaluation(Candidate candidate) {
    loggerPrintIterationNumber(candidate);
    StateSet taxo = new StateSet(tree, candidate);
    FrequencyTable frequencyTable = new FrequencyTable(symbolTransitions, taxo);
    double perf = frequencyTable.sumLogProb();
    logger.debug("LogProb (slow evaluation) = " + perf);
    candidate.setLogProbability(perf);
    results.put(candidate, new Integer(iterationNumber++));
  }

  void debugDoSlowEvaluation(Candidate candidate) {
    loggerPrintIterationNumber(candidate);
    StateSet taxo = new StateSet(tree, candidate);
    FrequencyTable frequencyTable = new FrequencyTable(symbolTransitions, taxo);
    double perf = frequencyTable.sumLogProb();
    logger.debug("LogProb (slow evaluation, debug only) = " + perf);
  }

  void doQuickEvaluation(Candidate candidate) {
    loggerPrintIterationNumber(candidate);
    ProbabilityDeltaCalculator calc = new ProbabilityDeltaCalculator(tree, symbolTransitions, candidate);
    double perf = calc.getLogLikelihood();
    logger.debug("LogProb = " + perf);
    candidate.setLogProbability(perf);
    results.put(candidate, new Integer(iterationNumber++));
  }

  public void setMaxIterations(int aMaxExperiments) {
    maxIterations = aMaxExperiments;
  }

  public void setBatchSize(int aBatchSize) {
    batchSize = aBatchSize;
  }

  @SuppressWarnings("boxing")
  void printReport(PrintWriter out) throws IOException {

    logger.info( "Preparing to write log file: re-sorting by experiment number");
    // Sort candidates by experiment number
    Candidate resultsSorted[] = results.keySet().toArray(new Candidate[] {});
    Arrays.sort(resultsSorted, new Comparator<Candidate>() {
      public int compare(Candidate arg0, Candidate arg1) {
        int it0 = results.get(arg0).intValue();
        int it1 = results.get(arg1).intValue();
        if (it0 < it1) {
          return -1;
        } else if (it0 > it1) {
          return +1;
        } else {
          return 0;
        }
      }
    });

    logger.info( "Writing log file");
    // Print descriptions
    out.println("#Iter\tProbability\tNumStates\tScore\tDescription");
    for (Candidate candidate : resultsSorted) {
      int itNumber = results.get(candidate).intValue();
      out.println(itNumber + "\t" + String.format("%.5f", candidate.getLogProbability()) + "\t" + candidate.getNumStates() + "\t" + outputScorer.getScore(candidate) + "\t" + candidate);
    }
  }

  private void printBestModel(int maxOutputStates, PrintWriter out) {
    Model hmm = getBestModel(maxOutputStates);
    out.println(hmm.toString(PrintMode.FULL));
  }
 
  public Model getBestModel() {
    return getBestModel(DEFAULT_OUTPUT_MAX_STATES);
  }
 
  public Model getBestModel(int maxOutputStates) {
    logger.info( "Will score candidates by " + outputScorer );
   
    // Sort candidates by output scorer
    Candidate resultsSorted[] = results.keySet().toArray(new Candidate[] {});
    Arrays.sort(resultsSorted, (SearchStrategy)outputScorer);
   
    int minNumStatesReached = Integer.MAX_VALUE;
    for( int i=0; i<resultsSorted.length; i++ ) {
      int numStates = resultsSorted[i].getNumStates();
      if( (numStates>1) && (numStates < minNumStatesReached) ) {
        minNumStatesReached = numStates;
      }
    }
   
    int realisticMaxOutputStates = maxOutputStates;
    if( minNumStatesReached > maxOutputStates ) {
      logger.warn( "Could not find a model with " + maxOutputStates + " states or less" );
      logger.warn( "Apart from the model with one state (which is never returned!), all the other models had " + minNumStatesReached + " states or more" );
      realisticMaxOutputStates = minNumStatesReached;
    }
   
    // Try to find the best one
    Candidate winner = null;
    for( int i=0; i<resultsSorted.length; i++ ) {
      int numStates = resultsSorted[i].getNumStates();
      // The model with only 1 state is never returned
      if( (numStates>1) && (numStates <= realisticMaxOutputStates ) ) {
        winner = resultsSorted[i];
        break;
      }
    }
       
    StateSet taxo = new StateSet(tree, winner);
    FrequencyTable frequencyTable = new FrequencyTable(symbolTransitions, taxo);
    Model hmm = new Model(frequencyTable, taxo);
    return hmm;
  }

  /**
   * @param args
   * @throws IOException
   * @throws JSAPException
   * @throws ClassNotFoundException
   * @throws IllegalAccessException
   * @throws InstantiationException
   */
  @SuppressWarnings("unchecked")
  public static void main(String[] args) throws IOException, JSAPException, ClassNotFoundException, InstantiationException, IllegalAccessException {

    String searchCriteria = getSearchCriteriaAsString();

    final SimpleJSAP jsap = new SimpleJSAP(
        LearnModel.class.getName(),
        "Learns a HMM based on the frequency of the transitions observed. Requires a list of acceptable states where not there can not be an state and one of its descendants in the taxonomy.",
        new Parameter[] {
            new Switch("verbose", 'v', "verbose", "Set verbose output"),
            new FlaggedOption("taxonomy-file", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.REQUIRED, 't', "taxonomy-file",
                "File containing the description of the taxonomy."),
            new Switch("init-all-leaves", 'l', "init-all-leaves", "Initial model states are all leaves from taxonomy."),
            new FlaggedOption("init-all-level", JSAP.INTEGER_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'e', "init-all-level",
                "Initial model states are leaves or internal nodes at level <= x."),
            new FlaggedOption("init-explicit", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'x', "init-explicit",
                "Initial model states are a quote-enclosed, space-separated list, given as input."),
            new FlaggedOption("input-file", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.REQUIRED, 'i', "input-file", "File containing the input sequences."),
            new FlaggedOption("max-iterations", JSAP.INTEGER_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'm', "max-iterations",
              "Maximum number of iterations to try for building the model."),
            new FlaggedOption("batch-size", JSAP.INTEGER_PARSER, Integer.toString(DEFAULT_BATCH_SIZE), JSAP.NOT_REQUIRED, 'b', "batch-size",
              "How many candidates to evaluate at a time before adding their children to the priority queue."),
            new FlaggedOption("search-method", JSAP.STRING_PARSER, SearchStrategy.DEFAULT_SEARCH_STRATEGY.getSimpleName(), JSAP.NOT_REQUIRED, 's',
                "search-method", "Search method, allowed values: " + searchCriteria.toString()),
            new FlaggedOption("search-method-weight-1", JSAP.DOUBLE_PARSER, Double.toString(CloserToOrigin.DEFAULT_PROBABILITY_WEIGHT), JSAP.NOT_REQUIRED,
                'w', "search-method-weight-1", "For the method " + CloserToOrigin.class.getSimpleName()
                    + ", the relative importance of having low probability versus having less states"),
            new FlaggedOption("output-log", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'g', "output-log",
                "Output file for logging all searched models"),
            new FlaggedOption("output-model", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'o', "output-model",
                "Output file for writing the winner model"),
            new FlaggedOption("winner-model-maxstates", JSAP.INTEGER_PARSER, Integer.toString(DEFAULT_OUTPUT_MAX_STATES), JSAP.NOT_REQUIRED, 'u',
                "winner-model-maxstates",
                "Output model (winner) must have at most this number of states; this is taken only as a hint, so if there is no model with that few states, "
                    + "the one with the lower number of states will be selected"),
           
        });

    final JSAPResult jsapResult = jsap.parse(args);
    if (jsap.messagePrinted())
      return;

    if (jsapResult.getBoolean("verbose")) {
      logger.setLevel(Level.DEBUG);
    }

    File inputFile = new File(jsapResult.getString("input-file"));
    File taxoFile = new File(jsapResult.getString("taxonomy-file"));
    Taxonomy tree = new Taxonomy(taxoFile);
    Candidate initialCandidate;

    if (jsapResult.userSpecified("init-explicit")) {
      initialCandidate = new Candidate(tree, Util.split(jsapResult.getString("init-explicit")), null, null);

    } else if (jsapResult.userSpecified("init-all-level")) {
      initialCandidate = Candidate.createFixedLevelCandidate(tree, jsapResult.getInt("init-all-level"));
     
    } else if (jsapResult.getBoolean("init-all-leaves")) {
      initialCandidate = Candidate.createLeafCandidate(tree);

    } else {
      throw new IllegalArgumentException("Either --init-explicit, --init-all-leaves, or --init-all-level should be specified. See --help.");
    }
    logger.debug("Initial candidate is " + initialCandidate.toBriefString());
   
    Class<SearchStrategy> strategy = (Class<SearchStrategy>) Class.forName(SearchStrategy.class.getPackage().getName() + "." + jsapResult.getString("search-method"));
    logger.debug("Search strategy is " + strategy );
   
    double weight1 = jsapResult.getDouble("search-method-weight-1");
    if( jsapResult.userSpecified("search-method-weight-1") ) {
      logger.info( "Will use weight1=" + weight1 );
    }
   
    // Read input file
    logger.info("Reading input file");
    SymbolTransitionFrequencies symbolTransitions = new SymbolTransitionFrequencies(tree);
    symbolTransitions.processFile(inputFile);

    // Learning
    LearnModel learner = new LearnModel(tree, symbolTransitions, initialCandidate, strategy, weight1);
    if (jsapResult.userSpecified("max-iterations")) {
      learner.setMaxIterations(jsapResult.getInt("max-iterations"));
    }
    if (jsapResult.userSpecified("batch-size")) {
      learner.setBatchSize(jsapResult.getInt("batch-size"));
    }
    learner.learn();

    if (jsapResult.userSpecified("output-log")) {
      String fileName = jsapResult.getString("output-log");
      logger.info("Printing logfile to '" + fileName + "'");

      PrintWriter reportWriter = outputFile(fileName);
      learner.printReport(reportWriter);
      reportWriter.close();
    }

    if (jsapResult.userSpecified("output-model")) {
      int maxStates = jsapResult.getInt("winner-model-maxstates");
      if( jsapResult.userSpecified("winner-model-maxstates") ) {
        logger.info("Will pick the best model only among those having " + maxStates + " states or less" );
      }
     
      String fileName = jsapResult.getString("output-model");
      logger.info("Printing best model to '" + fileName + "'" );
      logger.info("You can use " + GenerateSequences.class.getName() + " to generate sequences");
      PrintWriter reportWriter = outputFile(fileName);
      learner.printBestModel(maxStates,reportWriter);
      reportWriter.close();

    } else {
      logger.warn("You did not specify a file for writing the output model, it will not be printed.");
    }

  }

  public static String getSearchCriteriaAsString() {
    StringBuffer searchCriteria = new StringBuffer();
    for (Class<SearchStrategy> comparator : SearchStrategy.KNOWN_SEARCH_STRATEGIES) {
      searchCriteria.append(comparator.getSimpleName() + " ");
    }
    return searchCriteria.toString();
  }

  static PrintWriter outputFile(String fileName) throws IOException {
    if (fileName.equals("-")) {
      return new PrintWriter(System.out);
    } else {
      return new PrintWriter(new FileWriter(fileName));
    }
  }
}
TOP

Related Classes of com.yahoo.labs.taxomo.LearnModel

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.