package bgu.bio.algorithms.graphs.hsa;
import gnu.trove.list.array.TIntArrayList;
import java.util.Arrays;
import bgu.bio.adt.graphs.Tree;
import bgu.bio.algorithms.graphs.CostFunction;
import bgu.bio.algorithms.graphs.hsa.matchers.BipartiteCavityMatcher;
import bgu.bio.algorithms.graphs.hsa.matchers.MatcherFactory;
import bgu.bio.util.MathOperations;
public class HSA {
private static final int FIRST_TIME = 0;
private static final int SECOND_TIME = 1;
protected CostFunction costFunction;
protected MatcherFactory factory;
protected BipartiteCavityMatcher[][] matchers;
protected final boolean useCavity;
protected long computationTime;
protected long totalTime;
private int[] tEdgeEncounters;
private int[] sEdgeEncounters;
public HSA(CostFunction w, MatcherFactory factory) {
this(w, factory, true);
}
public HSA(CostFunction w, MatcherFactory factory, boolean useCavity) {
this.costFunction = w;
this.factory = factory;
this.useCavity = useCavity;
}
public HSA(int initialSize1, int initialSize2, CostFunction e,
MatcherFactory factory) {
this(initialSize1, initialSize2, e, factory, true);
}
public HSA(int initialSize1, int initialSize2, CostFunction e,
MatcherFactory factory, boolean useCavity) {
this(e, factory, useCavity);
this.matchers = new BipartiteCavityMatcher[initialSize1][initialSize2];
}
public double computeHSA(Tree t, Tree s) {
return computeHSA(t, s, null);
}
public double computeHSA(Tree t, Tree s, TIntArrayList[] alignment) {
computeSubtreeHSA(t, s);
return computeHSAFromMatchers(t, s, alignment, false);
}
public double computeHSA(Tree t) {
computeSubtreeHSA(t, t);
return computeHSAFromMatchers(t, t, null, true);
}
public double computeHSA(Tree t, TIntArrayList[] alignment) {
computeSubtreeHSA(t, t);
return computeHSAFromMatchers(t, t, alignment, true);
}
protected double computeHSAFromMatchers(Tree t, Tree s,
TIntArrayList[] alignment, boolean selfSimitry) {
double cost = MathOperations.INFINITY, currCost;
int tRoot = -1, sRoot = -1;
if (!selfSimitry) {
for (int i = 0; i < t.getNodeNum(); ++i) {
for (int j = 0; j < s.getNodeNum(); ++j) {
currCost = matchers[i][j].minCostMatching();
if (MathOperations.greater(cost, currCost)) {
cost = currCost;
tRoot = i;
sRoot = j;
}
}
}
} else {
for (int i = 0; i < t.getNodeNum(); ++i) {
int amount = t.outDeg(i);
for (int n = 0; n < amount; n++) {
int j = t.getNeighbor(i, n);
currCost = matchers[i][j].getMatchCostXY(n,t.getNeighborIx(j, i));
if (MathOperations.greater(cost, currCost)) {
cost = currCost;
tRoot = i;
sRoot = j;
}
}
}
}
if (alignment != null) {
resetTracebackLists(alignment);
// tracing back an optimal alignment:
if (tRoot != -1) {
alignment[0].add(tRoot);
alignment[1].add(sRoot);
alignment[2].add(0);
BipartiteCavityMatcher matcher = matchers[tRoot][sRoot];
matcher.setMinCostFullMatching(Double.NaN);
matcher.minCostMatching();
TIntArrayList[] matching = matcher.getCurrMatching();
int nextGroupId = 1;
for (int pair = 0; pair < matching[0].size(); ++pair) {
nextGroupId = traceback(t, s, alignment, tRoot,
matching[0].get(pair), sRoot,
matching[1].get(pair), nextGroupId);
}
}
}
computationTime = System.currentTimeMillis() - computationTime;
for (int i = 0; i < t.getNodeNum(); ++i) {
for (int j = 0; j < s.getNodeNum(); ++j) {
matchers[i][j].retain();
matchers[i][j] = null;
}
}
totalTime = System.currentTimeMillis() - totalTime;
return cost;
}
/**
* @param alignment
*/
private void resetTracebackLists(TIntArrayList[] alignment) {
if (alignment[0] == null)
alignment[0] = new TIntArrayList();
else
alignment[0].resetQuick();
if (alignment[1] == null)
alignment[1] = new TIntArrayList();
else
alignment[1].resetQuick();
if (alignment[2] == null)
alignment[2] = new TIntArrayList();
else
alignment[2].resetQuick();
}
private int traceback(Tree t, Tree s, TIntArrayList[] alignment, int tFrom,
int tToNeighborIx, int sFrom, int sToNeighborIx, int nextGroupIx) {
double cost = matchers[tFrom][sFrom].getMatchCostXY(tToNeighborIx,
sToNeighborIx);
int tTo = t.outAdjLists[tFrom][tToNeighborIx];
int sTo = s.outAdjLists[sFrom][sToNeighborIx];
int tFromNeighborIx = t.getNeighborIx(tTo, tFrom);
int sFromNeighborIx = s.getNeighborIx(sTo, sFrom);
BipartiteCavityMatcher matcher = matchers[tTo][sTo];
if (MathOperations.equals(cost,
matcher.minCostCavityMatching(tFromNeighborIx, sFromNeighborIx)
+ costFunction.cost(t.labels[tTo], s.labels[sTo]))) {
alignment[0].add(tTo);
alignment[1].add(sTo);
alignment[2].add(nextGroupIx);
TIntArrayList[] matching = matcher.getCavityMatching(
tFromNeighborIx, sFromNeighborIx);
if (matching[0].size() != 1 || t.outDeg(tTo) != 2
|| s.outDeg(sTo) != 2) {
++nextGroupIx;
}
for (int pair = 0; pair < matching[0].size(); ++pair) {
nextGroupIx = traceback(t, s, alignment, tTo,
matching[0].get(pair), sTo, matching[1].get(pair),
nextGroupIx);
}
return nextGroupIx;
}
double pruneAll = -matcher.safeDelCostX(tFromNeighborIx);
// -t.weights[tTo][tFromNeighborIx];
for (int u = 0; u < t.outAdjLists[tTo].length; ++u) {
pruneAll += matcher.safeDelCostX(u);
// t.weights[tTo][u];
}
double initWeight = pruneAll + t.smoothCost[tTo];
for (int u = 0; u < t.outAdjLists[tTo].length; ++u) {
if (u != tFromNeighborIx) {
if (MathOperations.equals(cost, initWeight
+ matchers[tTo][sFrom].getMatchCostXY(u, sToNeighborIx)
- matcher.safeDelCostX(u))) { // t.weights[tTo][u])) {
if (t.outDeg(tTo) != 2) {
++nextGroupIx;
}
nextGroupIx = traceback(t, s, alignment, tTo, u, sFrom,
sToNeighborIx, nextGroupIx);
return nextGroupIx;
}
}
}
pruneAll = -matcher.safeDelCostY(sFromNeighborIx);// -s.weights[sTo][sFromNeighborIx];
for (int u = 0; u < s.outAdjLists[sTo].length; ++u) {
pruneAll += matcher.safeDelCostY(u);// s.weights[sTo][u];
}
initWeight = pruneAll + s.smoothCost[sTo];
for (int u = 0; u < s.outAdjLists[sTo].length; ++u) {
if (u != sFromNeighborIx) {
if (MathOperations.equals(cost, initWeight
+ matchers[tFrom][sTo].getMatchCostXY(tToNeighborIx, u)
- matcher.safeDelCostY(u))) { // s.weights[sTo][u])) {
if (s.outDeg(sTo) != 2) {
++nextGroupIx;
}
nextGroupIx = traceback(t, s, alignment, tFrom,
tToNeighborIx, sTo, u, nextGroupIx);
return nextGroupIx;
}
}
}
throw new RuntimeException("Traceback error!");
}
/**
* @param t
* @param s
*/
protected void computeSubtreeHSA(Tree t, Tree s) {
totalTime = System.currentTimeMillis();
if (this.matchers == null || this.matchers.length < t.getNodeNum()
|| this.matchers[0].length < s.getNodeNum()) {
int tSize = t.getNodeNum();
int sSize = s.getNodeNum();
if (matchers != null) {
tSize = Math.max(matchers.length, t.getNodeNum());
sSize = Math.max(matchers[0].length, s.getNodeNum());
}
this.matchers = new BipartiteCavityMatcher[tSize][sSize];
}
BipartiteCavityMatcher matcher;
// initiation of the matchers
for (int i = 0; i < t.getNodeNum(); i++) {
for (int j = 0; j < s.getNodeNum(); j++) {
final int degI = t.outDeg(i);
final int degJ = s.outDeg(j);
matcher = factory.make(
degI,
degJ,
!MathOperations.isInfinity(costFunction.cost(
t.getLabel(i), s.getLabel(j))));
matchers[i][j] = matcher;
for (int u = 0; u < degI; u++) {
matcher.setDelCostX(u, t.weights[i][u]);
}
for (int u = 0; u < degJ; u++) {
matcher.setDelCostY(u, s.weights[j][u]);
}
}
}
if (tEdgeEncounters == null || tEdgeEncounters.length < t.getNodeNum()) {
tEdgeEncounters = new int[t.getNodeNum()];
}
Arrays.fill(tEdgeEncounters, 0, t.getNodeNum(), FIRST_TIME);
if (sEdgeEncounters == null || sEdgeEncounters.length < s.getNodeNum()) {
sEdgeEncounters = new int[s.getNodeNum()];
}
computationTime = System.currentTimeMillis();
double currScore, tmpScore; // , pruneAll
for (int et = 0; et < t.getEdgeNum(); et++) {
int fromT = t.edgeToNeighbors[et][0];
int toTNeighborIx = t.edgeToNeighbors[et][1];
int toT = t.outAdjLists[fromT][toTNeighborIx];
int fromTNeighborIx = t.getNeighborIx(toT, fromT);
Arrays.fill(sEdgeEncounters, 0, s.getNodeNum(), FIRST_TIME);
for (int es = 0; es < s.getEdgeNum(); es++) {
int fromS = s.edgeToNeighbors[es][0];
int toSNeighborIx = s.edgeToNeighbors[es][1];
int toS = s.outAdjLists[fromS][toSNeighborIx];
int fromSNeighborIx = s.getNeighborIx(toS, fromS);
matcher = matchers[toT][toS];
if (useCavity) {
if (sEdgeEncounters[toS] == SECOND_TIME) {
if (tEdgeEncounters[toT] == FIRST_TIME) {
matcher.processAllCavityMatchingY(fromTNeighborIx);
} else if (tEdgeEncounters[toT] == SECOND_TIME) {
matcher.processAllPairsCavityMatching();
}
}
}
final BipartiteCavityMatcher matcherToFrom = matchers[toT][fromS];
double cavityPrune = 0;
boolean foundForced = false;
// IMPROTANT: we assume that at most one forced match exists!!!
tmpScore = MathOperations.INFINITY;
final int tNeighborsSize = t.outAdjLists[toT].length;
for (int u = 0; u < tNeighborsSize; ++u) {
if (u != fromTNeighborIx) {
if (!MathOperations.isInfinity(t.weights[toT][u])) {
cavityPrune += t.weights[toT][u];
final double value = matcherToFrom.getMatchCostXY(u,
toSNeighborIx) - t.weights[toT][u];
if (MathOperations.greater(
tmpScore,
value)
&& !foundForced) {
tmpScore = value;
}
}
else {
if (foundForced) {
throw new RuntimeException(
"more than two forced matches!");
}
tmpScore = matcherToFrom.getMatchCostXY(u,
toSNeighborIx);
foundForced = true;
}
}
}
currScore = tmpScore + cavityPrune + t.smoothCost[toT];
final BipartiteCavityMatcher matcherFromTo = matchers[fromT][toS];
tmpScore = MathOperations.INFINITY;
foundForced = false;
cavityPrune = 0;
// pruneAll = 0;
final int sNeighborsSize = s.outAdjLists[toS].length;
for (int u = 0; u < sNeighborsSize; ++u) {
if (u != fromSNeighborIx) {
if (!MathOperations.isInfinity(s.weights[toS][u])) {
cavityPrune += s.weights[toS][u];
final double value = matcherFromTo.getMatchCostXY(toTNeighborIx,
u) - s.weights[toS][u];
if (MathOperations.greater(
tmpScore,
value)
&& !foundForced) {
tmpScore = value;
}
// pruneAll += s.weights[toS][u];
} else {
if (foundForced) {
throw new RuntimeException(
"more than two forced matches!");
}
tmpScore = matcherFromTo.getMatchCostXY(
toTNeighborIx, u);
foundForced = true;
}
}
}
// }
tmpScore += cavityPrune + s.smoothCost[toS];
if (MathOperations.greater(currScore, tmpScore)) {
currScore = tmpScore;
}
// Best matching of toT and toS:
double matchCost = matcher.minCostCavityMatching(
fromTNeighborIx, fromSNeighborIx)
+ costFunction.cost(t.labels[toT], s.labels[toS]);
if (MathOperations.greater(currScore, matchCost)) {
currScore = matchCost;
}
matchers[fromT][fromS].setMatchCost(toTNeighborIx,
toSNeighborIx, currScore);
++sEdgeEncounters[toS];
}
++tEdgeEncounters[toT];
}
}
public CostFunction getCostFunction() {
return costFunction;
}
public void setCostFunction(CostFunction costFunction) {
this.costFunction = costFunction;
}
public MatcherFactory getFactory() {
return factory;
}
public void setFactory(MatcherFactory factory) {
this.factory = factory;
}
public long getComputationTime() {
return computationTime;
}
public long getTotalTime() {
return totalTime;
}
}