/*
* Copyright (c) 2011, Yahoo! Inc. All rights reserved.
*
* Redistribution and use of this software in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* Redistributions of source code must retain the above copyright notice, this list of conditions
* and the following disclaimer.
*
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions
* and the following disclaimer in the documentation and/or other materials provided with the
* distribution.
*
* Neither the name of Yahoo! Inc. nor the names of its contributors may be used to endorse or
* promote products derived from this software without specific prior written permission of Yahoo!
* Inc.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
* WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY
* WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package com.yahoo.labs.taxomo.util;
import it.unimi.dsi.Util;
import it.unimi.dsi.fastutil.ints.Int2ObjectArrayMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import be.ac.ulg.montefiore.run.jahmm.ObservationInteger;
import com.yahoo.labs.taxomo.learn.Candidate;
import com.yahoo.labs.taxomo.util.State.Type;
/**
* A set of states, represents a "cut" in a taxonomy tree.
*
* @author chato
*
*/
public class StateSet {
static final Logger logger = Logger.getLogger(StateSet.class);
static {
Util.ensureLog4JIsConfigured();
logger.setLevel(Level.DEBUG);
}
private Taxonomy taxo;
private Int2ObjectArrayMap<State> num2state;
private final Map<String, State> str2state;
private final Object2IntOpenHashMap<State> state2num;
private State startingNode;
private State terminalNode;
private final ArrayList<String> allowedStates;
private final boolean hasOverlaps;
private final Object2IntOpenHashMap<String> symbol2state;
public StateSet(InputStream modelFile, Taxonomy taxo) throws IOException {
this( extractAllowedStates(modelFile), taxo);
}
public StateSet(File modelFile, Taxonomy taxo) throws FileNotFoundException, IOException {
this( extractAllowedStates(modelFile), taxo );
}
public StateSet(ArrayList<String> theAllowedStates, Taxonomy theTaxo) {
taxo = theTaxo;
allowedStates = theAllowedStates;
// Deal with states
num2state = new Int2ObjectArrayMap<State>();
str2state = new HashMap<String, State>();
state2num = new Object2IntOpenHashMap<State>();
symbol2state = new Object2IntOpenHashMap<String>();
// Add starting node
startingNode = new State(State.STARTING, Type.LEAF);
state2num.put(startingNode, 0);
num2state.put(0, startingNode);
str2state.put(State.STARTING, startingNode);
boolean foundOverlap = false;
// Add all allowed states
int stateNum = 1;
for (String state : allowedStates) {
if (!taxo.contains(state)) {
throw new IllegalArgumentException("The tree does not contain the state " + state);
}
State node = taxo.getNode(state);
state2num.put(node,stateNum);
num2state.put(stateNum, node);
str2state.put(state, node);
HashSet<State> descendantLeaves = node.getSelfOrDescendantLeaves();
for( State leaf: descendantLeaves ) {
if( symbol2state.containsKey(leaf.getNaturalSymbol())) {
foundOverlap = true;
} else {
symbol2state.put(leaf.getNaturalSymbol(), stateNum);
}
}
stateNum++;
}
hasOverlaps = foundOverlap;
// Add terminal node
terminalNode = new State(State.TERMINAL, Type.LEAF);
state2num.put(terminalNode, stateNum);
num2state.put(stateNum, terminalNode);
str2state.put(State.TERMINAL, terminalNode);
}
public StateSet(File taxonomyFile, ArrayList<String> allowedStates) throws IOException {
this(allowedStates, new Taxonomy(taxonomyFile));
}
public StateSet(Taxonomy aTree, Candidate candidateToTest) {
this(candidateToTest.getStates(), aTree );
}
public State getState(int stateNumber) {
if (!num2state.containsKey(stateNumber)) {
throw new IllegalArgumentException("State number " + stateNumber + " is not valid");
}
return num2state.get(stateNumber);
}
public State getState(String stateStr) {
if (!str2state.containsKey(stateStr)) {
throw new IllegalArgumentException("State string " + stateStr + " is not valid");
}
return str2state.get(stateStr);
}
public int getStateNumber(String stateStr) {
return state2num.getInt(getState(stateStr));
}
public int numStates() {
return num2state.size();
}
public static ArrayList<String> extractAllowedStates(InputStream in) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(in) );
String line;
while ((line = br.readLine()) != null) {
if (line.startsWith("s ")) {
String toks[] = line.split(" ");
ArrayList<String> states = new ArrayList<String>();
for (int i = 1; i < toks.length; i++) {
states.add(toks[i]);
}
return states;
}
}
throw new IllegalArgumentException("Could not find a line starting with ' s' in the model file");
}
public static ArrayList<String> extractAllowedStates(File in) throws FileNotFoundException, IOException {
return extractAllowedStates( new FileInputStream(in) );
}
public boolean isStartingSymbol(String str) {
return State.STARTING.equalsIgnoreCase(str);
}
public boolean isTerminalSymbol(String str) {
return State.TERMINAL.equalsIgnoreCase(str);
}
public int getStartingStateNumber() {
return getStateNumber(State.STARTING);
}
public int getTerminalStateNumber() {
return getStateNumber(State.TERMINAL);
}
public String getTaxonomyTreeFileName() {
return taxo.getFileName();
}
public String getStringAllowedStates() {
StringBuffer str = new StringBuffer();
for (int i = 0; i < allowedStates.size(); i++) {
if (i > 0) {
str.append(" ");
}
str.append(allowedStates.get(i));
}
return str.toString();
}
public ArrayList<ObservationInteger> obSeqWithBoundarySymbols(String[] seq) {
ArrayList<String> seqL = new ArrayList<String>();
for (String str : seq) {
seqL.add(str);
}
return obSeqWithBoundarySymbols(seqL);
}
public ArrayList<ObservationInteger> obSeqWithBoundarySymbols(ArrayList<String> seq) {
// Check length
if (seq.size() == 0) {
throw new IllegalArgumentException("Can not receive a zero-length sequence");
}
// Check first symbol
String firstSymbol = seq.get(seq.size() - 1);
if (isStartingSymbol(firstSymbol)) {
throw new IllegalArgumentException("Sequence can not start with the starting symbol, will add it inside the function");
}
// Check last symbol
String lastSymbol = seq.get(seq.size() - 1);
if (isTerminalSymbol(lastSymbol)) {
throw new IllegalArgumentException("Sequence can not finish in the terminal symbol, will add it inside the function");
}
// Convert
ArrayList<ObservationInteger> oSeq = new ArrayList<ObservationInteger>(seq.size()+2);
oSeq.add(new ObservationInteger(taxo.getStartingSymbolNumber()));
for (String element : seq) {
oSeq.add(new ObservationInteger(taxo.getSymbolNumber(element)));
}
oSeq.add(new ObservationInteger(taxo.getTerminalSymbolNumber()));
return oSeq;
}
public int getUniqueStateForSymbol(int symbolNum) {
String symbol = taxo.getSymbol(symbolNum);
if( hasOverlaps ) {
throw new IllegalStateException("There are overlaps, there is no unique state for each symbol");
} else if( symbol.equals(State.STARTING)) {
return getStartingStateNumber();
} else if (symbol.equals(State.TERMINAL)) {
return getTerminalStateNumber();
} else if (! symbol2state.containsKey(symbol) ) {
throw new IllegalArgumentException("This symbol number '" + symbolNum + "' has no corresponding state");
} else {
return symbol2state.getInt(symbol);
}
}
public int numSymbols() {
return taxo.numSymbols();
}
public String getSymbol(int symbolNum) {
return taxo.getSymbol(symbolNum);
}
public int getSymbolNumber(String string) {
return taxo.getSymbolNumber(string);
}
public int[] getSymbolNum2StateNum() {
int symbol2stateN[] = new int[numSymbols()];
for (int symbolNum = 0; symbolNum < numSymbols(); symbolNum++) {
symbol2stateN[symbolNum] = getUniqueStateForSymbol(symbolNum);
}
return symbol2stateN;
}
}