Package bgu.bio.algorithms.graphs.hsa

Source Code of bgu.bio.algorithms.graphs.hsa.HSA

package bgu.bio.algorithms.graphs.hsa;

import gnu.trove.list.array.TIntArrayList;

import java.util.Arrays;

import bgu.bio.adt.graphs.Tree;
import bgu.bio.algorithms.graphs.CostFunction;
import bgu.bio.algorithms.graphs.hsa.matchers.BipartiteCavityMatcher;
import bgu.bio.algorithms.graphs.hsa.matchers.MatcherFactory;
import bgu.bio.util.MathOperations;

public class HSA {

  private static final int FIRST_TIME = 0;
  private static final int SECOND_TIME = 1;

  protected CostFunction costFunction;
  protected MatcherFactory factory;
  protected BipartiteCavityMatcher[][] matchers;
  protected final boolean useCavity;
  protected long computationTime;
  protected long totalTime;
  private int[] tEdgeEncounters;
  private int[] sEdgeEncounters;

  public HSA(CostFunction w, MatcherFactory factory) {
    this(w, factory, true);
  }

  public HSA(CostFunction w, MatcherFactory factory, boolean useCavity) {
    this.costFunction = w;
    this.factory = factory;
    this.useCavity = useCavity;
  }

  public HSA(int initialSize1, int initialSize2, CostFunction e,
      MatcherFactory factory) {
    this(initialSize1, initialSize2, e, factory, true);
  }

  public HSA(int initialSize1, int initialSize2, CostFunction e,
      MatcherFactory factory, boolean useCavity) {
    this(e, factory, useCavity);
    this.matchers = new BipartiteCavityMatcher[initialSize1][initialSize2];
  }

  public double computeHSA(Tree t, Tree s) {
    return computeHSA(t, s, null);
  }

  public double computeHSA(Tree t, Tree s, TIntArrayList[] alignment) {
    computeSubtreeHSA(t, s);
    return computeHSAFromMatchers(t, s, alignment, false);
  }

  public double computeHSA(Tree t) {
    computeSubtreeHSA(t, t);
    return computeHSAFromMatchers(t, t, null, true);
  }

  public double computeHSA(Tree t, TIntArrayList[] alignment) {
    computeSubtreeHSA(t, t);
    return computeHSAFromMatchers(t, t, alignment, true);
  }

  protected double computeHSAFromMatchers(Tree t, Tree s,
      TIntArrayList[] alignment, boolean selfSimitry) {
    double cost = MathOperations.INFINITY, currCost;
    int tRoot = -1, sRoot = -1;

    if (!selfSimitry) {
      for (int i = 0; i < t.getNodeNum(); ++i) {
        for (int j = 0; j < s.getNodeNum(); ++j) {
          currCost = matchers[i][j].minCostMatching();
          if (MathOperations.greater(cost, currCost)) {
            cost = currCost;
            tRoot = i;
            sRoot = j;
          }
        }
      }
    } else {
      for (int i = 0; i < t.getNodeNum(); ++i) {
        int amount = t.outDeg(i);
        for (int n = 0; n < amount; n++) {
          int j = t.getNeighbor(i, n);
          currCost = matchers[i][j].getMatchCostXY(n,t.getNeighborIx(j, i));

          if (MathOperations.greater(cost, currCost)) {
            cost = currCost;
            tRoot = i;
            sRoot = j;
          }
        }
      }
    }

    if (alignment != null) {
      resetTracebackLists(alignment);
     
      // tracing back an optimal alignment:
      if (tRoot != -1) {
        alignment[0].add(tRoot);
        alignment[1].add(sRoot);
        alignment[2].add(0);

        BipartiteCavityMatcher matcher = matchers[tRoot][sRoot];
        matcher.setMinCostFullMatching(Double.NaN);
        matcher.minCostMatching();
        TIntArrayList[] matching = matcher.getCurrMatching();

        int nextGroupId = 1;
        for (int pair = 0; pair < matching[0].size(); ++pair) {
          nextGroupId = traceback(t, s, alignment, tRoot,
              matching[0].get(pair), sRoot,
              matching[1].get(pair), nextGroupId);
        }
      }

    }

    computationTime = System.currentTimeMillis() - computationTime;

    for (int i = 0; i < t.getNodeNum(); ++i) {
      for (int j = 0; j < s.getNodeNum(); ++j) {
        matchers[i][j].retain();
        matchers[i][j] = null;
      }
    }

    totalTime = System.currentTimeMillis() - totalTime;

    return cost;
  }

  /**
   * @param alignment
   */
  private void resetTracebackLists(TIntArrayList[] alignment) {
    if (alignment[0] == null)
      alignment[0] = new TIntArrayList();
    else
      alignment[0].resetQuick();
    if (alignment[1] == null)
      alignment[1] = new TIntArrayList();
    else
      alignment[1].resetQuick();
    if (alignment[2] == null)
      alignment[2] = new TIntArrayList();
    else
      alignment[2].resetQuick();
  }

  private int traceback(Tree t, Tree s, TIntArrayList[] alignment, int tFrom,
      int tToNeighborIx, int sFrom, int sToNeighborIx, int nextGroupIx) {

    double cost = matchers[tFrom][sFrom].getMatchCostXY(tToNeighborIx,
        sToNeighborIx);

    int tTo = t.outAdjLists[tFrom][tToNeighborIx];
    int sTo = s.outAdjLists[sFrom][sToNeighborIx];

    int tFromNeighborIx = t.getNeighborIx(tTo, tFrom);
    int sFromNeighborIx = s.getNeighborIx(sTo, sFrom);

    BipartiteCavityMatcher matcher = matchers[tTo][sTo];

    if (MathOperations.equals(cost,
        matcher.minCostCavityMatching(tFromNeighborIx, sFromNeighborIx)
            + costFunction.cost(t.labels[tTo], s.labels[sTo]))) {
      alignment[0].add(tTo);
      alignment[1].add(sTo);
      alignment[2].add(nextGroupIx);

      TIntArrayList[] matching = matcher.getCavityMatching(
          tFromNeighborIx, sFromNeighborIx);
      if (matching[0].size() != 1 || t.outDeg(tTo) != 2
          || s.outDeg(sTo) != 2) {
        ++nextGroupIx;
      }
      for (int pair = 0; pair < matching[0].size(); ++pair) {
        nextGroupIx = traceback(t, s, alignment, tTo,
            matching[0].get(pair), sTo, matching[1].get(pair),
            nextGroupIx);
      }
      return nextGroupIx;
    }

    double pruneAll = -matcher.safeDelCostX(tFromNeighborIx);
    // -t.weights[tTo][tFromNeighborIx];
    for (int u = 0; u < t.outAdjLists[tTo].length; ++u) {
      pruneAll += matcher.safeDelCostX(u);
      // t.weights[tTo][u];
    }

    double initWeight = pruneAll + t.smoothCost[tTo];

    for (int u = 0; u < t.outAdjLists[tTo].length; ++u) {
      if (u != tFromNeighborIx) {
        if (MathOperations.equals(cost, initWeight
            + matchers[tTo][sFrom].getMatchCostXY(u, sToNeighborIx)
            - matcher.safeDelCostX(u))) { // t.weights[tTo][u])) {
          if (t.outDeg(tTo) != 2) {
            ++nextGroupIx;
          }
          nextGroupIx = traceback(t, s, alignment, tTo, u, sFrom,
              sToNeighborIx, nextGroupIx);
          return nextGroupIx;
        }
      }
    }

    pruneAll = -matcher.safeDelCostY(sFromNeighborIx);// -s.weights[sTo][sFromNeighborIx];
    for (int u = 0; u < s.outAdjLists[sTo].length; ++u) {
      pruneAll += matcher.safeDelCostY(u);// s.weights[sTo][u];
    }

    initWeight = pruneAll + s.smoothCost[sTo];

    for (int u = 0; u < s.outAdjLists[sTo].length; ++u) {
      if (u != sFromNeighborIx) {
        if (MathOperations.equals(cost, initWeight
            + matchers[tFrom][sTo].getMatchCostXY(tToNeighborIx, u)
            - matcher.safeDelCostY(u))) { // s.weights[sTo][u])) {
          if (s.outDeg(sTo) != 2) {
            ++nextGroupIx;
          }
          nextGroupIx = traceback(t, s, alignment, tFrom,
              tToNeighborIx, sTo, u, nextGroupIx);
          return nextGroupIx;
        }
      }
    }

    throw new RuntimeException("Traceback error!");

  }

  /**
   * @param t
   * @param s
   */
  protected void computeSubtreeHSA(Tree t, Tree s) {

    totalTime = System.currentTimeMillis();

    if (this.matchers == null || this.matchers.length < t.getNodeNum()
        || this.matchers[0].length < s.getNodeNum()) {
      int tSize = t.getNodeNum();
      int sSize = s.getNodeNum();
      if (matchers != null) {
        tSize = Math.max(matchers.length, t.getNodeNum());
        sSize = Math.max(matchers[0].length, s.getNodeNum());
      }
      this.matchers = new BipartiteCavityMatcher[tSize][sSize];
    }

    BipartiteCavityMatcher matcher;

    // initiation of the matchers
    for (int i = 0; i < t.getNodeNum(); i++) {
      for (int j = 0; j < s.getNodeNum(); j++) {
        final int degI = t.outDeg(i);
        final int degJ = s.outDeg(j);
        matcher = factory.make(
            degI,
            degJ,
            !MathOperations.isInfinity(costFunction.cost(
                t.getLabel(i), s.getLabel(j))));
        matchers[i][j] = matcher;
        for (int u = 0; u < degI; u++) {
          matcher.setDelCostX(u, t.weights[i][u]);
        }

        for (int u = 0; u < degJ; u++) {
          matcher.setDelCostY(u, s.weights[j][u]);
        }
      }
    }

    if (tEdgeEncounters == null || tEdgeEncounters.length < t.getNodeNum()) {
      tEdgeEncounters = new int[t.getNodeNum()];
    }
    Arrays.fill(tEdgeEncounters, 0, t.getNodeNum(), FIRST_TIME);

    if (sEdgeEncounters == null || sEdgeEncounters.length < s.getNodeNum()) {
      sEdgeEncounters = new int[s.getNodeNum()];
    }

    computationTime = System.currentTimeMillis();

    double currScore, tmpScore; // , pruneAll
    for (int et = 0; et < t.getEdgeNum(); et++) {
      int fromT = t.edgeToNeighbors[et][0];
      int toTNeighborIx = t.edgeToNeighbors[et][1];
      int toT = t.outAdjLists[fromT][toTNeighborIx];
      int fromTNeighborIx = t.getNeighborIx(toT, fromT);

      Arrays.fill(sEdgeEncounters, 0, s.getNodeNum(), FIRST_TIME);
      for (int es = 0; es < s.getEdgeNum(); es++) {
        int fromS = s.edgeToNeighbors[es][0];
        int toSNeighborIx = s.edgeToNeighbors[es][1];
        int toS = s.outAdjLists[fromS][toSNeighborIx];
        int fromSNeighborIx = s.getNeighborIx(toS, fromS);

        matcher = matchers[toT][toS];

        if (useCavity) {
          if (sEdgeEncounters[toS] == SECOND_TIME) {
            if (tEdgeEncounters[toT] == FIRST_TIME) {
              matcher.processAllCavityMatchingY(fromTNeighborIx);
            } else if (tEdgeEncounters[toT] == SECOND_TIME) {
              matcher.processAllPairsCavityMatching();
            }
          }
        }
       
        final BipartiteCavityMatcher matcherToFrom = matchers[toT][fromS];
        double cavityPrune = 0;
        boolean foundForced = false;

        // IMPROTANT: we assume that at most one forced match exists!!!
        tmpScore = MathOperations.INFINITY;
        final int tNeighborsSize = t.outAdjLists[toT].length;
        for (int u = 0; u < tNeighborsSize; ++u) {
          if (u != fromTNeighborIx) {
            if (!MathOperations.isInfinity(t.weights[toT][u])) {
              cavityPrune += t.weights[toT][u];
              final double value = matcherToFrom.getMatchCostXY(u,
                  toSNeighborIx) - t.weights[toT][u];
              if (MathOperations.greater(
                  tmpScore,
                  value)
                  && !foundForced) {
                tmpScore = value;
              }
            }       
            else {
              if (foundForced) {
                throw new RuntimeException(
                    "more than two forced matches!");
              }
              tmpScore = matcherToFrom.getMatchCostXY(u,
                  toSNeighborIx);
              foundForced = true;
            }

          }
        }

        currScore = tmpScore + cavityPrune + t.smoothCost[toT];


        final BipartiteCavityMatcher matcherFromTo = matchers[fromT][toS];
        tmpScore = MathOperations.INFINITY;
        foundForced = false;
        cavityPrune = 0;

        // pruneAll = 0;
        final int sNeighborsSize = s.outAdjLists[toS].length;
        for (int u = 0; u < sNeighborsSize; ++u) {
          if (u != fromSNeighborIx) {
            if (!MathOperations.isInfinity(s.weights[toS][u])) {
              cavityPrune += s.weights[toS][u];
              final double value = matcherFromTo.getMatchCostXY(toTNeighborIx,
                  u) - s.weights[toS][u];
              if (MathOperations.greater(
                  tmpScore,
                  value)
                  && !foundForced) {
                tmpScore = value;
              }
              // pruneAll += s.weights[toS][u];
            } else {
              if (foundForced) {
                throw new RuntimeException(
                    "more than two forced matches!");
              }
              tmpScore = matcherFromTo.getMatchCostXY(
                  toTNeighborIx, u);
              foundForced = true;
            }
          }
        }
        // }

        tmpScore += cavityPrune + s.smoothCost[toS];
        if (MathOperations.greater(currScore, tmpScore)) {
          currScore = tmpScore;
        }

        // Best matching of toT and toS:
        double matchCost = matcher.minCostCavityMatching(
            fromTNeighborIx, fromSNeighborIx)
            + costFunction.cost(t.labels[toT], s.labels[toS]);
        if (MathOperations.greater(currScore, matchCost)) {
          currScore = matchCost;
        }

        matchers[fromT][fromS].setMatchCost(toTNeighborIx,
            toSNeighborIx, currScore);

        ++sEdgeEncounters[toS];
      }
      ++tEdgeEncounters[toT];
    }
  }

  public CostFunction getCostFunction() {
    return costFunction;
  }

  public void setCostFunction(CostFunction costFunction) {
    this.costFunction = costFunction;
  }

  public MatcherFactory getFactory() {
    return factory;
  }

  public void setFactory(MatcherFactory factory) {
    this.factory = factory;
  }

  public long getComputationTime() {
    return computationTime;
  }

  public long getTotalTime() {
    return totalTime;
  }

}
TOP

Related Classes of bgu.bio.algorithms.graphs.hsa.HSA

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.