Package org.integratedmodelling.riskwiz.inference.ls

Source Code of org.integratedmodelling.riskwiz.inference.ls.JoinTree

/**
* StrongJoinTree.java
* ----------------------------------------------------------------------------------
*
* Copyright (C) 2008 www.integratedmodelling.org
* Created: Feb 19, 2008
*
* ----------------------------------------------------------------------------------
* This file is part of RiskWiz.
*
* RiskWiz is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* RiskWiz is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with the software; if not, write to the Free Software
* Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
*
* ----------------------------------------------------------------------------------
*
* @copyright 2008 www.integratedmodelling.org
* @author    Sergey Krivov
* @date      Feb 19, 2008
* @license   http://www.gnu.org/licenses/gpl.txt GNU General Public License v3
* @link      http://www.integratedmodelling.org
**/

package org.integratedmodelling.riskwiz.inference.ls;


import java.util.HashSet;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.Set;
import java.util.Vector;

import org.integratedmodelling.riskwiz.bn.BNNode;
import org.integratedmodelling.riskwiz.bn.BeliefNetwork;
import org.integratedmodelling.riskwiz.domain.DomainFactory;
import org.integratedmodelling.riskwiz.graph.RiskUndirectedGraph;
import org.integratedmodelling.riskwiz.jtree.IJoinTree;
import org.integratedmodelling.riskwiz.jtree.JTEdge;
import org.integratedmodelling.riskwiz.pt.PT;
import org.integratedmodelling.riskwiz.pt.map.DomainMap2;
import org.integratedmodelling.riskwiz.pt.map.FMarginalizationMap;
import org.integratedmodelling.riskwiz.pt.map.FastMap2;
import org.integratedmodelling.riskwiz.pt.map.SubtableFastMap2;


/**
* @author Sergey Krivov
*
*/

/**
* @author Sergey Krivov
*
*/
public class JoinTree extends RiskUndirectedGraph<JTVertexHugin, JTEdgeHugin>
        implements IJoinTree<JTVertexHugin> {

    BeliefNetwork bn;

    Hashtable<BNNode, ClusterBundle> clusterHash;

    public JoinTree(BeliefNetwork bn) {
        super(JTEdgeHugin.class);
        this.bn = bn;
    }

    @Override
  public void initializeStructiure() {
        Set<JTVertexHugin> verttset = this.vertexSet();

        // initialize clusters
        for (JTVertexHugin vertex : verttset) {
            vertex.setAll(1);
        }
        // initialize spesets
        Set<JTEdgeHugin> edgeset = this.edgeSet();

        for (JTEdgeHugin edge : edgeset) {
            edge.setAll(1);
        }

        // create correspondence between Belif nodes and clusters (vertexes of
        // StrongJoinTree)
        clusterHash = new Hashtable<BNNode, ClusterBundle>();

        Set<BNNode> bNNodes = bn.vertexSet();

        for (BNNode node : bNNodes) {
            if (node.isNature()) {
                boolean debugInit = false;
                JTVertexHugin parentCluster = assignParentCluster(node);

                PT clusterPT = parentCluster.getPt();
                PT nodePT = node.getDiscreteCPT();
                SubtableFastMap2 fmap = clusterPT.createSubtableFastMap(nodePT);
                FMarginalizationMap mfmap = new FMarginalizationMap(
                        clusterPT.getDomainProduct(),
                        node.getDiscretizedDomain());
            
                FastMap2 liklihoodfmap = clusterPT.createSubtableFastMap(
                        DomainFactory.createDomainProduct(
                                node.getDiscretizedDomain()));

                clusterHash.put(node,
                        new ClusterBundle(parentCluster, fmap, mfmap,
                        liklihoodfmap));
                // if (debugInit)
                // System.out.println("clusterPT  \n" + clusterPT.toString()
                // + "\n");
                // if (debugInit)
                // System.out.println("nodePT  \n" + nodePT.toString() + "\n");

                clusterPT.multiplyBySubtable(nodePT, fmap);

                // if (debugInit)
                // System.out.println("clusterPT after mutiplication\n"
                // + clusterPT.toString() + "\n");
                // node.getPrior().setAll(1);
            }
        }

        // compile maps for fast operations
        for (JTVertexHugin vertex : verttset) {
            Set<JTEdgeHugin> edgesOfvertex = this.edgesOf(vertex);

            for (JTEdgeHugin jtedge : edgesOfvertex) {
                vertex.createFastMaps(jtedge);
            }
        }

    }

    // StrongJoinTree map structure has to be built before calling initialize()
    @Override
  public void initialize() {
        // boolean init=false;
        Set<JTVertexHugin> verttset = this.vertexSet();

        // initialize clusters
        for (JTVertexHugin vertex : verttset) {
            vertex.setAll(1);
        }
        // initialize spesets
        Set<JTEdgeHugin> edgeset = this.edgeSet();

        for (JTEdgeHugin edge : edgeset) {
            edge.setAll(1);
        }

        Set<BNNode> bNNodes = bn.vertexSet();

        for (BNNode node : bNNodes) {
            if (node.isNature()) {
                ClusterBundle bundle = clusterHash.get(node);
                JTVertexHugin parentCluster = bundle.jtcluster;
                PT clusterPT = parentCluster.getPt();
                PT nodePT = node.getDiscreteCPT();
                SubtableFastMap2 fmap = (SubtableFastMap2) bundle.fopmap;

                // FMarginalizationMap mfmap = bundle.mfmap;
                // FastMap2 liklihoodfmap = bundle.liklihoodfmap;
                // if(init)System.out.println( "clusterPT \n"+
                // clusterPT.toString() + "\n");
                // if(init)System.out.println( "nodePT \n"+
                // nodePT.toString() + "\n");

                clusterPT.multiplyBySubtableFast(nodePT, fmap);

                // if(init)System.out.println( "clusterPT after
                // mutiplication\n"+
                // clusterPT.toString() + "\n");
            }

        }

    }

    @Override
  public void initializeLikelihoods() {
        Set<BNNode> bNNodes = bn.vertexSet();

        for (BNNode node : bNNodes) {
            if (node.isNature()) {
                initializeLikelihood(node);
            }
        }

    }

    public void initializeLikelihood(BNNode node) {
        ClusterBundle bundle = clusterHash.get(node);
        JTVertexHugin parentCluster = bundle.jtcluster;
        PT clusterPT = parentCluster.getPt();

        if (node.hasEvidence()) {
      
            // cross fingers
            if (bundle.liklihoodfmap instanceof SubtableFastMap2) {
                clusterPT.multiplyBySubtableFast(node.getEvidence(),
                        (SubtableFastMap2) bundle.liklihoodfmap);
            } else {
                clusterPT.multiplyBySubtable(node.getEvidence(),
                        bundle.liklihoodfmap);
            }

        }
    }

    @Override
  public void propagateEvidence(JTVertexHugin X) {
   
        unmarkAll();
        collectEvidence(X);
        unmarkAll();
        distributeEvidence(X);
    }

    @Override
  public void propagateEvidence(BNNode node) {
        ClusterBundle bundle = clusterHash.get(node);
        JTVertexHugin parentCluster = bundle.jtcluster;

        PT clusterPT = parentCluster.getPt();

        clusterPT.multiplyBySubtable(node.getEvidence(), bundle.liklihoodfmap);
        unmarkAll();
        collectEvidence(parentCluster);
        unmarkAll();
        distributeEvidence(parentCluster);
    }

    @Override
  public void propagateEvidence() {
        Iterator it = this.vertexSet().iterator();

        JTVertexHugin vtx = (JTVertexHugin) it.next();

        this.propagateEvidence(vtx);
    }

    @Override
  public void collectEvidence(JTVertexHugin X) {
        X.isMarked = true;
        Set<JTVertexHugin> neighbours = this.getNeighbors(X);

        for (JTVertexHugin neighbor : neighbours) {
            if (!neighbor.isMarked) {
                collectEvidence(neighbor);
                passMessage(neighbor, X);
                // System.out.println("Collect from "+ neighbor.getName()+ "to "+ X.getName());
            }
     
        }

    }

    @Override
  public void distributeEvidence(JTVertexHugin X) {
        X.isMarked = true;
        Set<JTVertexHugin> neighbours = this.getNeighbors(X);

        for (JTVertexHugin neighbor : neighbours) {
            if (!neighbor.isMarked) {
                passMessage(X, neighbor);
                // System.out.println("Distribute from "+ X.getName()+ "to "+neighbor.getName());
                distributeEvidence(neighbor);

            }
        }

    }

    @Override
  public void passMessage(JTVertexHugin source, JTVertexHugin target) {

        JTEdgeHugin jtedge = this.getEdge(source, target);
        PT sepsetPT = jtedge.getPt();
    
        // System.out.println("was \n"+ sepsetPT.toString());

        PT oldSepsetPT = sepsetPT.clone();
        PT sourcePT = source.getPt();

        FMarginalizationMap mmap = source.getFMarginalizationMap(jtedge);

        PT.marginalizeDomainsFast(sepsetPT, sourcePT, mmap);
     
        // System.out.println("become \n"+ sepsetPT.toString());
        PT targetPT = target.getPt();
    
        SubtableFastMap2 fmap2 = target.getSubtableOpFastMap(jtedge);

        targetPT.multiplyAndDivideBySubtableFast(sepsetPT, oldSepsetPT, fmap2);
    
    }

    public void unmarkAll() {
        Set<JTVertexHugin> vertexSet = this.vertexSet();

        for (JTVertexHugin vertex : vertexSet) {
            vertex.isMarked = false;
        }
    }

    public void setNodeMarginals() {
        Set<BNNode> bNNodes = bn.vertexSet();

        for (BNNode node : bNNodes) {
            if (node.isNature()) {
                ClusterBundle cbundle = clusterHash.get(node);
                JTVertexHugin jtcluster = cbundle.jtcluster;
                FMarginalizationMap mfmap = cbundle.mfmap;
                PT marginal = new PT(mfmap.getProjectionDomainProduct());

                PT.marginalizeDomainsFast(marginal, jtcluster.getPt(), mfmap);
                // MarginalizationDomainMap mfmap = cbundle.mfmap;
                // node.setMarginal(PT.marginalizeDomains(jtcluster.getPt(),
                // mfmap));
                node.setMarginal(marginal);
            }

        }
    }

    @Override
  public void setNodeConditionalMarginals() {
        Set<BNNode> bNNodes = bn.vertexSet();

        for (BNNode node : bNNodes) {
            if (node.isNature()) {
                ClusterBundle cbundle = clusterHash.get(node);
                JTVertexHugin jtcluster = cbundle.jtcluster;
                FMarginalizationMap mfmap = cbundle.mfmap;

                PT marginal = new PT(mfmap.getProjectionDomainProduct());

                PT.marginalizeDomainsFast(marginal, jtcluster.getPt(), mfmap);
                // MarginalizationDomainMap mfmap = cbundle.mfmap;
                // PT marginal = PT.marginalizeDomains(jtcluster.getPt(),
                // mfmap);
                marginal.normalize();
                node.setMarginal(marginal);
            }

        }
    }

    private JTVertexHugin assignParentCluster(BNNode node) {
        Set<BNNode> family = new HashSet<BNNode>();

        family.add(node);
        family.addAll(bn.getParents(node));
        JTVertexHugin v = null;
        Set<JTVertexHugin> verttset = this.vertexSet();

        for (JTVertexHugin vertex : verttset) {
            if (vertex.containsAll(family)) {
                return vertex;
            }
        }
        return v;
    }

    private class ClusterBundle {
        private JTVertexHugin jtcluster;

        private DomainMap2 fopmap;

        private DomainMap2 liklihoodfmap;

        private FMarginalizationMap mfmap;

        public ClusterBundle(JTVertexHugin jtcluster, DomainMap2 fmap,
                FMarginalizationMap mfmap, DomainMap2 liklihoodfmap) {
            this.fopmap = fmap;
            this.jtcluster = jtcluster;
            this.mfmap = mfmap;
            this.liklihoodfmap = liklihoodfmap;
        }

    }

    @Override
  public BeliefNetwork getBeliefNetwork() {
        return bn;
    }

    public void dump() {

        System.out.println("Join Tree");
        System.out.println("Edges:");
        Set<JTEdgeHugin> edges = this.edgeSet();

        for (JTEdgeHugin edge : edges) {
            System.out.print(
                    this.getEdgeSource(edge).getName());
            System.out.print("<--->");
            System.out.print(
                    this.getEdgeTarget(edge).getName());
            System.out.println("");

            System.out.println("");
        }

        System.out.println("Nodes:");
        Set<JTVertexHugin> nodes = this.vertexSet();

        for (JTVertexHugin v : nodes) {
            System.out.print(v.getName() + ": ");

            printNodes(v.getClique());
        }

    }

    public void check() {

        System.out.println("Join Tree check");

        Set<JTEdgeHugin> edges = this.edgeSet();

        for (JTEdgeHugin edge : edges) {

            JTVertexHugin source = (this.getEdgeSource(edge));
            JTVertexHugin target = (this.getEdgeTarget(edge));
            Set<BNNode> edgenodes = edge.getSepset();

            if (!source.getClique().containsAll(edgenodes)
                    || !target.getClique().containsAll(edgenodes)
                    || edgenodes.isEmpty()) {
                System.out.println(
                        "Error in" + source.getName() + "<-->"
                        + target.getName());
                System.out.print(source.getName() + " Vertex nodes:");
                printNodes(source.getClique());
                System.out.println("SepsetNodes:");
                printNodes(edgenodes);
                System.out.print(target.getName() + " Vertex nodes:");
                printNodes(source.getClique());
                System.out.println("");
            }

        }

    }

    public void printNodes(Set<BNNode> bnnodes) {
        for (BNNode bNNode : bnnodes) {
            System.out.print(bNNode.getName() + ", ");
        }
        System.out.println("");
    }
 
    public void checkClusterAssignment() {
        Set<BNNode> bNNodes = bn.vertexSet();

        for (BNNode node : bNNodes) {
            if (node.isNature()) {
                ClusterBundle cbundle = clusterHash.get(node);
                JTVertexHugin jtcluster = cbundle.jtcluster;

                System.out.println(node.getName() + ": ");
       
                printNodes(jtcluster.getClique());
       
            }

        }
   
    }
 
    public void checkClusterAssignment2() {
   
        Set<JTVertexHugin> verts = new HashSet<JTVertexHugin>();

        verts.addAll(this.vertexSet());
        Set<BNNode> bNNodes = bn.vertexSet();

        for (BNNode node : bNNodes) {
            if (node.isNature()) {
                ClusterBundle cbundle = clusterHash.get(node);
                JTVertexHugin jtcluster = cbundle.jtcluster;

                verts.remove(jtcluster);
                System.out.println(node.getName() + ":  ");
                System.out.print(jtcluster.getName() + ": ");
                printNodes(jtcluster.getClique());
       
            }

        }
   
        System.out.println("Not Assigned:");
   
        for (JTVertexHugin v : verts) {
            System.out.println(v.getName());
            printNodes(v.getClique());
        }
   
    }
 
    public void printCluster(String name) {
        for (JTVertexHugin v : this.vertexSet()) {
            if (v.getName().equalsIgnoreCase(name)) {
                System.out.println(v.getName());
                System.out.println(v.getPt().toString());
            }
        }
    }
 
    public void printEdges() {
        for (JTEdgeHugin e : this.edgeSet()) {
            System.out.println(
                    "source: " + e.getSourceVertex().getName() + "  :");
            printNodes(e.getSourceVertex().getClique());
            System.out.println("edge has nodes: ");
            printNodes(e.getSepset());
            System.out.println("target: " + e.getTargetVertex().getName() + " :");
            printNodes(e.getTargetVertex().getClique());
      
        }
    }

    public void checkJTProperty() {
        JTVertexHugin[] cliqueVertexesArray = new JTVertexHugin[this.vertexSet().size()];

        this.vertexSet().toArray(cliqueVertexesArray);

        if (cliqueVertexesArray.length > 1) {
            for (int i = 0; i < cliqueVertexesArray.length; i++) {
                for (int j = i + 1; j < cliqueVertexesArray.length; j++) {
                    Set<BNNode> sepset = JTEdge.intersection(
                            cliqueVertexesArray[i].getClique(),
                            cliqueVertexesArray[j].getClique());

                    if (!sepset.isEmpty()) {
                        checkPath(cliqueVertexesArray[i], cliqueVertexesArray[j],
                                sepset);
                    }
                }
            }
        }
    }
 
    public void checkNeighbours() {
        Set<JTVertexHugin> vertex = this.vertexSet();

        for (JTVertexHugin v : vertex) {
            Set<JTVertexHugin> neighbours = this.getNeighbors(v);

            System.out.println(v.getName() + "  has neighbors:");
            for (JTVertexHugin jtVertexHugin : neighbours) {
                System.out.print(jtVertexHugin.getName() + ",");
       
            }
            System.out.println("");
        }
    
    }
 
    public void checkNormalization() {
        Set<JTVertexHugin> vertex = this.vertexSet();

        for (JTVertexHugin v : vertex) {
            System.out.print(v.getName() + ", sum is  ");
            System.out.println(v.getPt().sum());
            System.out.println("");
        }
    
    }

    private void checkPath(JTVertexHugin source, JTVertexHugin target,
            Set<BNNode> sepset) {
        Set<JTVertexHugin> neighbours = this.getNeighbors(source);

        for (JTVertexHugin v : neighbours) {
            Vector<JTVertexHugin> visited = new Vector<JTVertexHugin>();

            visited.add(source);
            visited.add(v);
      
            if (v.getClique().containsAll(sepset)) {
                checkPath(source, v, target, sepset, true, visited);
            } else {
                checkPath(source, v, target, sepset, false, visited);
            }
        }

    }

    private void checkPath(JTVertexHugin source, JTVertexHugin temp,
            JTVertexHugin target, Set<BNNode> sepset, boolean report,
            Vector<JTVertexHugin> visited) {
   
        if (temp == target) {
            if (report) {
                System.out.println("OK:");
                reportPath(visited);
                return;
            } else {
                System.out.println("Error:");
                printNodes(sepset);
                reportPath(visited);
                return;
            }
     
        }
   
        Set<JTVertexHugin> neighbours = this.getNeighbors(temp);

        for (JTVertexHugin v : neighbours) {
            if (!visited.contains(v)) {
                Vector<JTVertexHugin> visited2 = new Vector<JTVertexHugin>();

                visited2.addAll(visited);
                visited2.add(v);
                if (v.getClique().containsAll(sepset)) {
                    checkPath(source, v, target, sepset, report, visited2);
                } else {
                    checkPath(source, v, target, sepset, false, visited2);
                }
            }
      
        }

    }
 
    public void reportPath(Vector<JTVertexHugin> visited) {
   
        for (JTVertexHugin jtVertexHugin : visited) {
            System.out.print(jtVertexHugin.getName() + ",");
            printNodes(jtVertexHugin.getClique());
     
        }
        System.out.println("");
   
    }

}
TOP

Related Classes of org.integratedmodelling.riskwiz.inference.ls.JoinTree

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.