Package com.yahoo.labs.taxomo

Source Code of com.yahoo.labs.taxomo.ClusterSequences

/*
* Copyright (c) 2011, Yahoo! Inc. All rights reserved.
*
* Redistribution and use of this software in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* Redistributions of source code must retain the above copyright notice, this list of conditions
* and the following disclaimer.
*
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions
* and the following disclaimer in the documentation and/or other materials provided with the
* distribution.
*
* Neither the name of Yahoo! Inc. nor the names of its contributors may be used to endorse or
* promote products derived from this software without specific prior written permission of Yahoo!
* Inc.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
* WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY
* WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

package com.yahoo.labs.taxomo;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;

import com.martiansoftware.jsap.FlaggedOption;
import com.martiansoftware.jsap.JSAP;
import com.martiansoftware.jsap.JSAPException;
import com.martiansoftware.jsap.JSAPResult;
import com.martiansoftware.jsap.Parameter;
import com.martiansoftware.jsap.SimpleJSAP;
import com.martiansoftware.jsap.Switch;
import com.yahoo.labs.taxomo.Model.PrintMode;
import com.yahoo.labs.taxomo.learn.Candidate;
import com.yahoo.labs.taxomo.learn.CloserToOrigin;
import com.yahoo.labs.taxomo.learn.SearchStrategy;
import com.yahoo.labs.taxomo.util.SymbolTransitionFrequencies;
import com.yahoo.labs.taxomo.util.Taxonomy;
import com.yahoo.labs.taxomo.util.Util;

/**
* Clusters sequences using EM and representing each cluster by a Taxomo model.
* <p>
* See <tt>--help</tt> for command-line options.
*
* @author chato
*
*/
public class ClusterSequences {
  static final Logger logger = Logger.getLogger(ClusterSequences.class);

  static {
    Util.configureLogger(logger, Level.INFO);
  }

  static final int DEFAULT_CLUSTERS = 10;

  static final int DEFAULT_PASSES = 5;

  final int sequence2cluster[];

  final int aux[];

  final Taxonomy tree;

  final File inFile;

  final int nClusters;

  final int nSequences;

  final Candidate prototypeInitialCandidate;

  final Class<SearchStrategy> searchStrategy;

  final double learnerMaxWeight1;

  final int learnerMaxIterations;

  final int learnerBatchSize;

  Model bestModel[];

  public ClusterSequences(Taxonomy aTree, File aInFile, int aNClusters, Candidate initialCandidate, Class<SearchStrategy> aStrategy, double aWeight1,
      int aMaxIterations, int aBatchSize) {
    logger.info("Initializing clustering");

    tree = aTree;
    inFile = aInFile;
    nClusters = aNClusters;
    nSequences = Util.countLines(inFile);
    logger.info("File contains " + nSequences + " sequences");

    logger.info("Doing random initial clustering assignment");
    sequence2cluster = new int[nSequences];
    for (int sequenceNum = 0; sequenceNum < nSequences; sequenceNum++) {
      sequence2cluster[sequenceNum] = (int) Math.floor(Math.random() * nClusters);
    }

    aux = new int[nSequences];
    prototypeInitialCandidate = initialCandidate;
    searchStrategy = aStrategy;
    learnerMaxWeight1 = aWeight1;
    learnerMaxIterations = aMaxIterations;
    learnerBatchSize = aBatchSize;

    bestModel = new Model[nClusters];
  }

  void doClustering(int passes) {
    for (int pass = 1; pass <= passes; pass++) {
      logger.info("Doing pass " + pass + "/" + passes);
      try {
        doPass();
      } catch (IOException e) {
        throw new IllegalStateException("There was an I/O error");
      }
    }
  }

  SymbolTransitionFrequencies[] getFrequencyTables() throws IOException {
    logger.info("Reading the elements to update per-cluster frequency tables");

    // Create empty frequency tables
    SymbolTransitionFrequencies symbolFrequencyTable[] = new SymbolTransitionFrequencies[nClusters];
    for (int i = 0; i < nClusters; i++) {
      symbolFrequencyTable[i] = new SymbolTransitionFrequencies(tree);
    }

    // Read input file line by line
    BufferedReader reader = new BufferedReader(new FileReader(inFile));
    String line;
    int seqNum = 0;
    while ((line = reader.readLine()) != null) {
      // Process one line and update the corresponding frequency table
      symbolFrequencyTable[sequence2cluster[seqNum]].processString(line);
      seqNum++;
    }

    // Freeze the frequency tables
    for (int i = 0; i < nClusters; i++) {
      symbolFrequencyTable[i].freeze();
    }
    return symbolFrequencyTable;
  }

  void learnModelSettingBestModel(SymbolTransitionFrequencies[] symbolFrequencyTable) {

    for (int clusterNum = 0; clusterNum < nClusters; clusterNum++) {

      logger.info("Learning model for cluster " + clusterNum);

      // Clone initial candidate (so that it does not have a set
      // logProbability)
      Candidate initialCandidate = prototypeInitialCandidate.clone();

      // Initialize weight1 for model
      double weight1 = learnerMaxWeight1 != -1 ? learnerMaxWeight1 : 1.0;

      // Create and set-up learner
      LearnModel learner = new LearnModel(tree, symbolFrequencyTable[clusterNum], initialCandidate, searchStrategy, weight1);
      if (learnerBatchSize != -1) {
        learner.setBatchSize(learnerBatchSize);
      }
      if (learnerMaxIterations != -1) {
        learner.setMaxIterations(learnerMaxIterations);
      }

      // Learn
      learner.learn();

      // Get best model
      bestModel[clusterNum] = learner.getBestModel();
    }

  }

  void evaluateNearestNeighborSettingAux() throws IOException {
    logger.info("Evaluating nearest cluster for each element");

    BufferedReader reader = new BufferedReader(new FileReader(inFile));
    int seqNum = 0;
    String line;
    while ((line = reader.readLine()) != null) {

      ArrayList<String> seq = Util.split(line);

      double maxProbability = Double.NEGATIVE_INFINITY;
      int bestCluster = -1;
      for (int clusterNum = 0; clusterNum < nClusters; clusterNum++) {
        double probability = bestModel[clusterNum].viterbiCalculateNonOverlap(seq);
        if (probability > maxProbability) {
          bestCluster = clusterNum;
          maxProbability = probability;
        }
      }

      aux[seqNum] = bestCluster;
      seqNum++;
    }

    int movedElements = 0;
    for (int sequenceNum = 0; sequenceNum < nSequences; sequenceNum++) {
      if (aux[sequenceNum] != sequence2cluster[sequenceNum]) {
        movedElements++;
      }
    }
    logger.info("Elements that moved: " + movedElements + " (" + Math.floor((double) movedElements * 10000.0 / (double) sequence2cluster.length) / 100 + "%)");
  }

  void doPass() throws IOException {

    // Clear auxiliary vector
    Arrays.fill(aux, 0);

    // Update frequency tables
    SymbolTransitionFrequencies symbolFrequencyTables[] = getFrequencyTables();

    // Get best models
    learnModelSettingBestModel(symbolFrequencyTables);

    // Evaluate nearest-neighbors
    evaluateNearestNeighborSettingAux();

    // Replace old cluster by new clusters
    logger.info("Replacing old cluster by new clustering");
    int[] clusterSizes = new int[nClusters];
    for (int sequenceNum = 0; sequenceNum < nSequences; sequenceNum++) {
      sequence2cluster[sequenceNum] = aux[sequenceNum];
      clusterSizes[sequence2cluster[sequenceNum]]++;
    }

    // Report
    for (int clusterNum = 0; clusterNum < nClusters; clusterNum++) {
      logger.info(" Cluster " + clusterNum + " " + (int) Math.floor((double) clusterSizes[clusterNum] * 100.0 / (double) nSequences) + "% "
          + bestModel[clusterNum].toString(PrintMode.STATES_ONLY));
    }

  }

  private void printModels(String outModelsBasename) throws IOException {
    for (int clusterNum = 0; clusterNum < nClusters; clusterNum++) {
      String outFileName = outModelsBasename + "_" + clusterNum + Model.DEFAULT_FILE_EXTENSION;
      logger.info("Writing model of cluster " + clusterNum + " to " + outFileName);
      PrintWriter writer = new PrintWriter(new BufferedWriter(new FileWriter(outFileName), 1 * 1024 * 1024)); // 1Mb
      writer.println(bestModel[clusterNum].toString(PrintMode.FULL));
      writer.close();
    }
  }

  private void printClusterAssignment(String outFileName) throws IOException {
    logger.info("Writing clustering assignment to " + outFileName);
    PrintWriter writer = new PrintWriter(new BufferedWriter(new FileWriter(outFileName), 1 * 1024 * 1024)); // 1Mb
    BufferedReader reader = new BufferedReader(new FileReader(inFile));
    String line;
    int seqNum = 0;

    writer.println("#Cluster\tSequence");
    while ((line = reader.readLine()) != null) {
      writer.println(sequence2cluster[seqNum] + "\t" + line);
      seqNum++;
    }
    writer.close();

  }

  @SuppressWarnings("unchecked")
  public static void main(String[] args) throws JSAPException, IOException, ClassNotFoundException {

    final SimpleJSAP jsap = new SimpleJSAP(ClusterSequences.class.getName(),
        "Parses input sequences to do clustering on them. Input sequences should have one sequence per line. "
            + "Each sequence is a tab-separated list of symbols.", new Parameter[] {
            new Switch("verbose", 'v', "verbose", "Set verbose output"),
            new FlaggedOption("taxonomy-file", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.REQUIRED, 't', "taxonomy-file",
                "File containing the description of the taxonomy."),
            new FlaggedOption("clusters", JSAP.INTEGER_PARSER, Integer.toString(DEFAULT_CLUSTERS), JSAP.NOT_REQUIRED, 'c', "clusters",
                "Number of clusters to use."),
            new FlaggedOption("passes", JSAP.INTEGER_PARSER, Integer.toString(DEFAULT_PASSES), JSAP.NOT_REQUIRED, 'p', "passes",
                "Number of passes of the algorithm to do."),

            new Switch("init-all-leaves", 'l', "init-all-leaves", "Initial model states are all leaves from taxonomy."),
            new FlaggedOption("init-all-level", JSAP.INTEGER_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'e', "init-all-level",
                "Initial model states are leaves or internal nodes at level <= x."),
            new FlaggedOption("init-explicit", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'x', "init-explicit",
                "Initial model states are a quote-enclosed, space-separated list, given as input."),
            new FlaggedOption("max-iterations", JSAP.INTEGER_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'm', "max-iterations",
                "Maximum number of iterations to try for building the model."),
            new FlaggedOption("batch-size", JSAP.INTEGER_PARSER, Integer.toString(LearnModel.DEFAULT_BATCH_SIZE), JSAP.NOT_REQUIRED, 'b', "batch-size",
                "How many candidates to evaluate at a time before adding their children to the priority queue."),
            new FlaggedOption("search-method", JSAP.STRING_PARSER, SearchStrategy.DEFAULT_SEARCH_STRATEGY.getSimpleName(), JSAP.NOT_REQUIRED, 's',
                "search-method", "Search method, allowed values: " + LearnModel.getSearchCriteriaAsString()),
            new FlaggedOption("search-method-weight-1", JSAP.DOUBLE_PARSER, Double.toString(CloserToOrigin.DEFAULT_PROBABILITY_WEIGHT), JSAP.NOT_REQUIRED,
                'w', "search-method-weight-1", "For the method " + CloserToOrigin.class.getSimpleName()
                    + ", the relative importance of having low probability versus having less states"),
            new FlaggedOption("winner-model-maxstates", JSAP.INTEGER_PARSER, Integer.toString(LearnModel.DEFAULT_OUTPUT_MAX_STATES), JSAP.NOT_REQUIRED, 'u',
                "winner-model-maxstates",
                "Output model (winner) must have at most this number of states; this is taken only as a hint, so if there is no model with that few states, "
                    + "the one with the lower number of states will be selected"),

            new FlaggedOption("input-file", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.REQUIRED, 'i', "input-file", "File containing the input sequences."),
            new FlaggedOption("output-assignment", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.REQUIRED, 'a', "output-assignment",
                "File for writing the output cluster assignments."),
            new FlaggedOption("output-models-basename", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'o', "output-models-basename",
                "Base filename for writing the models."),

        });

    final JSAPResult jsapResult = jsap.parse(args);
    if (jsap.messagePrinted())
      return;

    if (jsapResult.getBoolean("verbose")) {
      logger.setLevel(Level.DEBUG);
    }

    final String inFileName = jsapResult.getString("input-file");
    File inFile = new File(inFileName);
    final String outFileName = jsapResult.getString("output-assignment");

    final int nClusters = jsapResult.getInt("clusters");
    final int clusteringIterations = jsapResult.getInt("passes");

    // Load tree
    File taxoFile = new File(jsapResult.getString("taxonomy-file"));
    Taxonomy tree = new Taxonomy(taxoFile);

    // Build initial candidate
    Candidate initialCandidate;
    if (jsapResult.userSpecified("init-explicit")) {
      initialCandidate = new Candidate(tree, Util.split(jsapResult.getString("init-explicit")), null, null);

    } else if (jsapResult.userSpecified("init-all-level")) {
      initialCandidate = Candidate.createFixedLevelCandidate(tree, jsapResult.getInt("init-all-level"));

    } else if (jsapResult.getBoolean("init-all-leaves")) {
      initialCandidate = Candidate.createLeafCandidate(tree);

    } else {
      throw new IllegalArgumentException("Either --init-explicit, --init-all-leaves, or --init-all-level should be specified. See --help.");
    }
    logger.debug("Initial candidate is " + initialCandidate.toBriefString());

    // Set learning parameters
    Class<SearchStrategy> strategy = (Class<SearchStrategy>) Class.forName(SearchStrategy.class.getPackage().getName() + "." + jsapResult.getString("search-method"));
    logger.debug("Search strategy is " + strategy);

    double weight1 = -1;
    if (jsapResult.userSpecified("search-method-weight-1")) {
      weight1 = jsapResult.getDouble("search-method-weight-1");
      logger.info("Will use weight1=" + weight1);
    }
    int maxIterations = -1;
    if (jsapResult.userSpecified("max-iterations")) {
      maxIterations = jsapResult.getInt("max-iterations");
      logger.info("Will use maxIterations=" + maxIterations);
    }
    int batchSize = -1;
    if (jsapResult.userSpecified("batch-size")) {
      batchSize = jsapResult.getInt("batch-size");
      logger.info("Will use batchSize=" + batchSize);
    }

    // Do the clustering
    ClusterSequences clustering = new ClusterSequences(tree, inFile, nClusters, initialCandidate, strategy, weight1, maxIterations, batchSize);
    clustering.doClustering(clusteringIterations);

    // Write clusters
    clustering.printClusterAssignment(outFileName);

    // Write models
    if (jsapResult.userSpecified("output-models-basename")) {

      final String outModelsBasename = jsapResult.getString("output-models-basename");
      clustering.printModels(outModelsBasename);

    } else {
      logger.warn("You did not specify a file for writing the output models, will not print them");
    }

  }

}
TOP

Related Classes of com.yahoo.labs.taxomo.ClusterSequences

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.