/*
* Copyright (c) 2011, Yahoo! Inc. All rights reserved.
*
* Redistribution and use of this software in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* Redistributions of source code must retain the above copyright notice, this list of conditions
* and the following disclaimer.
*
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions
* and the following disclaimer in the documentation and/or other materials provided with the
* distribution.
*
* Neither the name of Yahoo! Inc. nor the names of its contributors may be used to endorse or
* promote products derived from this software without specific prior written permission of Yahoo!
* Inc.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
* WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY
* WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package com.yahoo.labs.taxomo;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import com.martiansoftware.jsap.FlaggedOption;
import com.martiansoftware.jsap.JSAP;
import com.martiansoftware.jsap.JSAPException;
import com.martiansoftware.jsap.JSAPResult;
import com.martiansoftware.jsap.Parameter;
import com.martiansoftware.jsap.SimpleJSAP;
import com.martiansoftware.jsap.Switch;
import com.yahoo.labs.taxomo.Model.PrintMode;
import com.yahoo.labs.taxomo.learn.Candidate;
import com.yahoo.labs.taxomo.learn.CloserToOrigin;
import com.yahoo.labs.taxomo.learn.SearchStrategy;
import com.yahoo.labs.taxomo.util.SymbolTransitionFrequencies;
import com.yahoo.labs.taxomo.util.Taxonomy;
import com.yahoo.labs.taxomo.util.Util;
/**
* Clusters sequences using EM and representing each cluster by a Taxomo model.
* <p>
* See <tt>--help</tt> for command-line options.
*
* @author chato
*
*/
public class ClusterSequences {
static final Logger logger = Logger.getLogger(ClusterSequences.class);
static {
Util.configureLogger(logger, Level.INFO);
}
static final int DEFAULT_CLUSTERS = 10;
static final int DEFAULT_PASSES = 5;
final int sequence2cluster[];
final int aux[];
final Taxonomy tree;
final File inFile;
final int nClusters;
final int nSequences;
final Candidate prototypeInitialCandidate;
final Class<SearchStrategy> searchStrategy;
final double learnerMaxWeight1;
final int learnerMaxIterations;
final int learnerBatchSize;
Model bestModel[];
public ClusterSequences(Taxonomy aTree, File aInFile, int aNClusters, Candidate initialCandidate, Class<SearchStrategy> aStrategy, double aWeight1,
int aMaxIterations, int aBatchSize) {
logger.info("Initializing clustering");
tree = aTree;
inFile = aInFile;
nClusters = aNClusters;
nSequences = Util.countLines(inFile);
logger.info("File contains " + nSequences + " sequences");
logger.info("Doing random initial clustering assignment");
sequence2cluster = new int[nSequences];
for (int sequenceNum = 0; sequenceNum < nSequences; sequenceNum++) {
sequence2cluster[sequenceNum] = (int) Math.floor(Math.random() * nClusters);
}
aux = new int[nSequences];
prototypeInitialCandidate = initialCandidate;
searchStrategy = aStrategy;
learnerMaxWeight1 = aWeight1;
learnerMaxIterations = aMaxIterations;
learnerBatchSize = aBatchSize;
bestModel = new Model[nClusters];
}
void doClustering(int passes) {
for (int pass = 1; pass <= passes; pass++) {
logger.info("Doing pass " + pass + "/" + passes);
try {
doPass();
} catch (IOException e) {
throw new IllegalStateException("There was an I/O error");
}
}
}
SymbolTransitionFrequencies[] getFrequencyTables() throws IOException {
logger.info("Reading the elements to update per-cluster frequency tables");
// Create empty frequency tables
SymbolTransitionFrequencies symbolFrequencyTable[] = new SymbolTransitionFrequencies[nClusters];
for (int i = 0; i < nClusters; i++) {
symbolFrequencyTable[i] = new SymbolTransitionFrequencies(tree);
}
// Read input file line by line
BufferedReader reader = new BufferedReader(new FileReader(inFile));
String line;
int seqNum = 0;
while ((line = reader.readLine()) != null) {
// Process one line and update the corresponding frequency table
symbolFrequencyTable[sequence2cluster[seqNum]].processString(line);
seqNum++;
}
// Freeze the frequency tables
for (int i = 0; i < nClusters; i++) {
symbolFrequencyTable[i].freeze();
}
return symbolFrequencyTable;
}
void learnModelSettingBestModel(SymbolTransitionFrequencies[] symbolFrequencyTable) {
for (int clusterNum = 0; clusterNum < nClusters; clusterNum++) {
logger.info("Learning model for cluster " + clusterNum);
// Clone initial candidate (so that it does not have a set
// logProbability)
Candidate initialCandidate = prototypeInitialCandidate.clone();
// Initialize weight1 for model
double weight1 = learnerMaxWeight1 != -1 ? learnerMaxWeight1 : 1.0;
// Create and set-up learner
LearnModel learner = new LearnModel(tree, symbolFrequencyTable[clusterNum], initialCandidate, searchStrategy, weight1);
if (learnerBatchSize != -1) {
learner.setBatchSize(learnerBatchSize);
}
if (learnerMaxIterations != -1) {
learner.setMaxIterations(learnerMaxIterations);
}
// Learn
learner.learn();
// Get best model
bestModel[clusterNum] = learner.getBestModel();
}
}
void evaluateNearestNeighborSettingAux() throws IOException {
logger.info("Evaluating nearest cluster for each element");
BufferedReader reader = new BufferedReader(new FileReader(inFile));
int seqNum = 0;
String line;
while ((line = reader.readLine()) != null) {
ArrayList<String> seq = Util.split(line);
double maxProbability = Double.NEGATIVE_INFINITY;
int bestCluster = -1;
for (int clusterNum = 0; clusterNum < nClusters; clusterNum++) {
double probability = bestModel[clusterNum].viterbiCalculateNonOverlap(seq);
if (probability > maxProbability) {
bestCluster = clusterNum;
maxProbability = probability;
}
}
aux[seqNum] = bestCluster;
seqNum++;
}
int movedElements = 0;
for (int sequenceNum = 0; sequenceNum < nSequences; sequenceNum++) {
if (aux[sequenceNum] != sequence2cluster[sequenceNum]) {
movedElements++;
}
}
logger.info("Elements that moved: " + movedElements + " (" + Math.floor((double) movedElements * 10000.0 / (double) sequence2cluster.length) / 100 + "%)");
}
void doPass() throws IOException {
// Clear auxiliary vector
Arrays.fill(aux, 0);
// Update frequency tables
SymbolTransitionFrequencies symbolFrequencyTables[] = getFrequencyTables();
// Get best models
learnModelSettingBestModel(symbolFrequencyTables);
// Evaluate nearest-neighbors
evaluateNearestNeighborSettingAux();
// Replace old cluster by new clusters
logger.info("Replacing old cluster by new clustering");
int[] clusterSizes = new int[nClusters];
for (int sequenceNum = 0; sequenceNum < nSequences; sequenceNum++) {
sequence2cluster[sequenceNum] = aux[sequenceNum];
clusterSizes[sequence2cluster[sequenceNum]]++;
}
// Report
for (int clusterNum = 0; clusterNum < nClusters; clusterNum++) {
logger.info(" Cluster " + clusterNum + " " + (int) Math.floor((double) clusterSizes[clusterNum] * 100.0 / (double) nSequences) + "% "
+ bestModel[clusterNum].toString(PrintMode.STATES_ONLY));
}
}
private void printModels(String outModelsBasename) throws IOException {
for (int clusterNum = 0; clusterNum < nClusters; clusterNum++) {
String outFileName = outModelsBasename + "_" + clusterNum + Model.DEFAULT_FILE_EXTENSION;
logger.info("Writing model of cluster " + clusterNum + " to " + outFileName);
PrintWriter writer = new PrintWriter(new BufferedWriter(new FileWriter(outFileName), 1 * 1024 * 1024)); // 1Mb
writer.println(bestModel[clusterNum].toString(PrintMode.FULL));
writer.close();
}
}
private void printClusterAssignment(String outFileName) throws IOException {
logger.info("Writing clustering assignment to " + outFileName);
PrintWriter writer = new PrintWriter(new BufferedWriter(new FileWriter(outFileName), 1 * 1024 * 1024)); // 1Mb
BufferedReader reader = new BufferedReader(new FileReader(inFile));
String line;
int seqNum = 0;
writer.println("#Cluster\tSequence");
while ((line = reader.readLine()) != null) {
writer.println(sequence2cluster[seqNum] + "\t" + line);
seqNum++;
}
writer.close();
}
@SuppressWarnings("unchecked")
public static void main(String[] args) throws JSAPException, IOException, ClassNotFoundException {
final SimpleJSAP jsap = new SimpleJSAP(ClusterSequences.class.getName(),
"Parses input sequences to do clustering on them. Input sequences should have one sequence per line. "
+ "Each sequence is a tab-separated list of symbols.", new Parameter[] {
new Switch("verbose", 'v', "verbose", "Set verbose output"),
new FlaggedOption("taxonomy-file", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.REQUIRED, 't', "taxonomy-file",
"File containing the description of the taxonomy."),
new FlaggedOption("clusters", JSAP.INTEGER_PARSER, Integer.toString(DEFAULT_CLUSTERS), JSAP.NOT_REQUIRED, 'c', "clusters",
"Number of clusters to use."),
new FlaggedOption("passes", JSAP.INTEGER_PARSER, Integer.toString(DEFAULT_PASSES), JSAP.NOT_REQUIRED, 'p', "passes",
"Number of passes of the algorithm to do."),
new Switch("init-all-leaves", 'l', "init-all-leaves", "Initial model states are all leaves from taxonomy."),
new FlaggedOption("init-all-level", JSAP.INTEGER_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'e', "init-all-level",
"Initial model states are leaves or internal nodes at level <= x."),
new FlaggedOption("init-explicit", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'x', "init-explicit",
"Initial model states are a quote-enclosed, space-separated list, given as input."),
new FlaggedOption("max-iterations", JSAP.INTEGER_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'm', "max-iterations",
"Maximum number of iterations to try for building the model."),
new FlaggedOption("batch-size", JSAP.INTEGER_PARSER, Integer.toString(LearnModel.DEFAULT_BATCH_SIZE), JSAP.NOT_REQUIRED, 'b', "batch-size",
"How many candidates to evaluate at a time before adding their children to the priority queue."),
new FlaggedOption("search-method", JSAP.STRING_PARSER, SearchStrategy.DEFAULT_SEARCH_STRATEGY.getSimpleName(), JSAP.NOT_REQUIRED, 's',
"search-method", "Search method, allowed values: " + LearnModel.getSearchCriteriaAsString()),
new FlaggedOption("search-method-weight-1", JSAP.DOUBLE_PARSER, Double.toString(CloserToOrigin.DEFAULT_PROBABILITY_WEIGHT), JSAP.NOT_REQUIRED,
'w', "search-method-weight-1", "For the method " + CloserToOrigin.class.getSimpleName()
+ ", the relative importance of having low probability versus having less states"),
new FlaggedOption("winner-model-maxstates", JSAP.INTEGER_PARSER, Integer.toString(LearnModel.DEFAULT_OUTPUT_MAX_STATES), JSAP.NOT_REQUIRED, 'u',
"winner-model-maxstates",
"Output model (winner) must have at most this number of states; this is taken only as a hint, so if there is no model with that few states, "
+ "the one with the lower number of states will be selected"),
new FlaggedOption("input-file", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.REQUIRED, 'i', "input-file", "File containing the input sequences."),
new FlaggedOption("output-assignment", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.REQUIRED, 'a', "output-assignment",
"File for writing the output cluster assignments."),
new FlaggedOption("output-models-basename", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'o', "output-models-basename",
"Base filename for writing the models."),
});
final JSAPResult jsapResult = jsap.parse(args);
if (jsap.messagePrinted())
return;
if (jsapResult.getBoolean("verbose")) {
logger.setLevel(Level.DEBUG);
}
final String inFileName = jsapResult.getString("input-file");
File inFile = new File(inFileName);
final String outFileName = jsapResult.getString("output-assignment");
final int nClusters = jsapResult.getInt("clusters");
final int clusteringIterations = jsapResult.getInt("passes");
// Load tree
File taxoFile = new File(jsapResult.getString("taxonomy-file"));
Taxonomy tree = new Taxonomy(taxoFile);
// Build initial candidate
Candidate initialCandidate;
if (jsapResult.userSpecified("init-explicit")) {
initialCandidate = new Candidate(tree, Util.split(jsapResult.getString("init-explicit")), null, null);
} else if (jsapResult.userSpecified("init-all-level")) {
initialCandidate = Candidate.createFixedLevelCandidate(tree, jsapResult.getInt("init-all-level"));
} else if (jsapResult.getBoolean("init-all-leaves")) {
initialCandidate = Candidate.createLeafCandidate(tree);
} else {
throw new IllegalArgumentException("Either --init-explicit, --init-all-leaves, or --init-all-level should be specified. See --help.");
}
logger.debug("Initial candidate is " + initialCandidate.toBriefString());
// Set learning parameters
Class<SearchStrategy> strategy = (Class<SearchStrategy>) Class.forName(SearchStrategy.class.getPackage().getName() + "." + jsapResult.getString("search-method"));
logger.debug("Search strategy is " + strategy);
double weight1 = -1;
if (jsapResult.userSpecified("search-method-weight-1")) {
weight1 = jsapResult.getDouble("search-method-weight-1");
logger.info("Will use weight1=" + weight1);
}
int maxIterations = -1;
if (jsapResult.userSpecified("max-iterations")) {
maxIterations = jsapResult.getInt("max-iterations");
logger.info("Will use maxIterations=" + maxIterations);
}
int batchSize = -1;
if (jsapResult.userSpecified("batch-size")) {
batchSize = jsapResult.getInt("batch-size");
logger.info("Will use batchSize=" + batchSize);
}
// Do the clustering
ClusterSequences clustering = new ClusterSequences(tree, inFile, nClusters, initialCandidate, strategy, weight1, maxIterations, batchSize);
clustering.doClustering(clusteringIterations);
// Write clusters
clustering.printClusterAssignment(outFileName);
// Write models
if (jsapResult.userSpecified("output-models-basename")) {
final String outModelsBasename = jsapResult.getString("output-models-basename");
clustering.printModels(outModelsBasename);
} else {
logger.warn("You did not specify a file for writing the output models, will not print them");
}
}
}