/*
* Copyright (c) 2011, Yahoo! Inc. All rights reserved.
*
* Redistribution and use of this software in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* Redistributions of source code must retain the above copyright notice, this list of conditions
* and the following disclaimer.
*
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions
* and the following disclaimer in the documentation and/or other materials provided with the
* distribution.
*
* Neither the name of Yahoo! Inc. nor the names of its contributors may be used to endorse or
* promote products derived from this software without specific prior written permission of Yahoo!
* Inc.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
* WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY
* WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package com.yahoo.labs.taxomo.util;
import it.unimi.dsi.fastutil.ints.Int2ObjectArrayMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Vector;
import com.yahoo.labs.taxomo.util.State.Type;
/**
* Represents a taxonomy of states.
* @author chato
*
*/
public class Taxonomy {
private State root;
private ArrayList<State> leaves;
private HashMap<String,State> nodes;
private final String DEFAULT_TAXONOMY_SEPARATOR = "/";
private final String fileName;
private final Int2ObjectArrayMap<String> num2symbol;
private final Object2IntOpenHashMap<String> symbol2num;
public Taxonomy( String fileName, InputStream inStream ) throws IOException {
this.fileName = fileName;
BufferedReader br = new BufferedReader(new InputStreamReader(inStream));
Vector<String[]> taxonomyDesc = new Vector<String[]>();
String line;
while ((line = br.readLine()) != null) {
taxonomyDesc.add(line.split(DEFAULT_TAXONOMY_SEPARATOR));
}
loadTaxonomy(taxonomyDesc);
// Verify degree in taxonomy
verifyTaxonomy();
// Deal with symbols
symbol2num = new Object2IntOpenHashMap<String>();
num2symbol = new Int2ObjectArrayMap<String>();
// Add leaves as symbols
symbol2num.put(State.STARTING, 0);
num2symbol.put(0, State.STARTING);
for (int i = 0; i < leaves.size(); i++) {
symbol2num.put(leaves.get(i).getNaturalSymbol(), i + 1);
num2symbol.put(i + 1, leaves.get(i).getNaturalSymbol());
}
symbol2num.put(State.TERMINAL, leaves.size() + 1);
num2symbol.put(leaves.size() + 1, State.TERMINAL);
}
public Taxonomy( File inputFile ) throws IOException {
this( inputFile.getName(), new FileInputStream(inputFile) );
}
private void verifyTaxonomy() {
for( State node: nodes.values() ) {
if( node.getDegree() == 1 ) {
throw new IllegalArgumentException( "Node " + node + " in the taxonomy has degree one." );
}
}
}
private void loadTaxonomy( Vector<String[]> taxonomyDesc ) {
if( taxonomyDesc.size() == 0 ) {
throw new IllegalArgumentException( "Empty taxonomy" );
}
leaves = new ArrayList<State>( taxonomyDesc.size() );
nodes = new HashMap<String,State>();
for( String[] path: taxonomyDesc ) {
if( path.length == 0 ) {
throw new IllegalArgumentException( "Empty branch" );
}
if( root == null ) {
root = new State(path[0], Type.INTERNAL);
nodes.put(root.name(), root);
} else {
if( ! path[0].equalsIgnoreCase(root.name()) ) {
throw new IllegalArgumentException( "A path starts with a root that is not " + root.name() );
}
}
State node = root;
for( int i=1; i<path.length; i++ ) { // Start from 1, element 0 is always root
if( node.hasChild(path[i]) ) {
node = node.getChild(path[i]);
} else {
Type type = ( i == path.length-1 ) ? Type.LEAF : Type.INTERNAL;
State newNode = new State( path[i], type );
node.addChild(newNode);
if( ! nodes.containsKey(newNode.name() ) ) {
nodes.put( newNode.name(), newNode );
}
if( newNode.isLeaf() ) {
leaves.add(newNode);
}
node = newNode;
}
}
}
}
public int numNodes() {
return nodes.size();
}
public int numLeaves() {
return leaves.size();
}
public boolean contains(String state) {
return nodes.containsKey(state);
}
public State getNode(String state) {
return nodes.get(state);
}
public State[] getLeaves() {
return leaves.toArray( new State[] { } );
}
public State getLeafForSymbol(String symbol) {
for( State leaf: leaves ) {
if( leaf.getNaturalSymbol().equals(symbol) ) {
return leaf;
}
}
throw new IllegalArgumentException("Could not find any leaf matching symbol '" + symbol + "'");
}
public String getFileName() {
return fileName;
}
public String getSymbol(int i) {
if (!num2symbol.containsKey(i)) {
throw new IllegalArgumentException("Wrong symbol number '" + i + "'");
} else {
return num2symbol.get(i);
}
}
public int getSymbolNumber(String str) {
if (!symbol2num.containsKey(str)) {
throw new IllegalArgumentException("Wrong symbol name '" + str + "'");
} else {
return symbol2num.getInt(str);
}
}
public int getStartingSymbolNumber() {
return getSymbolNumber(State.STARTING);
}
public int getTerminalSymbolNumber() {
return getSymbolNumber(State.TERMINAL);
}
public int numSymbols() {
return num2symbol.size();
}
public int[] parseSymbolSequenceToSymbolNums(String line) {
String[] tokens = line.split(" ");
if (tokens[0].equals(State.STARTING)) {
throw new IllegalArgumentException("Can not start input sequence with starting symbol");
}
if (tokens[tokens.length - 1].equals(State.TERMINAL)) {
throw new IllegalArgumentException("Can not end input sequence with starting symbol");
}
int[] seq = new int[tokens.length + 2];
seq[0] = getStartingSymbolNumber();
for (int i = 0; i < tokens.length; i++) {
seq[i + 1] = getSymbolNumber(tokens[i]);
}
seq[tokens.length+1] = getTerminalSymbolNumber();
return seq;
}
public ArrayList<String> parseSymbolSequenceToSymbolStrings(String line) {
String[] tokens = line.split(" ");
if (tokens[0].equals(State.STARTING)) {
throw new IllegalArgumentException("Can not start input sequence with starting symbol");
}
if (tokens[tokens.length - 1].equals(State.TERMINAL)) {
throw new IllegalArgumentException("Can not end input sequence with starting symbol");
}
ArrayList<String> seq = new ArrayList<String>(tokens.length + 2);
seq.add(0, getSymbol(getStartingSymbolNumber()));
for (int i = 0; i < tokens.length; i++) {
if( ! symbol2num.containsKey(tokens[i]) ) {
throw new IllegalArgumentException( "Input sequence contains this symbol '" + tokens[i] + "' which is unknown");
}
seq.add(i + 1, tokens[i]);
}
seq.add(tokens.length+1, getSymbol(getTerminalSymbolNumber()));
return seq;
}
/**
* Gets the root state of the taxonomy
* @return The root state
*/
public State getRootState() {
return root;
}
/**
* Gets all states at a given level
* @param targetLevel The target level (0=root)
* @return The states at that level
*/
public State[] getLevel(int targetLevel) {
Vector<State> allLevel = new Vector<State>();
LinkedList<State> toVisit = new LinkedList<State>();
Object2IntOpenHashMap<State> node2level = new Object2IntOpenHashMap<State>();
toVisit.add(root);
node2level.put(root,0);
while( ! toVisit.isEmpty() ) {
State node = toVisit.poll();
int nodeLevel = node2level.getInt(node);
if( nodeLevel < targetLevel ) {
if( node.isLeaf() ) {
allLevel.add(node);
} else {
for( State child: node.getChildren() ) {
toVisit.add(child);
node2level.put(child, nodeLevel+1);
}
}
} else if( nodeLevel == targetLevel ) {
allLevel.add(node);
} else {
throw new IllegalStateException("This visit should not have nodes at a deeper level than the target level");
}
}
return new ArrayList<State>(allLevel).toArray( new State[] { });
}
}