/*
* Copyright (c) 2011, Yahoo! Inc. All rights reserved.
*
* Redistribution and use of this software in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* Redistributions of source code must retain the above copyright notice, this list of conditions
* and the following disclaimer.
*
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions
* and the following disclaimer in the documentation and/or other materials provided with the
* distribution.
*
* Neither the name of Yahoo! Inc. nor the names of its contributors may be used to endorse or
* promote products derived from this software without specific prior written permission of Yahoo!
* Inc.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
* WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY
* WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package com.yahoo.labs.taxomo;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import be.ac.ulg.montefiore.run.jahmm.Hmm;
import be.ac.ulg.montefiore.run.jahmm.ObservationInteger;
import be.ac.ulg.montefiore.run.jahmm.Opdf;
import be.ac.ulg.montefiore.run.jahmm.OpdfInteger;
import be.ac.ulg.montefiore.run.jahmm.ViterbiCalculator;
import be.ac.ulg.montefiore.run.jahmm.toolbox.MarkovGenerator;
import com.yahoo.labs.taxomo.util.FrequencyTable;
import com.yahoo.labs.taxomo.util.State;
import com.yahoo.labs.taxomo.util.StateSet;
/**
* Represents a Taxomo model.
* <p>
* This model is described by: the transition probabilities, the emission probabilities, and the
* taxonomy tree from which the states are extracted.
* <p>
* Also, in a Taxomo model there is always a starting state "^" and a terminal state "$".
* <p>
* This class is based on, and extends, the Hmm class of <a
* href="http://code.google.com/p/jahmm/">JAHMM</a> by Jean-Marc François (distributed under a BSD
* license).
*/
public class Model extends Hmm<ObservationInteger> {
/**
* Printing mode for the model.
* <p>
* Iff SHORT, then only transitions/emissions with probability larger than
* {@link #PRINT_MODE_SHORT_MIN_PROBABILITY} will be written.
*
* @author chato
*/
public enum PrintMode {
FULL, SHORT, STATES_ONLY
}
private static final long serialVersionUID = 1L;
/**
* The default filename extension for a file containing a model
*/
public static final String DEFAULT_FILE_EXTENSION = ".mod";
/**
* The taxonomy tree of the states in this taxonomy
*/
private final StateSet taxo;
/**
* Number of states
*/
private final int nSt;
/**
* Number of observable symbols
*/
private final int nOb;
/**
* The minimum magnitude of a probability to print it when printing in short mode
*/
public static final double PRINT_MODE_SHORT_MIN_PROBABILITY = 1e-1;
/**
* Error tolerance when comparing floats.
*/
private static final double ERR_TOLERANCE = 1e-2;
/**
* Returns a generator associated to this model
*
* @return
*/
private MarkovGenerator<ObservationInteger> generator() {
return new MarkovGenerator<ObservationInteger>(this);
}
/**
* Gets the probability of a given state emitting a given symbol
*
* @param state The state
* @param symbol The symbol
* @return The emission probability
*/
public double getEmissionProbability(int state, int symbol) {
Opdf<ObservationInteger> oPdf = getOpdf(state);
return oPdf.probability(new ObservationInteger(symbol));
}
/**
* Initializes the parameters of this model.
*
* @param Aij The transition probabilities
* @param op The emission probabilities
*/
private void initParameters(double Aij[][], double op[][]) {
// Initial probabilities are simple: always the starting symbol starts
for (int state = 0; state < nSt; state++) {
setPi(state, 0);
}
setPi(taxo.getStartingStateNumber(), 1.0);
// Transition probabilities
for (int row = 0; row < nSt; row++) {
if (Aij[row].length != nSt) {
throw new IllegalArgumentException("Transition probabilities should be a matrix with " + (nSt) + " cols");
}
if (isSumming(Aij[row], 0.0)) {
Aij[row][row] = 1.0;
}
if (!isStochastic(Aij[row])) {
throw new IllegalArgumentException("Transition probabilities from state " + taxo.getState(row) + " do not form a stochastic vector");
}
for (int col = 0; col < nSt; col++) {
this.setAij(row, col, Aij[row][col]);
}
}
// Last state is absorbing
if (taxo.getTerminalStateNumber() != nSt - 1) {
throw new IllegalArgumentException("The terminal symbol number must be " + (nSt - 1) + " but it is " + taxo.getTerminalStateNumber());
}
for (int col = 0; col < nSt - 1; col++) {
this.setAij(nSt - 1, col, 0.0);
}
this.setAij(nSt - 1, nSt - 1, 1.0);
// Set output probabilities
if (op.length != nSt) {
throw new IllegalArgumentException("The length of the observation probabilities is incorrect");
}
for (int state = 0; state < nSt; state++) {
if (isSumming(op[state], 0.0)) {
op[state][0] = 1.0;
}
if (!isStochastic(op[state])) {
double sum = 0;
for (int i = 0; i < op[state].length; i++) {
sum += op[state][i];
}
throw new IllegalArgumentException("Observation probabilities of state " + taxo.getState(state) + " do not form a stochastic vector, they add up to "
+ sum);
}
this.setOpdf(state, new OpdfInteger(op[state]));
}
double opLast[] = new double[nOb];
Arrays.fill(opLast, 0.0);
opLast[nOb - 1] = 1.0;
this.setOpdf(nSt - 1, new OpdfInteger(opLast));
}
/**
* Checks is a vector is stochastic
*
* @param vec The vector to be tested
* @return True iff the vector components add up to 1.0 up to a certain error tolerance
*/
private boolean isStochastic(double[] vec) {
return isSumming(vec, 1.0);
}
/**
* Checks if the sum of the components of a vector matchs the given number
*
* @param vec The vector to be tested
* @param expectedSum The expected sum
* @return True iff the vector components sum expectedSum
*/
private boolean isSumming(double[] vec, double expectedSum) {
double sum = 0.0;
for (int i = 0; i < vec.length; i++) {
sum += vec[i];
}
if (Math.abs(sum - expectedSum) < ERR_TOLERANCE) {
return true;
} else {
return false;
}
}
/**
* Computes the log likelihood of a set of sequences read from an input stream
*
* @param inSequences The input stream containing the sequences
* @return Log probability
* @throws IOException
*/
public double logProb(InputStream inSequences) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(inSequences));
double sumLogProb = 0.0;
double sumProb = 0.0;
int cnt = 0;
String line;
while ((line = br.readLine()) != null) {
ArrayList<ObservationInteger> seq = taxo.obSeqWithBoundarySymbols(line.split(" "));
double p = (new ViterbiCalculator(seq, this)).lnProbability();
sumProb += Math.exp(p);
sumLogProb += p;
cnt++;
}
return sumLogProb;
}
/**
* Computes the log likelihood of a set of sequences read from a file
*
* @param inSequences The file containing the sequences
* @return Log probability
* @throws IOException
* @throws FileNotFoundException
*/
public double logProb(File inSequences) throws FileNotFoundException, IOException {
return logProb(new FileInputStream(inSequences));
}
/**
* Generates a sequence using this model.
*
* @return The generated sequence
*/
public ArrayList<String> sequence() {
MarkovGenerator<ObservationInteger> markovGenerator = generator();
ArrayList<String> seq = new ArrayList<String>();
// First observation must be initial state
if (!taxo.isStartingSymbol(taxo.getSymbol(markovGenerator.observation().value))) {
throw new IllegalStateException("First symbol should always be the starting symbol");
}
String symbol = taxo.getSymbol(markovGenerator.observation().value);
while (!taxo.isTerminalSymbol(symbol)) {
seq.add(symbol);
symbol = taxo.getSymbol(markovGenerator.observation().value);
}
if (seq.size() == 0) {
throw new IllegalStateException("Generated a zero-length sequence, which is wrong if pi of terminal state is 0.0");
}
return seq;
}
public String toString() {
return toString(PrintMode.FULL);
}
/**
* Prints the model given a specified {@link PrintMode}
*
* @param mode The printing mode
* @return A string describing the model
*/
public String toString(PrintMode mode) {
if (mode == PrintMode.STATES_ONLY) {
return taxo.getStringAllowedStates();
}
StringBuffer sb = new StringBuffer();
double minProb = 0;
if (mode == PrintMode.SHORT) {
minProb = PRINT_MODE_SHORT_MIN_PROBABILITY;
}
sb.append("# Valid states\n");
sb.append("s " + taxo.getStringAllowedStates() + "\n");
sb.append("\n");
sb.append("# Transition probabilities\n");
for (int row = 0; row < nSt - 1; row++) {
for (int col = 0; col < nSt; col++) {
double a = getAij(row, col);
if (a > minProb) {
sb.append("t " + taxo.getState(row).name() + " " + taxo.getState(col).name() + " " + a + "\n");
}
}
}
sb.append("\n");
sb.append("# Emission probabilities\n");
for (int state = 0; state < nSt; state++) {
Opdf<ObservationInteger> oPdf = getOpdf(state);
for (int symbolNum = 0; symbolNum < nOb; symbolNum++) {
double p = oPdf.probability(new ObservationInteger(symbolNum));
if (p > minProb) {
sb.append("o " + taxo.getState(state).name() + " " + taxo.getSymbol(symbolNum) + " " + p + "\n");
}
}
}
return sb.toString();
}
/**
* Computes the log probability of a sequence using the standard viterbi algorithm.
* <p>
* It makes no assumption on the model, and it is much slower than
* {@link #viterbiCalculateNonOverlap(ArrayList)} for models in which emission probabilities do
* not overlap.
*
* @param seq The sequence to be evaluated
* @return The log probability of that sequence
*/
public double viterbiCalculate(ArrayList<String> seq) {
ArrayList<ObservationInteger> oSeq = taxo.obSeqWithBoundarySymbols(seq);
ViterbiCalculator vc = new ViterbiCalculator(oSeq, this);
return vc.lnProbability();
}
/**
* Computes the log probability of a sequence assuming that emission probabilities for different
* states never overlap.
* <p>
* This is much faster than {@link #viterbiCalculate(ArrayList)}.
*
* @param seq The sequence to be evaluated
* @return The log probability of that sequence
*/
public double viterbiCalculateNonOverlap(ArrayList<String> seq) {
// Assumes there are no overlaps
double lnProb = 0.0;
// Initial probability
int firstState = taxo.getUniqueStateForSymbol(taxo.getSymbolNumber(seq.get(0)));
lnProb += Math.log(getAij(taxo.getStartingStateNumber(), firstState));
for (int i = 0; i < seq.size() - 1; i++) {
int symFrom = taxo.getSymbolNumber(seq.get(i));
int symDest = taxo.getSymbolNumber(seq.get(i + 1));
int stateFrom = taxo.getUniqueStateForSymbol(symFrom);
int stateDest = taxo.getUniqueStateForSymbol(symDest);
// Probability due to transition
lnProb += Math.log(getAij(stateFrom, stateDest));
// Probability due to emission
lnProb += Math.log(getEmissionProbability(stateFrom, symFrom));
}
// Terminal probability
int lastSymbol = taxo.getSymbolNumber(seq.get(seq.size() - 1));
int lastState = taxo.getUniqueStateForSymbol(lastSymbol);
lnProb += Math.log(getEmissionProbability(lastState, lastSymbol));
lnProb += Math.log(getAij(lastState, taxo.getTerminalStateNumber()));
return lnProb;
}
/**
* @see #viterbiCalculateNonOverlap(ArrayList)
* @param sequenceStr
* @return log probability of sequence assuming there are no overlaps
*/
public double viterbiCalculateNonOverlap(String sequenceStr) {
String[] tokens = sequenceStr.split(" ");
ArrayList<String> seq = new ArrayList<String>(tokens.length);
for (int i = 0; i < tokens.length; i++) {
seq.add(tokens[i]);
}
return viterbiCalculateNonOverlap(seq);
}
/**
* Gets the list of states in the viterbi (optimal) state path for emitting a given sequence.
* <p>
* It uses the slow {@link #viterbiCalculate(ArrayList)} that makes no assumptions on the model.
*
* @param seq The sequence of symbols being evaluated
* @return The most likely sequence of states to produce that sequence
*/
public State[] viterbiPath(ArrayList<String> seq) {
ViterbiCalculator vc = new ViterbiCalculator(taxo.obSeqWithBoundarySymbols(seq), this);
int[] stateSequence = vc.stateSequence();
State[] ms = new State[stateSequence.length];
for (int i = 0; i < stateSequence.length; i++) {
ms[i] = taxo.getState(stateSequence[i]);
}
return ms;
}
/**
* Creates a model given an input stream and a taxonomy tree
*
* @param inStream The input stream
* @param theTaxo The taxonomy tree
* @throws IOException If the input file can not be read properly
*/
public Model(InputStream inStream, StateSet theTaxo) throws IOException {
super(theTaxo.numStates());
this.taxo = theTaxo;
nSt = taxo.numStates();
nOb = taxo.numSymbols();
double op[][] = new double[nSt][nOb];
double Aij[][] = new double[nSt][nSt];
BufferedReader br = new BufferedReader(new InputStreamReader(inStream));
String line;
boolean statesOk = false;
boolean transitionsOk = false;
boolean observationsOk = false;
while ((line = br.readLine()) != null) {
if (line.length() == 0) {
// Ignore blank lines
continue;
} else if (line.charAt(0) == '#') {
// Ignore comments
continue;
} else if (line.charAt(0) == 's') {
// Ignore: is used when loading states
statesOk = true;
continue;
} else {
String[] toks = line.split(" ");
switch (toks[0].charAt(0)) {
case 't':
transitionsOk = true;
Aij[taxo.getStateNumber(toks[1])][taxo.getStateNumber(toks[2])] = Double.parseDouble(toks[3]);
break;
case 'o':
observationsOk = true;
op[taxo.getStateNumber(toks[1])][taxo.getSymbolNumber(toks[2])] = Double.parseDouble(toks[3]);
break;
default:
throw new IllegalArgumentException("Wrong format in this line: '" + line + "'");
}
}
}
if (!statesOk) {
throw new IllegalArgumentException("The model file does not contain states");
}
if (!transitionsOk) {
throw new IllegalArgumentException("The model file does not contain any transition");
}
if (!observationsOk) {
throw new IllegalArgumentException("The model file does not contain any observation probability");
}
initParameters(Aij, op);
}
/**
* Creates a model given an input file and a taxonomy tree
*
* @param inFile The input file
* @param theTaxo The taxonomy tree
* @throws IOException If the input file can not be read properly
*/
public Model(File inFile, StateSet theTaxo) throws IOException {
this(new FileInputStream(inFile), theTaxo);
}
/**
* Creates a model from a frequency table and a taxonomy
*
* @param freqs The frequency table
* @param theTaxo The taxonomy
*/
public Model(FrequencyTable freqs, StateSet theTaxo) {
super(theTaxo.numStates());
this.taxo = theTaxo;
nSt = taxo.numStates();
nOb = taxo.numSymbols();
// Initial probability
for (int state = 0; state < nSt; state++) {
this.setPi(state, 0);
}
this.setPi(theTaxo.getStartingStateNumber(), 1.0);
for (int row = 0; row < nSt; row++) {
for (int col = 0; col < nSt; col++) {
this.setAij(row, col, freqs.getStateTransitionProbability(row, col));
}
}
for (int state = 0; state < nSt; state++) {
double[] emissionProbabilities = freqs.getEmissionProbabilities(state);
setOpdf(state, new OpdfInteger(emissionProbabilities));
}
}
/**
* Creates a model given another model.
*
* @param hmm The model being copied
* @param theTaxo The taxonomy tree.
*/
public Model(Hmm<ObservationInteger> hmm, StateSet theTaxo) {
super(hmm.nbStates());
this.taxo = theTaxo;
nSt = taxo.numStates();
nOb = taxo.numSymbols();
for (int state = 0; state < nSt; state++) {
this.setPi(state, hmm.getPi(state));
}
for (int row = 0; row < nSt; row++) {
for (int col = 0; col < nSt; col++) {
this.setAij(row, col, hmm.getAij(row, col));
}
}
for (int state = 0; state < nSt; state++) {
Opdf<ObservationInteger> oPdf = hmm.getOpdf(state);
setOpdf(state, oPdf.clone());
}
}
}