/*
* Copyright (c) 2011, Yahoo! Inc. All rights reserved.
*
* Redistribution and use of this software in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* Redistributions of source code must retain the above copyright notice, this list of conditions
* and the following disclaimer.
*
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions
* and the following disclaimer in the documentation and/or other materials provided with the
* distribution.
*
* Neither the name of Yahoo! Inc. nor the names of its contributors may be used to endorse or
* promote products derived from this software without specific prior written permission of Yahoo!
* Inc.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
* WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY
* WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package com.yahoo.labs.taxomo.util;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import it.unimi.dsi.io.FileLinesCollection;
import it.unimi.dsi.io.FileLinesCollection.FileLinesIterator;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.Charset;
import java.util.HashSet;
/**
* Sparse representation of a symbol transition frequency matrix, as observed in a set of input sequences.
* @author chato
*/
public class SymbolTransitionFrequencies {
private Taxonomy tree;
boolean frozen = false;
public static class Coordinate {
public static final int MAX_STATES = 1000000;
public int row;
public int col;
public Coordinate(int aRow, int aCol) {
row = aRow;
col = aCol;
}
public int hashCode() {
return row + (MAX_STATES * col);
}
public boolean equals(Object other) {
return( row == ((Coordinate)other).row && col == ((Coordinate)other).col );
}
}
final private int[] symbolFrequency;
final private Object2IntOpenHashMap<Coordinate> symbolTransitionFrequency;
public SymbolTransitionFrequencies(Taxonomy theTree) {
tree = theTree;
if( tree.numSymbols() > Coordinate.MAX_STATES ) {
throw new IllegalArgumentException("Can handle only up to " + Coordinate.MAX_STATES + " states");
}
symbolFrequency = new int[tree.numSymbols()];
symbolTransitionFrequency = new Object2IntOpenHashMap<Coordinate>();
}
public void processStream(InputStream in) throws IOException {
if( frozen ) {
throw new IllegalArgumentException("This table was already built");
}
BufferedReader br = new BufferedReader( new InputStreamReader(in) );
String line;
while ( (line = br.readLine()) != null ) {
processString(line);
}
freeze();
}
public void processFile(File inputFile) {
if( frozen ) {
throw new IllegalArgumentException("This table was already built");
}
FileLinesCollection flc = new FileLinesCollection(inputFile.getPath(), Charset.forName("UTF-8").name());
FileLinesIterator it = flc.iterator();
while (it.hasNext()) {
String line = it.next().toString();
processString(line);
}
freeze();
}
public void processString(String line) {
if( frozen ) {
throw new IllegalArgumentException("This table was already built");
}
int[] sequence = tree.parseSymbolSequenceToSymbolNums(line);
for (int i = 0; i < sequence.length; i++) {
symbolFrequency[sequence[i]]++;
if (i > 0) {
Coordinate coord = new Coordinate(sequence[i-1],sequence[i]);
int cnt = 0;
if( symbolTransitionFrequency.containsKey(coord) ) {
cnt = symbolTransitionFrequency.getInt(coord);
}
cnt++;
symbolTransitionFrequency.put(coord, cnt);
}
}
}
public void freeze() {
frozen = true;
}
public boolean isFrozen() {
return frozen;
}
public int frequency(int symbolNum) {
return symbolFrequency[symbolNum];
}
public int transitionFrequency(int symbolSrc, int symbolDest ) {
Coordinate coord = new Coordinate(symbolSrc, symbolDest);
if( symbolTransitionFrequency.containsKey(coord) ) {
return symbolTransitionFrequency.getInt(coord);
} else {
return 0;
}
}
public void fillStateTransitionFrequencies(int[][] stateTransitionFrequencies, int[] symbol2state) {
for( Coordinate coord: symbolTransitionFrequency.keySet() ) {
stateTransitionFrequencies[symbol2state[coord.row]][symbol2state[coord.col]] += symbolTransitionFrequency.getInt(coord);
}
}
public void fillStateTransitionFrequencies(int[][] stateTransitionFrequencies, int[] symbol2state1, int[] symbol2state2, HashSet<Coordinate> usedStateCoordinates ) {
for( Coordinate coord: symbolTransitionFrequency.keySet() ) {
if( usedStateCoordinates.contains( new Coordinate(symbol2state1[coord.row], symbol2state1[coord.col]) )) {
stateTransitionFrequencies[symbol2state1[coord.row]][symbol2state1[coord.col]] += symbolTransitionFrequency.getInt(coord);
}
if( (symbol2state1[coord.row] != symbol2state2[coord.row]) || (symbol2state1[coord.col] != symbol2state2[coord.col]) ) {
if( usedStateCoordinates.contains( new Coordinate(symbol2state2[coord.row], symbol2state2[coord.col]) )) {
stateTransitionFrequencies[symbol2state2[coord.row]][symbol2state2[coord.col]] += symbolTransitionFrequency.getInt(coord);
}
}
}
}
public String toString() {
StringBuffer sb = new StringBuffer();
for( int symbol = 0; symbol < tree.numSymbols(); symbol++ ) {
sb.append( "SymbolFreq(" + symbol + " " + tree.getSymbol(symbol) + ") =" + frequency(symbol) + "\n" );
}
for( int symbol1 = 0; symbol1 < tree.numSymbols(); symbol1++ ) {
for( int symbol2 = 0; symbol2 < tree.numSymbols(); symbol2++ ) {
sb.append( "SymbolTranFreq(" + symbol1 + " " + tree.getSymbol(symbol1) + "," + symbol2 + " " + tree.getSymbol(symbol2) + ") = " + transitionFrequency(symbol1, symbol2) + "\n");
}
}
return sb.toString();
}
}