Package org.integratedmodelling.riskwiz.learning.parameter.bayes

Source Code of org.integratedmodelling.riskwiz.learning.parameter.bayes.BayesianLearner

/**
* BayesianLearner.java
* ----------------------------------------------------------------------------------
*
* Copyright (C) 2008 www.integratedmodelling.org
* Created: May 16, 2008
*
* ----------------------------------------------------------------------------------
* This file is part of RiskWiz.
*
* RiskWiz is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* RiskWiz is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with the software; if not, write to the Free Software
* Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
*
* ----------------------------------------------------------------------------------
*
* @copyright 2008 www.integratedmodelling.org
* @author    Sergey Krivov
* @date      May 16, 2008
* @license   http://www.gnu.org/licenses/gpl.txt GNU General Public License v3
* @link      http://www.integratedmodelling.org
**/

package org.integratedmodelling.riskwiz.learning.parameter.bayes;


import java.io.IOException;
import java.util.Set;
import java.util.Vector;

import org.integratedmodelling.riskwiz.bn.BNNode;
import org.integratedmodelling.riskwiz.bn.BeliefNetwork;
import org.integratedmodelling.riskwiz.domain.DiscreteDomain;
import org.integratedmodelling.riskwiz.learning.IParameterLearner;
import org.integratedmodelling.riskwiz.learning.bndata.IGraphData;
import org.integratedmodelling.riskwiz.learning.bndata.IGraphDataSource;
import org.integratedmodelling.riskwiz.learning.bndata.IGraphTable;
import org.integratedmodelling.riskwiz.learning.dtable.Dirichlet;
import org.integratedmodelling.riskwiz.learning.dtable.DistTable;
import org.integratedmodelling.riskwiz.pfunction.TabularCPD;
import org.nfunk.jep.ParseException;


/**
* @author Sergey Krivov
*
*/
public class BayesianLearner implements IParameterLearner {
    BeliefNetwork bnet;

    /**
     *
     */
    public BayesianLearner() {// TODO Auto-generated constructor stub
    }
 
    /*
     * (non-Javadoc)
     *
     * @see org.integratedmodelling.riskwiz.learning.IParameterLearner#initialize(org.integratedmodelling.riskwiz.bn.BeliefNetwork)
     */
    @Override
  public void initialize(BeliefNetwork bn) {
        bnet = bn;
        initializeUniformDistributions();

    }

    /*
     * (non-Javadoc)
     *
     * @see org.integratedmodelling.riskwiz.learning.IParameterLearner#getBeliefNetwork()
     */
    @Override
  public BeliefNetwork getFinalResult() {
        Set<BNNode> nodes = bnet.vertexSet();

        for (BNNode node : nodes) {
            if (node.isProbabilistic()) {
                node.clearProperty("distribution");
            }
        }
        return bnet;
    }

    /*
     * (non-Javadoc)
     *
     * @see org.integratedmodelling.riskwiz.learning.IParameterLearner#learnFromTable(org.integratedmodelling.riskwiz.learning.bndata.IGraphTable)
     */
    @Override
  public void learnFromTable(IGraphTable gtable) {
        gtable.connect(bnet);
        Set<BNNode> nodes = bnet.vertexSet();

        for (BNNode node : nodes) {
            if (node.isProbabilistic() && (!node.isExpression())
                    && gtable.hasCompleteProjection(node)) {
                learnNodeParameters(node, gtable, gtable.getValues());
            }
        }

    }
 
    @Override
  public void learnFromDataSource(IGraphDataSource datasource) throws IOException {
        datasource.connect(bnet);
        while (datasource.hasNextValues()) {
            Vector<Vector<String>> table = datasource.getNextValues();
            Set<BNNode> nodes = bnet.vertexSet();

            for (BNNode node : nodes) {
                if (node.isProbabilistic() && (!node.isExpression())
                        && datasource.hasCompleteProjection(node)) {
                    learnNodeParameters(node, datasource, table);
                }
            }
     
            datasource.readNextValues();
        }
   
        datasource.close();
   
    }

    protected void learnNodeParameters(BNNode node, IGraphData gtable, Vector<Vector<String>> tuples) {
        DistTable dtable = (DistTable) node.getProperty("distribution");
    
        for (Vector<String> tuple : tuples) {
            int[] cptquery = gtable.getQuery(node, tuple);

            if (isCompleteQuery(cptquery)) {
                Dirichlet dist = dtable.getValue1(cptquery);

                dist.increment(cptquery[0]);
            }
        }

        node.setFunction(dtable.createCPF());
    }

    public boolean isCompleteQuery(int[] cptquery) {

        for (int i = 0; i < cptquery.length; i++) {
            if (cptquery[i] == -1) {
                return false;
            }
        }
        return true;

    }

    public void initializeWithPriors(BeliefNetwork bn, int virtualSamples) throws ParseException {
        bnet = bn;
        initializeDistributions(virtualSamples);

    }

    private void initializeUniformDistributions() {
        Set<BNNode> nodes = bnet.vertexSet();

        for (BNNode node : nodes) {
            if (node.getFunction() instanceof TabularCPD) {
                TabularCPD function = (TabularCPD) node.getFunction();
         
                DistTable dtable = new DistTable(
                        (DiscreteDomain) function.getDomain(),
                        function.getParentsDomains());

                dtable.setUniformDistributions();
                node.setProperty("distribution", dtable);
                node.setFunction(dtable.createCPF());
            }
        }
    }

    private void initializeDistributions(int virtualSamples) throws ParseException {
        Set<BNNode> nodes = bnet.vertexSet();

        for (BNNode node : nodes) {
            if (node.getFunction() instanceof TabularCPD) {
                TabularCPD function = (TabularCPD) node.getFunction();
                DistTable dtable = new DistTable(function, virtualSamples);

                node.setProperty("distribution", dtable);
                node.setFunction(dtable.createCPF());
            }
        }
    }

}
TOP

Related Classes of org.integratedmodelling.riskwiz.learning.parameter.bayes.BayesianLearner

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.