package bgu.bio.algorithms.alignment.constrained.re;
import bgu.bio.adt.matrix.FloatMatrix3D;
import bgu.bio.adt.matrix.IntMatrix4D;
import bgu.bio.util.ScoringMatrix;
import bgu.bio.util.alphabet.constrain.SpecificRnaAlphabet;
public class RigidConstrainedAlignmentEngine {
final static int UNCONSTRAINED=0;
final static int CONSTRAINED=1;
static final char CHAR_SEPERATOR = '!';
static final char EMPTY_LETTER = '_';
protected ScoringMatrix score;
protected SpecificRnaAlphabet constAlphabet;
protected char[] s1;
String S1, S2;
protected char[] s2;
protected int s1Length;
protected int s2Length;
protected FloatMatrix3D dpTable;
protected String[] pattern;
// protected String[] parsedPattern;
protected int[] patternY;
protected int[] patternX;
// Tracing
protected final boolean trace;
protected IntMatrix4D dpTrace; // a trace table in the size of [n][m][t][3]
protected int DPcounter;
public RigidConstrainedAlignmentEngine(ScoringMatrix score,
String[] rigidPatterns,
SpecificRnaAlphabet constAlphabet,
boolean trace) {
this.score = score;
this.constAlphabet = constAlphabet;
this.trace = trace;
this.pattern = rigidPatterns;
}
public RigidConstrainedAlignmentEngine(ScoringMatrix scoring,
String[] rigidPatterns,
SpecificRnaAlphabet coAlphabet,
int dimension1Size,
int dimension2Size,
boolean trace) {
this(scoring, rigidPatterns, coAlphabet, trace);
rebuildTheTables(dimension1Size, dimension2Size);
calculatePatternDimensions();
}
private void calculatePatternDimensions() {
patternX = new int[pattern.length];
patternY = new int[pattern.length];
for(int pIdx=0;pIdx<this.pattern.length; pIdx++){
for (int j = 0; j<pattern[pIdx].length(); j++){
if (pattern[pIdx].charAt(j)=='0' || pattern[pIdx].charAt(j)=='1'
|| pattern[pIdx].charAt(j)=='m'|| pattern[pIdx].charAt(j)=='s'){
patternY[pIdx]++;
patternX[pIdx]++;
}
else if ((pattern[pIdx].charAt(j)==CHAR_SEPERATOR)
&& (pattern[pIdx].charAt(j-1)!=EMPTY_LETTER)
&& (pattern[pIdx].charAt(j+1)!=EMPTY_LETTER)){
patternX[pIdx]++;
patternY[pIdx]++;
}
else if ((pattern[pIdx].charAt(j)=='i') ||
((pattern[pIdx].charAt(j)==CHAR_SEPERATOR) && (pattern[pIdx].charAt(j-1)==EMPTY_LETTER)))
patternX[pIdx]++;
else if ((pattern[pIdx].charAt(j)=='d') ||
((pattern[pIdx].charAt(j)==CHAR_SEPERATOR) && (pattern[pIdx].charAt(j+1)==EMPTY_LETTER)))
patternY[pIdx]++;
}
// System.out.println("pattern "+pattern[pIdx]+ " dimensions "+ patternX[pIdx]+","+patternY[pIdx] );
}
}
/**
* Rebuild the tables for the engine.
* @param newDim1 the new first dimension size
* @param newDim2 the new second dimension size
* @return true, if the tables are rebuilt and false otherwise
*/
private boolean rebuildTheTables(int newDim1, int newDim2) {
if (this.dpTable == null || newDim1 > this.dpTable.getDimensionSize(0)
|| newDim2 > this.dpTable.getDimensionSize(1)) {
this.dpTable = new FloatMatrix3D(newDim1, newDim2, 2);
// last dimension is non-accept/accept.
if (this.trace) {
this.dpTrace = new IntMatrix4D(newDim1, newDim2,2, 3);
}
return true;
}
return false;
}
/** dynamic programming */
public void align(String s1, String s2){
this.DPcounter=0;
this.s1 = s1.toCharArray();
this.s2 = s2.toCharArray();
S1 = s1;
S2 = s2;
s1Length = s1.length()+1;
s2Length = s2.length()+1;
this.rebuildTheTables(s1Length, s2Length);
init();
int patternXmin=patternX[0];
int patternYmin=patternY[0];
for(int pIdx=0;pIdx<this.pattern.length; pIdx++){
if (patternX[pIdx]<patternXmin) patternXmin=patternX[pIdx];
if (patternY[pIdx]<patternYmin) patternYmin=patternY[pIdx];
}
//Run on all the cells
for(int pIdx=0;pIdx<this.pattern.length; pIdx++){
for (int j=0; j < s2Length-patternX[pIdx]; j++){
for (int i=0; i < s1Length-patternY[pIdx]; i++){
calculateSingleCellUnConstrained(i,j);
}
}
}
for (int i=patternYmin; i < s1Length; i++){
for(int pIdx=0;pIdx<this.pattern.length; pIdx++){
calculateSingleCellConstrained(i,this.s2Length-1,pIdx);
}
}
// System.out.println("finished constrained ");
}
/// semi local
private void calculateSingleCellUnConstrained(int i, int j) {
if(i>0 || j>0){
// Replace
if ((i!=0) && (j!=0)) {
this.calculateByOtherCellUnConstrained(i, j, 1 , 1, s1[i-1], s2[j-1]);
}
// Delete
if (i!=0) {
this.calculateByOtherCellUnConstrained(i, j, 1, 0, s1[i-1], this.score.getAlphabet().emptyLetter());
}
// Insert
if (j!=0) {
this.calculateByOtherCellUnConstrained(i, j, 0, 1, this.score.getAlphabet().emptyLetter(),s2[j-1]);
}
// float val = dpTable.get(i, j, UNCONSTRAINED);
// if (val!=Float.NEGATIVE_INFINITY)
// System.out.println("Unconstrained ("+i+","+j+")="+val);
}
}
private void init() {
for (int i=0; i < this.s1Length; i++){
//initialize first row
// the semi local allows zeros in the first row
dpTable.set(Float.NEGATIVE_INFINITY,i,0,CONSTRAINED);
dpTable.set(0.0f,i,0,UNCONSTRAINED);
for (int j=1; j < this.s2Length; j++){
//initialize each cell data
dpTable.set(Float.NEGATIVE_INFINITY,i,j,CONSTRAINED);
dpTable.set(Float.NEGATIVE_INFINITY,i,j,UNCONSTRAINED);
}
}
}
public double getOptimalScore() {
double optimalScore=Double.NEGATIVE_INFINITY;
// Find final state in the last column with maximal score
for (int i=0; i< this.s1Length; i++) {
double temp = dpTable.get(i,this.s2Length-1, CONSTRAINED);
// System.out.println("looking for optimal in "+(this.s1Length-1)+","+j+":"+temp);
if ((!Double.isNaN(temp)) && (optimalScore < temp)) {
optimalScore=temp;
}
}
return optimalScore;
}
private void calculateSingleCellConstrained(int it, int jt, int pIdx) {
// System.out.println("CalculateSingleCellConstrained ("+it+","+jt+"), pattern "+pIdx);
// check if pattern fits
int x = this.patternX[pIdx];
int y = this.patternY[pIdx];
float patternScore = patternScore(it,jt,pIdx);
if (patternScore!=Float.NEGATIVE_INFINITY){
// System.out.println(" got score "+patternScore);
float val = dpTable.get(it - y, jt -x, UNCONSTRAINED);
// System.out.println("checking entry "+ (it - y)+", "+(jt -x)+": "+val+", pattern score: "+ patternScore+" pattern "+ pIdx );
if ((val!=Float.NEGATIVE_INFINITY) &&
(val+patternScore > dpTable.get(it, jt, CONSTRAINED))){
// System.out.println("update constrained "+it+","+jt+" to "+(val+patternScore));
dpTable.set(val+patternScore,it, jt, CONSTRAINED);
DPcounter++;
// Update optimal in trace table
if (trace){
dpTrace.set(it - y ,it,jt,CONSTRAINED,0);
dpTrace.set(jt - x ,it,jt,CONSTRAINED,1);
dpTrace.set(UNCONSTRAINED ,it,jt,CONSTRAINED,2);
}
}
}
}
private float patternScore(int i, int j, int pIdx){
String pat = this.pattern[pIdx];
// System.out.println("matching "+S1.substring(Math.max(0,i-patternY[pIdx]), i)+" and "+ S2.substring(Math.max(0,j-patternX[pIdx]), j));
float s = 0;
for (int k=pat.length()-1; k>=0;k--){
char l = pat.charAt(k);
// System.out.println("Got pattern letter "+ l);
if (l=='0' || l=='s'){
if (i>0 && j>0 && (this.s1[i-1]!=this.s2[j-1])){
s += score.score(this.s1[i-1], this.s2[j-1]);
i--;
j--;
// System.out.println("Fit s");
}
else{
return Float.NEGATIVE_INFINITY;
}
}
else if (l=='1' || l=='m'){
// System.out.println("fitting "+this.s1[i-1]+" and "+ this.s2[j-1] );
if (i>0 && j>0 && (this.s1[i-1]==this.s2[j-1])){
s += score.score(this.s1[i-1], this.s2[j-1]);
i--;
j--;
// System.out.println("Fit m");
}
else{
return Float.NEGATIVE_INFINITY;
}
}
else if (l=='i'){
if (j>0){
s += score.score(EMPTY_LETTER, this.s2[j-1]);
j--;
}
else{
return Float.NEGATIVE_INFINITY;
}
// System.out.println("Fit i");
}
else if (l=='d'){
if (i>0){
s += score.score(this.s1[i-1], EMPTY_LETTER);
i--;
}
else{
return Float.NEGATIVE_INFINITY;
}
// System.out.println("Fit d");
}
else if (l==')'){
char l2 = pat.charAt(k-1);
char l1 = pat.charAt(k-3);
k=k-4; // skip parenthesis
if ((l1!=EMPTY_LETTER) && (l2!=EMPTY_LETTER)){
if (i>0 && j>0 && l1==s1[i-1] && l2==s2[j-1]){
s += score.score(this.s1[i-1], this.s2[j-1]);
i--;
j--;
}
else{
return Float.NEGATIVE_INFINITY;
}
}
else if (l1==EMPTY_LETTER) { // insert
if (j>0 && l2==s2[j-1]){
s += score.score(EMPTY_LETTER, this.s2[j-1]);
j--;
}
else{
return Float.NEGATIVE_INFINITY;
}
}
else if (l2==EMPTY_LETTER) { // delete
if (i>0 && l1==s1[i-1]){
s += score.score(this.s1[i-1], EMPTY_LETTER);
i--;
}
else{
return Float.NEGATIVE_INFINITY;
}
}
}
}
return s;
}
private final void calculateByOtherCellUnConstrained(int i,int j,int iDiff,int jDiff,char s1Char,char s2Char){
float actionScore = score.score(s1Char, s2Char);
final float val = dpTable.get(i - iDiff, j -jDiff, UNCONSTRAINED) + actionScore;
if (val > dpTable.get(i, j, UNCONSTRAINED)){
// System.out.println("update unconstrained ("+i+","+j+") from "+ dpTable.get(i, j, UNCONSTRAINED)+" to "+val);
dpTable.set(val,i, j, UNCONSTRAINED);
DPcounter++;
// Update optimal in trace table
if (trace){
dpTrace.set(i - iDiff ,i,j,UNCONSTRAINED,0);
dpTrace.set(j - jDiff ,i,j,UNCONSTRAINED,1);
dpTrace.set(UNCONSTRAINED ,i,j,UNCONSTRAINED,2);
}
}
}
public int getDPcounter(){
return DPcounter;
}
}