Package org.integratedmodelling.riskwiz.influence.jensen

Source Code of org.integratedmodelling.riskwiz.influence.jensen.StrongJoinTree

/**
* StrongJoinTree.java
* ----------------------------------------------------------------------------------
*
* Copyright (C) 2008 www.integratedmodelling.org
* Created: Feb 19, 2008
*
* ----------------------------------------------------------------------------------
* This file is part of RiskWiz.
*
* RiskWiz is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* RiskWiz is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with the software; if not, write to the Free Software
* Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
*
* ----------------------------------------------------------------------------------
*
* @copyright 2008 www.integratedmodelling.org
* @author    Sergey Krivov
* @date      Feb 19, 2008
* @license   http://www.gnu.org/licenses/gpl.txt GNU General Public License v3
* @link      http://www.integratedmodelling.org
**/

package org.integratedmodelling.riskwiz.influence.jensen;


import java.util.HashSet;
import java.util.Hashtable;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.Vector;

import org.integratedmodelling.riskwiz.bn.BNNode;
import org.integratedmodelling.riskwiz.bn.BeliefNetwork;
import org.integratedmodelling.riskwiz.domain.DiscreteDomain;
import org.integratedmodelling.riskwiz.domain.DomainFactory;
import org.integratedmodelling.riskwiz.graph.RiskUndirectedGraph;
import org.integratedmodelling.riskwiz.influence.JTPotential;
import org.integratedmodelling.riskwiz.influence.jensen.SJTVertex.EliminationOrder;
import org.integratedmodelling.riskwiz.jtree.IJoinTreeDecision;
import org.integratedmodelling.riskwiz.pt.CPT;
import org.integratedmodelling.riskwiz.pt.PT;
import org.integratedmodelling.riskwiz.pt.map.DomainMap2;
import org.integratedmodelling.riskwiz.pt.map.FMarginalizationMap;
import org.integratedmodelling.riskwiz.pt.map.FastMap2;
import org.integratedmodelling.riskwiz.pt.map.SubtableFastMap2;
import org.jgrapht.traverse.BreadthFirstIterator;


/**
* @author Sergey Krivov
*
*/

/**
* @author Sergey Krivov
*
*/
public class StrongJoinTree extends RiskUndirectedGraph<SJTVertex, SJTEdge>
        implements IJoinTreeDecision<SJTVertex> {

    protected BeliefNetwork bn;

    protected SJTVertex root;

    protected Hashtable<BNNode, ClusterBundle> clusterHash;

    protected Hashtable<DiscreteDomain, CPT> policyHash;
 
    protected Vector<Object> rootMmaps;

    public StrongJoinTree(BeliefNetwork bn) {
        super(SJTEdge.class);
        this.bn = bn;
    }

    public StrongJoinTree() {
        super(SJTEdge.class);

    }

    public void setVaciousPotentials() {
        Set<SJTVertex> verttset = this.vertexSet();

        // initialize clusters
        for (SJTVertex vertex : verttset) {
            vertex.setVacious();
        }
        // initialize spesets
        Set<SJTEdge> edgeset = this.edgeSet();

        for (SJTEdge edge : edgeset) {
            edge.setVacious();
        }
    }

    // builds map for fast operation and initializes tree
    // unless evidences changed no need to call initialize() after this
    @Override
  public void initializeStructiure() {
        setVaciousPotentials();

        // create correspondence between Belif nodes and clusters (vertexes of
        // StrongJoinTree)
        clusterHash = new Hashtable<BNNode, ClusterBundle>();

        Set<BNNode> bNNodes = bn.vertexSet();

        // debug
    
        for (BNNode node : bNNodes) {

            SJTVertex parentCluster = assignParentCluster(node);

            if (node.isNature()) {
        
                JTPotential clusterPT = parentCluster.getPotential();
                PT nodePT = node.getDiscreteCPT();
                FastMap2 fmap = clusterPT.createSubtableFastMap(nodePT);
                FMarginalizationMap mfmap = clusterPT.createFMarginalizationMap(
                        node.getDiscretizedDomain());
                FastMap2 liklihoodfmap = clusterPT.createSubtableFastMap(
                        DomainFactory.createDomainProduct(
                                node.getDiscretizedDomain()));

                clusterHash.put(node,
                        new ClusterBundle(parentCluster, fmap, mfmap,
                        liklihoodfmap));
            } else if (node.isUtility()) {
                JTPotential clusterPT = parentCluster.getPotential();
                Vector<DiscreteDomain> parentDomains = node.getDiscreteCPT().getParentsDomains();
                FastMap2 fmap = clusterPT.createSubtableFastMap(parentDomains);
                FMarginalizationMap mfmap = clusterPT.createFMarginalizationMap(
                        node.getDiscreteCPT().getParentsDomains());

                clusterHash.put(node,
                        new ClusterBundle(parentCluster, fmap, mfmap, null));
            }

        }

        // compile maps for fast operations
        Set<SJTVertex> verttset = this.vertexSet();

        for (SJTVertex vertex : verttset) {
            Set<SJTEdge> edgesOfvertex = this.edgesOf(vertex);

            for (SJTEdge jtedge : edgesOfvertex) {
                vertex.createFastMaps(jtedge);
            }
        }
   
        createRootMarginalizationFastMap();

    }

    // StrongJoinTree map structure has to be built before calling
    // initializeDecision()
    @Override
  public void initialize() {
        setVaciousPotentials();

        Set<BNNode> bNNodes = bn.vertexSet();

        for (BNNode node : bNNodes) {
            if (node.isNature()) {
                ClusterBundle bundle = clusterHash.get(node);
                SJTVertex parentCluster = bundle.jtcluster;
                JTPotential clusterPT = parentCluster.getPotential();

                clusterPT.multiplyByProbabilitySubtable(node.getDiscreteCPT(),
                        bundle.fopmap);
            } else if (node.isUtility()) {
                ClusterBundle bundle = clusterHash.get(node);
                SJTVertex parentCluster = bundle.jtcluster;
                JTPotential clusterPT = parentCluster.getPotential();

                clusterPT.addUtilitySubtable(node.getDiscreteCPT(),
                        bundle.fopmap);

            }

        }

    }

    @Override
  public void initializeLikelihoods() {
        Set<BNNode> bNNodes = bn.vertexSet();

        for (BNNode node : bNNodes) {
            if (node.isProbabilistic()) {
                initializeLikelihood(node);
            }
        }

    }

    public void initializeLikelihood(BNNode node) {
        ClusterBundle bundle = clusterHash.get(node);
        SJTVertex parentCluster = bundle.jtcluster;
        JTPotential clusterPT = parentCluster.getPotential();

        if (node.hasEvidence()) {

            clusterPT.multiplyByProbabilitySubtable(node.getEvidence(),
                    bundle.liklihoodfmap);

        }
    }

    public void collectEvidence(SJTVertex X) {
        X.isMarked = true;
        Set<SJTVertex> neighbours = this.getNeighbors(X);

        for (SJTVertex neighbor : neighbours) {
            if (!neighbor.isMarked) {
                collectEvidence(neighbor);
                passMessage(neighbor, X);
        
            }
        }

    }

    @Override
  public void propagateEvidence() {
        policyHash = new Hashtable<DiscreteDomain, CPT>();
        unmarkAll();
    
        collectEvidence(root);
        JTPotential.marginalizeDomainsSequence(root.getPotential(), rootMmaps,
                policyHash);
   
    }

    public void passMessage(SJTVertex source, SJTVertex target) {
        // boolean pp=false;
        SJTEdge jtedge = this.getEdge(source, target);
        JTPotential sepsetPT = jtedge.getPotential();

        JTPotential sourcePT = source.getPotential();
        Vector<Object> mmaps = source.getMarginalizationFastMap(jtedge);

        // System.out.println("\nmapsize: "+ mmaps.size()+ " \n" );
        // for (Object object : mmaps) {
        // if (object instanceof MarginalizationFastMap) {
        // System.out.println( "marg," );
        // } else if (object instanceof DiscreteDomain) {
        //
        // System.out.println( "maxMarg," );
        // }
        // }
        sepsetPT = JTPotential.marginalizeDomainsSequence(sourcePT, mmaps,
                policyHash);
        jtedge.setPotential(sepsetPT);

        JTPotential targetPT = target.getPotential();

        SubtableFastMap2 fmap2 = target.getSubtableOpFastMap(jtedge);

        targetPT.multiplyBySubtableFast(sepsetPT, fmap2);

    }

    public void unmarkAll() {
        Set<SJTVertex> vertexSet = this.vertexSet();

        for (SJTVertex vertex : vertexSet) {
            vertex.isMarked = false;
        }
    }

    protected SJTVertex assignParentCluster(BNNode node) {
        if (node.isDecision()) {
            return assignParentClusterDecision(node);
        }
        Set<BNNode> family = new HashSet<BNNode>();

        if (!node.isUtility()) {
            family.add(node);
        }
        family.addAll(bn.getParents(node));
        SJTVertex v = null;
        Set<SJTVertex> verttset = this.vertexSet();

        for (SJTVertex vertex : verttset) {
            if (vertex.containsAll(family)) {
                v = vertex;
                break;
            }
        }
        return v;
    }

    protected SJTVertex assignParentClusterDecision(BNNode node) {

        for (BreadthFirstIterator<SJTVertex, SJTEdge> it = new BreadthFirstIterator<SJTVertex, SJTEdge>(
                this, this.getRoot()); it.hasNext();) {
            SJTVertex vertex = it.next();

            if (vertex.contains(node)) {
                return vertex;
            }
        }
        return null;
    }

    protected class ClusterBundle {
        protected SJTVertex jtcluster;

        protected DomainMap2 fopmap;

        protected FastMap2 liklihoodfmap;

        protected FMarginalizationMap mfmap;

        public ClusterBundle(SJTVertex jtcluster, DomainMap2 fmap,
                FMarginalizationMap mfmap, FastMap2 liklihoodfmap) {
            this.fopmap = fmap;
            this.jtcluster = jtcluster;
            this.mfmap = mfmap;
            this.liklihoodfmap = liklihoodfmap;
        }

    }

    @Override
  public SJTVertex getRoot() {
        return root;
    }

    public void setRoot(SJTVertex root) {
        this.root = root;
    }
 
    // vacious methods, TODO need an adapter

    @Override
  public BeliefNetwork getBeliefNetwork() {
    
        return bn;
    }

    @Override
  public CPT getPolicy(DiscreteDomain nodeDom) {
        // TODO Auto-generated method stub
        return policyHash.get(nodeDom);
    }
 
    public void createRootMarginalizationFastMap() {
        rootMmaps = new Vector<Object>();
        SortedSet<BNNode> marginals = new TreeSet<BNNode>(new EliminationOrder());

        marginals.addAll(root.getClique());
   
        Vector<DiscreteDomain> currentDomainProduct = new Vector<DiscreteDomain>();

        currentDomainProduct.addAll(root.getPotential().getDomainProduct());
   
        while (!marginals.isEmpty()) {
            Vector<DiscreteDomain> domsSet = new Vector<DiscreteDomain>();

            while (!marginals.isEmpty() && marginals.first().isNature()) {
                BNNode first = marginals.first();

                domsSet.add(first.getDiscretizedDomain());
                marginals.remove(first);
            }
            if (!domsSet.isEmpty()) {
        
                FMarginalizationMap mmap = new FMarginalizationMap(
                        currentDomainProduct, domsSet);

                rootMmaps.add(mmap);
                currentDomainProduct.removeAll(domsSet);
            }
            while (!marginals.isEmpty() && marginals.first().isDecision()) {
         
                BNNode first = marginals.first();
       
                rootMmaps.add(first.getDomain());
                currentDomainProduct.remove(first.getDomain());
                marginals.remove(first);
            }
        } 
     
    }
 
    // debugging functions

  
    public void report(SJTVertex v) {
        System.out.println(" clique:");
        Set<BNNode> cliq = v.getClique();

        for (BNNode node : cliq) {
            System.out.println(node.getName());
        }        
        System.out.println();
    }
  
    public void reportAll() {
        Set<SJTVertex> verttset = this.vertexSet();

        for (SJTVertex vertex : verttset) {
            report(vertex);
        }

    }

}
TOP

Related Classes of org.integratedmodelling.riskwiz.influence.jensen.StrongJoinTree

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.