Package com.yahoo.labs.taxomo.util

Source Code of com.yahoo.labs.taxomo.util.StateSet

/*
* Copyright (c) 2011, Yahoo! Inc. All rights reserved.
*
* Redistribution and use of this software in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* Redistributions of source code must retain the above copyright notice, this list of conditions
* and the following disclaimer.
*
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions
* and the following disclaimer in the documentation and/or other materials provided with the
* distribution.
*
* Neither the name of Yahoo! Inc. nor the names of its contributors may be used to endorse or
* promote products derived from this software without specific prior written permission of Yahoo!
* Inc.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
* WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY
* WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

package com.yahoo.labs.taxomo.util;

import it.unimi.dsi.Util;
import it.unimi.dsi.fastutil.ints.Int2ObjectArrayMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;

import be.ac.ulg.montefiore.run.jahmm.ObservationInteger;

import com.yahoo.labs.taxomo.learn.Candidate;
import com.yahoo.labs.taxomo.util.State.Type;

/**
* A set of states, represents a "cut" in a taxonomy tree.
*
* @author chato
*
*/
public class StateSet {
 
  static final Logger logger = Logger.getLogger(StateSet.class);
  static {
        Util.ensureLog4JIsConfigured();
        logger.setLevel(Level.DEBUG);
  }
 
  private Taxonomy taxo;

  private Int2ObjectArrayMap<State> num2state;

  private final Map<String, State> str2state;
 
  private final Object2IntOpenHashMap<State> state2num;

  private State startingNode;

  private State terminalNode;

  private final ArrayList<String> allowedStates;
 
  private final boolean hasOverlaps;
 
  private final Object2IntOpenHashMap<String> symbol2state;

  public StateSet(InputStream modelFile, Taxonomy taxo) throws IOException {
    this( extractAllowedStates(modelFile), taxo);
  }
 
  public StateSet(File modelFile, Taxonomy taxo) throws FileNotFoundException, IOException {
    this( extractAllowedStates(modelFile), taxo );
  }
 
  public StateSet(ArrayList<String> theAllowedStates, Taxonomy theTaxo) {

    taxo = theTaxo;
    allowedStates = theAllowedStates;

    // Deal with states
    num2state = new Int2ObjectArrayMap<State>();
    str2state = new HashMap<String, State>();
    state2num = new Object2IntOpenHashMap<State>();
    symbol2state = new Object2IntOpenHashMap<String>();

    // Add starting node
    startingNode = new State(State.STARTING, Type.LEAF);
    state2num.put(startingNode, 0);
    num2state.put(0, startingNode);
    str2state.put(State.STARTING, startingNode);

    boolean foundOverlap = false;
   
    // Add all allowed states
    int stateNum = 1;
    for (String state : allowedStates) {

      if (!taxo.contains(state)) {
        throw new IllegalArgumentException("The tree does not contain the state " + state);
      }

      State node = taxo.getNode(state);
      state2num.put(node,stateNum);
      num2state.put(stateNum, node);
      str2state.put(state, node);
     
      HashSet<State> descendantLeaves = node.getSelfOrDescendantLeaves();
      for( State leaf: descendantLeaves ) {
        if( symbol2state.containsKey(leaf.getNaturalSymbol())) {
          foundOverlap = true;
        } else {
          symbol2state.put(leaf.getNaturalSymbol(), stateNum);
        }
      }
      stateNum++;
    }
   
    hasOverlaps = foundOverlap;

    // Add terminal node
    terminalNode = new State(State.TERMINAL, Type.LEAF);
    state2num.put(terminalNode, stateNum);
    num2state.put(stateNum, terminalNode);
    str2state.put(State.TERMINAL, terminalNode);
  }

  public StateSet(File taxonomyFile, ArrayList<String> allowedStates) throws IOException {
    this(allowedStates, new Taxonomy(taxonomyFile));
  }

  public StateSet(Taxonomy aTree, Candidate candidateToTest) {
    this(candidateToTest.getStates(), aTree );

  }

  public State getState(int stateNumber) {
    if (!num2state.containsKey(stateNumber)) {
      throw new IllegalArgumentException("State number " + stateNumber + " is not valid");
    }
    return num2state.get(stateNumber);
  }

  public State getState(String stateStr) {
    if (!str2state.containsKey(stateStr)) {
      throw new IllegalArgumentException("State string " + stateStr + " is not valid");
    }
    return str2state.get(stateStr);
  }
 
  public int getStateNumber(String stateStr) {
    return state2num.getInt(getState(stateStr));
  }

  public int numStates() {
    return num2state.size();
  }
 
  public static ArrayList<String> extractAllowedStates(InputStream in) throws IOException {
    BufferedReader br = new BufferedReader(new InputStreamReader(in) );
    String line;

    while ((line = br.readLine()) != null) {
      if (line.startsWith("s ")) {
        String toks[] = line.split(" ");
        ArrayList<String> states = new ArrayList<String>();
        for (int i = 1; i < toks.length; i++) {
          states.add(toks[i]);
        }
        return states;
      }
    }
    throw new IllegalArgumentException("Could not find a line starting with ' s' in the model file");
  }
 
  public static ArrayList<String> extractAllowedStates(File in) throws FileNotFoundException, IOException {
    return extractAllowedStates( new FileInputStream(in) );
  }
 
  public boolean isStartingSymbol(String str) {
    return State.STARTING.equalsIgnoreCase(str);
  }

  public boolean isTerminalSymbol(String str) {
    return State.TERMINAL.equalsIgnoreCase(str);
  }

  public int getStartingStateNumber() {
    return getStateNumber(State.STARTING);
  }

  public int getTerminalStateNumber() {
    return getStateNumber(State.TERMINAL);
  }

  public String getTaxonomyTreeFileName() {
    return taxo.getFileName();
  }

  public String getStringAllowedStates() {
    StringBuffer str = new StringBuffer();
    for (int i = 0; i < allowedStates.size(); i++) {
      if (i > 0) {
        str.append(" ");
      }
      str.append(allowedStates.get(i));
    }
    return str.toString();
  }

  public ArrayList<ObservationInteger> obSeqWithBoundarySymbols(String[] seq) {
    ArrayList<String> seqL = new ArrayList<String>();
    for (String str : seq) {
      seqL.add(str);
    }
    return obSeqWithBoundarySymbols(seqL);
  }

  public ArrayList<ObservationInteger> obSeqWithBoundarySymbols(ArrayList<String> seq) {
    // Check length
    if (seq.size() == 0) {
      throw new IllegalArgumentException("Can not receive a zero-length sequence");
    }
    // Check first symbol
    String firstSymbol = seq.get(seq.size() - 1);
    if (isStartingSymbol(firstSymbol)) {
      throw new IllegalArgumentException("Sequence can not start with the starting symbol, will add it inside the function");
    }

    // Check last symbol
    String lastSymbol = seq.get(seq.size() - 1);
    if (isTerminalSymbol(lastSymbol)) {
      throw new IllegalArgumentException("Sequence can not finish in the terminal symbol, will add it inside the function");
    }

    // Convert
    ArrayList<ObservationInteger> oSeq = new ArrayList<ObservationInteger>(seq.size()+2);
    oSeq.add(new ObservationInteger(taxo.getStartingSymbolNumber()));
    for (String element : seq) {
      oSeq.add(new ObservationInteger(taxo.getSymbolNumber(element)));
    }
    oSeq.add(new ObservationInteger(taxo.getTerminalSymbolNumber()));

    return oSeq;
  }

  public int getUniqueStateForSymbol(int symbolNum) {
    String symbol = taxo.getSymbol(symbolNum);
    if( hasOverlaps ) {
      throw new IllegalStateException("There are overlaps, there is no unique state for each symbol");
    } else if( symbol.equals(State.STARTING)) {
      return getStartingStateNumber();
    } else if (symbol.equals(State.TERMINAL)) {
      return getTerminalStateNumber();
    } else if (! symbol2state.containsKey(symbol) ) {
      throw new IllegalArgumentException("This symbol number '" + symbolNum + "' has no corresponding state");
    } else {
      return symbol2state.getInt(symbol);
    }
  }
 
  public int numSymbols() {
    return taxo.numSymbols();
  }

  public String getSymbol(int symbolNum) {
    return taxo.getSymbol(symbolNum);
  }

  public int getSymbolNumber(String string) {
    return taxo.getSymbolNumber(string);
  }

  public int[] getSymbolNum2StateNum() {
    int symbol2stateN[] = new int[numSymbols()];
    for (int symbolNum = 0; symbolNum < numSymbols(); symbolNum++) {
      symbol2stateN[symbolNum] = getUniqueStateForSymbol(symbolNum);
    }
    return symbol2stateN;
  }

}
TOP

Related Classes of com.yahoo.labs.taxomo.util.StateSet

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.