Package com.yahoo.labs.taxomo

Source Code of com.yahoo.labs.taxomo.Model

/*
* Copyright (c) 2011, Yahoo! Inc. All rights reserved.
*
* Redistribution and use of this software in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* Redistributions of source code must retain the above copyright notice, this list of conditions
* and the following disclaimer.
*
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions
* and the following disclaimer in the documentation and/or other materials provided with the
* distribution.
*
* Neither the name of Yahoo! Inc. nor the names of its contributors may be used to endorse or
* promote products derived from this software without specific prior written permission of Yahoo!
* Inc.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
* WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY
* WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

package com.yahoo.labs.taxomo;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;

import be.ac.ulg.montefiore.run.jahmm.Hmm;
import be.ac.ulg.montefiore.run.jahmm.ObservationInteger;
import be.ac.ulg.montefiore.run.jahmm.Opdf;
import be.ac.ulg.montefiore.run.jahmm.OpdfInteger;
import be.ac.ulg.montefiore.run.jahmm.ViterbiCalculator;
import be.ac.ulg.montefiore.run.jahmm.toolbox.MarkovGenerator;

import com.yahoo.labs.taxomo.util.FrequencyTable;
import com.yahoo.labs.taxomo.util.State;
import com.yahoo.labs.taxomo.util.StateSet;

/**
* Represents a Taxomo model.
* <p>
* This model is described by: the transition probabilities, the emission probabilities, and the
* taxonomy tree from which the states are extracted.
* <p>
* Also, in a Taxomo model there is always a starting state "^" and a terminal state "$".
* <p>
* This class is based on, and extends, the Hmm class of <a
* href="http://code.google.com/p/jahmm/">JAHMM</a> by Jean-Marc François (distributed under a BSD
* license).
*/
public class Model extends Hmm<ObservationInteger> {

  /**
   * Printing mode for the model.
   * <p>
   * Iff SHORT, then only transitions/emissions with probability larger than
   * {@link #PRINT_MODE_SHORT_MIN_PROBABILITY} will be written.
   *
   * @author chato
   */
  public enum PrintMode {
    FULL, SHORT, STATES_ONLY
  }

  private static final long serialVersionUID = 1L;

  /**
   * The default filename extension for a file containing a model
   */
  public static final String DEFAULT_FILE_EXTENSION = ".mod";

  /**
   * The taxonomy tree of the states in this taxonomy
   */
  private final StateSet taxo;

  /**
   * Number of states
   */
  private final int nSt;

  /**
   * Number of observable symbols
   */
  private final int nOb;

  /**
   * The minimum magnitude of a probability to print it when printing in short mode
   */
  public static final double PRINT_MODE_SHORT_MIN_PROBABILITY = 1e-1;

  /**
   * Error tolerance when comparing floats.
   */
  private static final double ERR_TOLERANCE = 1e-2;

  /**
   * Returns a generator associated to this model
   *
   * @return
   */
  private MarkovGenerator<ObservationInteger> generator() {
    return new MarkovGenerator<ObservationInteger>(this);
  }

  /**
   * Gets the probability of a given state emitting a given symbol
   *
   * @param state The state
   * @param symbol The symbol
   * @return The emission probability
   */
  public double getEmissionProbability(int state, int symbol) {
    Opdf<ObservationInteger> oPdf = getOpdf(state);
    return oPdf.probability(new ObservationInteger(symbol));
  }

  /**
   * Initializes the parameters of this model.
   *
   * @param Aij The transition probabilities
   * @param op The emission probabilities
   */
  private void initParameters(double Aij[][], double op[][]) {

    // Initial probabilities are simple: always the starting symbol starts
    for (int state = 0; state < nSt; state++) {
      setPi(state, 0);
    }
    setPi(taxo.getStartingStateNumber(), 1.0);

    // Transition probabilities
    for (int row = 0; row < nSt; row++) {
      if (Aij[row].length != nSt) {
        throw new IllegalArgumentException("Transition probabilities should be a matrix with " + (nSt) + " cols");
      }
      if (isSumming(Aij[row], 0.0)) {
        Aij[row][row] = 1.0;
      }
      if (!isStochastic(Aij[row])) {
        throw new IllegalArgumentException("Transition probabilities from state " + taxo.getState(row) + " do not form a stochastic vector");
      }
      for (int col = 0; col < nSt; col++) {
        this.setAij(row, col, Aij[row][col]);
      }
    }

    // Last state is absorbing
    if (taxo.getTerminalStateNumber() != nSt - 1) {
      throw new IllegalArgumentException("The terminal symbol number must be " + (nSt - 1) + " but it is " + taxo.getTerminalStateNumber());
    }
    for (int col = 0; col < nSt - 1; col++) {
      this.setAij(nSt - 1, col, 0.0);
    }
    this.setAij(nSt - 1, nSt - 1, 1.0);

    // Set output probabilities
    if (op.length != nSt) {
      throw new IllegalArgumentException("The length of the observation probabilities is incorrect");
    }
    for (int state = 0; state < nSt; state++) {
      if (isSumming(op[state], 0.0)) {
        op[state][0] = 1.0;
      }
      if (!isStochastic(op[state])) {
        double sum = 0;
        for (int i = 0; i < op[state].length; i++) {
          sum += op[state][i];
        }
        throw new IllegalArgumentException("Observation probabilities of state " + taxo.getState(state) + " do not form a stochastic vector, they add up to "
            + sum);
      }
      this.setOpdf(state, new OpdfInteger(op[state]));
    }
    double opLast[] = new double[nOb];
    Arrays.fill(opLast, 0.0);
    opLast[nOb - 1] = 1.0;
    this.setOpdf(nSt - 1, new OpdfInteger(opLast));
  }

  /**
   * Checks is a vector is stochastic
   *
   * @param vec The vector to be tested
   * @return True iff the vector components add up to 1.0 up to a certain error tolerance
   */
  private boolean isStochastic(double[] vec) {
    return isSumming(vec, 1.0);
  }

  /**
   * Checks if the sum of the components of a vector matchs the given number
   *
   * @param vec The vector to be tested
   * @param expectedSum The expected sum
   * @return True iff the vector components sum expectedSum
   */
  private boolean isSumming(double[] vec, double expectedSum) {
    double sum = 0.0;
    for (int i = 0; i < vec.length; i++) {
      sum += vec[i];
    }
    if (Math.abs(sum - expectedSum) < ERR_TOLERANCE) {
      return true;
    } else {
      return false;
    }
  }

  /**
   * Computes the log likelihood of a set of sequences read from an input stream
   *
   * @param inSequences The input stream containing the sequences
   * @return Log probability
   * @throws IOException
   */
  public double logProb(InputStream inSequences) throws IOException {
    BufferedReader br = new BufferedReader(new InputStreamReader(inSequences));
    double sumLogProb = 0.0;
    double sumProb = 0.0;
    int cnt = 0;
    String line;

    while ((line = br.readLine()) != null) {
      ArrayList<ObservationInteger> seq = taxo.obSeqWithBoundarySymbols(line.split(" "));
      double p = (new ViterbiCalculator(seq, this)).lnProbability();
      sumProb += Math.exp(p);
      sumLogProb += p;
      cnt++;
    }
    return sumLogProb;
  }

  /**
   * Computes the log likelihood of a set of sequences read from a file
   *
   * @param inSequences The file containing the sequences
   * @return Log probability
   * @throws IOException
   * @throws FileNotFoundException
   */
  public double logProb(File inSequences) throws FileNotFoundException, IOException {
    return logProb(new FileInputStream(inSequences));
  }

  /**
   * Generates a sequence using this model.
   *
   * @return The generated sequence
   */
  public ArrayList<String> sequence() {
    MarkovGenerator<ObservationInteger> markovGenerator = generator();
    ArrayList<String> seq = new ArrayList<String>();

    // First observation must be initial state
    if (!taxo.isStartingSymbol(taxo.getSymbol(markovGenerator.observation().value))) {
      throw new IllegalStateException("First symbol should always be the starting symbol");
    }

    String symbol = taxo.getSymbol(markovGenerator.observation().value);
    while (!taxo.isTerminalSymbol(symbol)) {
      seq.add(symbol);
      symbol = taxo.getSymbol(markovGenerator.observation().value);
    }

    if (seq.size() == 0) {
      throw new IllegalStateException("Generated a zero-length sequence, which is wrong if pi of terminal state is 0.0");
    }
    return seq;
  }

  public String toString() {
    return toString(PrintMode.FULL);
  }

  /**
   * Prints the model given a specified {@link PrintMode}
   *
   * @param mode The printing mode
   * @return A string describing the model
   */
  public String toString(PrintMode mode) {
    if (mode == PrintMode.STATES_ONLY) {
      return taxo.getStringAllowedStates();
    }

    StringBuffer sb = new StringBuffer();

    double minProb = 0;
    if (mode == PrintMode.SHORT) {
      minProb = PRINT_MODE_SHORT_MIN_PROBABILITY;
    }

    sb.append("# Valid states\n");
    sb.append("s " + taxo.getStringAllowedStates() + "\n");
    sb.append("\n");

    sb.append("# Transition probabilities\n");
    for (int row = 0; row < nSt - 1; row++) {
      for (int col = 0; col < nSt; col++) {
        double a = getAij(row, col);
        if (a > minProb) {
          sb.append("t " + taxo.getState(row).name() + " " + taxo.getState(col).name() + " " + a + "\n");
        }
      }
    }
    sb.append("\n");

    sb.append("# Emission probabilities\n");
    for (int state = 0; state < nSt; state++) {
      Opdf<ObservationInteger> oPdf = getOpdf(state);
      for (int symbolNum = 0; symbolNum < nOb; symbolNum++) {
        double p = oPdf.probability(new ObservationInteger(symbolNum));
        if (p > minProb) {
          sb.append("o " + taxo.getState(state).name() + " " + taxo.getSymbol(symbolNum) + " " + p + "\n");
        }
      }
    }

    return sb.toString();
  }

  /**
   * Computes the log probability of a sequence using the standard viterbi algorithm.
   * <p>
   * It makes no assumption on the model, and it is much slower than
   * {@link #viterbiCalculateNonOverlap(ArrayList)} for models in which emission probabilities do
   * not overlap.
   *
   * @param seq The sequence to be evaluated
   * @return The log probability of that sequence
   */
  public double viterbiCalculate(ArrayList<String> seq) {
    ArrayList<ObservationInteger> oSeq = taxo.obSeqWithBoundarySymbols(seq);
    ViterbiCalculator vc = new ViterbiCalculator(oSeq, this);
    return vc.lnProbability();
  }

  /**
   * Computes the log probability of a sequence assuming that emission probabilities for different
   * states never overlap.
   * <p>
   * This is much faster than {@link #viterbiCalculate(ArrayList)}.
   *
   * @param seq The sequence to be evaluated
   * @return The log probability of that sequence
   */
  public double viterbiCalculateNonOverlap(ArrayList<String> seq) {
    // Assumes there are no overlaps
    double lnProb = 0.0;

    // Initial probability
    int firstState = taxo.getUniqueStateForSymbol(taxo.getSymbolNumber(seq.get(0)));
    lnProb += Math.log(getAij(taxo.getStartingStateNumber(), firstState));

    for (int i = 0; i < seq.size() - 1; i++) {
      int symFrom = taxo.getSymbolNumber(seq.get(i));
      int symDest = taxo.getSymbolNumber(seq.get(i + 1));

      int stateFrom = taxo.getUniqueStateForSymbol(symFrom);
      int stateDest = taxo.getUniqueStateForSymbol(symDest);

      // Probability due to transition
      lnProb += Math.log(getAij(stateFrom, stateDest));

      // Probability due to emission
      lnProb += Math.log(getEmissionProbability(stateFrom, symFrom));
    }

    // Terminal probability
    int lastSymbol = taxo.getSymbolNumber(seq.get(seq.size() - 1));
    int lastState = taxo.getUniqueStateForSymbol(lastSymbol);
    lnProb += Math.log(getEmissionProbability(lastState, lastSymbol));
    lnProb += Math.log(getAij(lastState, taxo.getTerminalStateNumber()));

    return lnProb;
  }

  /**
   * @see #viterbiCalculateNonOverlap(ArrayList)
   * @param sequenceStr
   * @return log probability of sequence assuming there are no overlaps
   */
  public double viterbiCalculateNonOverlap(String sequenceStr) {
    String[] tokens = sequenceStr.split(" ");
    ArrayList<String> seq = new ArrayList<String>(tokens.length);
    for (int i = 0; i < tokens.length; i++) {
      seq.add(tokens[i]);
    }
    return viterbiCalculateNonOverlap(seq);
  }

  /**
   * Gets the list of states in the viterbi (optimal) state path for emitting a given sequence.
   * <p>
   * It uses the slow {@link #viterbiCalculate(ArrayList)} that makes no assumptions on the model.
   *
   * @param seq The sequence of symbols being evaluated
   * @return The most likely sequence of states to produce that sequence
   */
  public State[] viterbiPath(ArrayList<String> seq) {
    ViterbiCalculator vc = new ViterbiCalculator(taxo.obSeqWithBoundarySymbols(seq), this);
    int[] stateSequence = vc.stateSequence();
    State[] ms = new State[stateSequence.length];
    for (int i = 0; i < stateSequence.length; i++) {
      ms[i] = taxo.getState(stateSequence[i]);
    }
    return ms;
  }

  /**
   * Creates a model given an input stream and a taxonomy tree
   *
   * @param inStream The input stream
   * @param theTaxo The taxonomy tree
   * @throws IOException If the input file can not be read properly
   */
  public Model(InputStream inStream, StateSet theTaxo) throws IOException {
    super(theTaxo.numStates());
    this.taxo = theTaxo;

    nSt = taxo.numStates();
    nOb = taxo.numSymbols();

    double op[][] = new double[nSt][nOb];
    double Aij[][] = new double[nSt][nSt];

    BufferedReader br = new BufferedReader(new InputStreamReader(inStream));
    String line;

    boolean statesOk = false;
    boolean transitionsOk = false;
    boolean observationsOk = false;

    while ((line = br.readLine()) != null) {

      if (line.length() == 0) {
        // Ignore blank lines
        continue;
      } else if (line.charAt(0) == '#') {
        // Ignore comments
        continue;
      } else if (line.charAt(0) == 's') {
        // Ignore: is used when loading states
        statesOk = true;
        continue;

      } else {
        String[] toks = line.split(" ");
        switch (toks[0].charAt(0)) {
        case 't':
          transitionsOk = true;
          Aij[taxo.getStateNumber(toks[1])][taxo.getStateNumber(toks[2])] = Double.parseDouble(toks[3]);
          break;
        case 'o':
          observationsOk = true;
          op[taxo.getStateNumber(toks[1])][taxo.getSymbolNumber(toks[2])] = Double.parseDouble(toks[3]);
          break;
        default:
          throw new IllegalArgumentException("Wrong format in this line: '" + line + "'");
        }
      }
    }

    if (!statesOk) {
      throw new IllegalArgumentException("The model file does not contain states");
    }
    if (!transitionsOk) {
      throw new IllegalArgumentException("The model file does not contain any transition");
    }
    if (!observationsOk) {
      throw new IllegalArgumentException("The model file does not contain any observation probability");
    }

    initParameters(Aij, op);

  }

  /**
   * Creates a model given an input file and a taxonomy tree
   *
   * @param inFile The input file
   * @param theTaxo The taxonomy tree
   * @throws IOException If the input file can not be read properly
   */
  public Model(File inFile, StateSet theTaxo) throws IOException {
    this(new FileInputStream(inFile), theTaxo);
  }

  /**
   * Creates a model from a frequency table and a taxonomy
   *
   * @param freqs The frequency table
   * @param theTaxo The taxonomy
   */
  public Model(FrequencyTable freqs, StateSet theTaxo) {
    super(theTaxo.numStates());
    this.taxo = theTaxo;

    nSt = taxo.numStates();
    nOb = taxo.numSymbols();

    // Initial probability
    for (int state = 0; state < nSt; state++) {
      this.setPi(state, 0);
    }
    this.setPi(theTaxo.getStartingStateNumber(), 1.0);

    for (int row = 0; row < nSt; row++) {
      for (int col = 0; col < nSt; col++) {
        this.setAij(row, col, freqs.getStateTransitionProbability(row, col));
      }
    }

    for (int state = 0; state < nSt; state++) {
      double[] emissionProbabilities = freqs.getEmissionProbabilities(state);
      setOpdf(state, new OpdfInteger(emissionProbabilities));
    }

  }

  /**
   * Creates a model given another model.
   *
   * @param hmm The model being copied
   * @param theTaxo The taxonomy tree.
   */
  public Model(Hmm<ObservationInteger> hmm, StateSet theTaxo) {
    super(hmm.nbStates());
    this.taxo = theTaxo;

    nSt = taxo.numStates();
    nOb = taxo.numSymbols();

    for (int state = 0; state < nSt; state++) {
      this.setPi(state, hmm.getPi(state));
    }
    for (int row = 0; row < nSt; row++) {
      for (int col = 0; col < nSt; col++) {
        this.setAij(row, col, hmm.getAij(row, col));
      }
    }

    for (int state = 0; state < nSt; state++) {
      Opdf<ObservationInteger> oPdf = hmm.getOpdf(state);
      setOpdf(state, oPdf.clone());
    }
  }

}
TOP

Related Classes of com.yahoo.labs.taxomo.Model

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.