Package weka.classifiers.bayes.net.search.local

Source Code of weka.classifiers.bayes.net.search.local.GeneticSearch$BayesNetRepresentation

/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/

/*
* GeneticSearch.java
* Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.bayes.net.search.local;

import weka.classifiers.bayes.BayesNet;
import weka.classifiers.bayes.net.ParentSet;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;

/**
<!-- globalinfo-start -->
* This Bayes Network learning algorithm uses genetic search for finding a well scoring Bayes network structure. Genetic search works by having a population of Bayes network structures and allow them to mutate and apply cross over to get offspring. The best network structure found during the process is returned.
* <p/>
<!-- globalinfo-end -->
*
<!-- options-start -->
* Valid options are: <p/>
*
* <pre> -L &lt;integer&gt;
*  Population size</pre>
*
* <pre> -A &lt;integer&gt;
*  Descendant population size</pre>
*
* <pre> -U &lt;integer&gt;
*  Number of runs</pre>
*
* <pre> -M
*  Use mutation.
*  (default true)</pre>
*
* <pre> -C
*  Use cross-over.
*  (default true)</pre>
*
* <pre> -O
*  Use tournament selection (true) or maximum subpopulatin (false).
*  (default false)</pre>
*
* <pre> -R &lt;seed&gt;
*  Random number seed</pre>
*
* <pre> -mbc
*  Applies a Markov Blanket correction to the network structure,
*  after a network structure is learned. This ensures that all
*  nodes in the network are part of the Markov blanket of the
*  classifier node.</pre>
*
* <pre> -S [BAYES|MDL|ENTROPY|AIC|CROSS_CLASSIC|CROSS_BAYES]
*  Score type (BAYES, BDeu, MDL, ENTROPY and AIC)</pre>
*
<!-- options-end -->
*
* @author Remco Bouckaert (rrb@xm.co.nz)
* @version $Revision: 1.5 $
*/
public class GeneticSearch
    extends LocalScoreSearchAlgorithm {

    /** for serialization */
    static final long serialVersionUID = -7037070678911459757L;
 
    /** number of runs **/
    int m_nRuns = 10;

  /** size of population **/
  int m_nPopulationSize = 10;

  /** size of descendant population **/
  int m_nDescendantPopulationSize = 100;

  /** use cross-over? **/
  boolean m_bUseCrossOver = true;

  /** use mutation? **/
  boolean m_bUseMutation = true;
 
  /** use tournament selection or take best sub-population **/
  boolean m_bUseTournamentSelection = false
 
  /** random number seed **/
  int m_nSeed = 1;
 
  /** random number generator **/
  Random m_random = null;


  /** used in BayesNetRepresentation for efficiently determining
   * whether a number is square 
   */
  static boolean [] g_bIsSquare;
 
  class BayesNetRepresentation implements RevisionHandler {
   
    /** number of nodes in network **/   
    int m_nNodes = 0;

    /** bit representation of parent sets
     * m_bits[iTail + iHead * m_nNodes] represents arc iTail->iHead
     */
    boolean [] m_bits;
   
    /** score of represented network structure **/
    double m_fScore = 0.0f;
   
    /**
     * return score of represented network structure
     *
     * @return the score
     */
    public double getScore() {
      return m_fScore;
    } // getScore

    /**
     * c'tor
     *
     * @param nNodes the number of nodes
     */
    BayesNetRepresentation (int nNodes) {
      m_nNodes = nNodes;
    } // c'tor
   
    /** initialize with a random structure by randomly placing
     * m_nNodes arcs.
     */
    public void randomInit() {
      do {
        m_bits = new boolean [m_nNodes * m_nNodes];
        for (int i = 0; i < m_nNodes; i++) {
          int iPos;
          do {
            iPos = m_random.nextInt(m_nNodes * m_nNodes);
          } while (isSquare(iPos));
          m_bits[iPos] = true;
        }
      } while (hasCycles());
      calcScore();
    }

    /** calculate score of current network representation
     * As a side effect, the parent sets are set
     */
    void calcScore() {
      // clear current network
      for (int iNode = 0; iNode < m_nNodes; iNode++) {
        ParentSet parentSet = m_BayesNet.getParentSet(iNode);
        while (parentSet.getNrOfParents() > 0) {
          parentSet.deleteLastParent(m_BayesNet.m_Instances);
        }
      }
      // insert arrows
      for (int iNode = 0; iNode < m_nNodes; iNode++) {
        ParentSet parentSet = m_BayesNet.getParentSet(iNode);
        for (int iNode2 = 0; iNode2 < m_nNodes; iNode2++) {
          if (m_bits[iNode2 + iNode * m_nNodes]) {
            parentSet.addParent(iNode2, m_BayesNet.m_Instances);
          }
        }
      }
      // calc score
      m_fScore = 0.0;
      for (int iNode = 0; iNode < m_nNodes; iNode++) {
        m_fScore += calcNodeScore(iNode);
      }
    } // calcScore

    /** check whether there are cycles in the network
     *
     * @return true if a cycle is found, false otherwise
     */
    public boolean hasCycles() {
      // check for cycles
      boolean[] bDone = new boolean[m_nNodes];
      for (int iNode = 0; iNode < m_nNodes; iNode++) {

        // find a node for which all parents are 'done'
        boolean bFound = false;

        for (int iNode2 = 0; !bFound && iNode2 < m_nNodes; iNode2++) {
          if (!bDone[iNode2]) {
            boolean bHasNoParents = true;
            for (int iParent = 0; iParent < m_nNodes; iParent++) {
              if (m_bits[iParent + iNode2 * m_nNodes] && !bDone[iParent]) {
                bHasNoParents = false;
              }
            }
            if (bHasNoParents) {
              bDone[iNode2] = true;
              bFound = true;
            }
          }
        }
        if (!bFound) {
          return true;
        }
      }
      return false;
    } // hasCycles

    /** create clone of current object
     * @return cloned object
     */
    BayesNetRepresentation copy() {
      BayesNetRepresentation b = new BayesNetRepresentation(m_nNodes);
      b.m_bits = new boolean [m_bits.length];
      for (int i = 0; i < m_nNodes * m_nNodes; i++) {
        b.m_bits[i] = m_bits[i];
      }
      b.m_fScore = m_fScore;
      return b;   
    } // copy

    /** Apply mutation operation to BayesNet
     * Calculate score and as a side effect sets BayesNet parent sets.
     */
    void mutate() {
      // flip a bit
      do {       
        int iBit;
        do {
          iBit = m_random.nextInt(m_nNodes * m_nNodes);
        } while (isSquare(iBit));
       
        m_bits[iBit] = !m_bits[iBit];
      } while (hasCycles());

      calcScore();
    } // mutate

    /** Apply cross-over operation to BayesNet
     * Calculate score and as a side effect sets BayesNet parent sets.
     * @param other BayesNetRepresentation to cross over with
     */
    void crossOver(BayesNetRepresentation other) {
      boolean [] bits = new boolean [m_bits.length];
      for (int i = 0; i < m_bits.length; i++) {
        bits[i] = m_bits[i];
      }
      int iCrossOverPoint = m_bits.length;
      do {
        // restore to original state
        for (int i = iCrossOverPoint; i < m_bits.length; i++) {
          m_bits[i] = bits[i];
        }
        // take all bits from cross-over points onwards
        iCrossOverPoint = m_random.nextInt(m_bits.length);
        for (int i = iCrossOverPoint; i < m_bits.length; i++) {
          m_bits[i] = other.m_bits[i];
        }
      } while (hasCycles());
      calcScore();
    } // crossOver
       
    /** check if number is square and initialize g_bIsSquare structure
     * if necessary
     * @param nNum number to check (should be below m_nNodes * m_nNodes)
     * @return true if number is square
     */
    boolean isSquare(int nNum) {
      if (g_bIsSquare == null || g_bIsSquare.length < nNum) {
        g_bIsSquare = new boolean [m_nNodes * m_nNodes];
        for (int i = 0; i < m_nNodes; i++) {
          g_bIsSquare[i * m_nNodes + i] = true;
        }
      }
      return g_bIsSquare[nNum];
    } // isSquare

    /**
     * Returns the revision string.
     *
     * @return    the revision
     */
    public String getRevision() {
      return RevisionUtils.extract("$Revision: 1.5 $");
    }
  } // class BayesNetRepresentation
       
  /**
   * search determines the network structure/graph of the network
   * with a genetic search algorithm.
   *
   * @param bayesNet the network to use
   * @param instances the data to use
   * @throws Exception if population size doesn fit or neither cross-over or mutation was chosen
   */
  protected void search(BayesNet bayesNet, Instances instances) throws Exception {
    // sanity check
    if (getDescendantPopulationSize() < getPopulationSize()) {
      throw new Exception ("Descendant PopulationSize should be at least Population Size");
    }
    if (!getUseCrossOver() && !getUseMutation()) {
      throw new Exception ("At least one of mutation or cross-over should be used");
    }
   
    m_random = new Random(m_nSeed);

    // keeps track of best structure found so far
    BayesNet bestBayesNet;
    // keeps track of score pf best structure found so far
    double fBestScore = 0.0
    for (int iAttribute = 0; iAttribute < instances.numAttributes(); iAttribute++) {
      fBestScore += calcNodeScore(iAttribute);
    }

    // initialize bestBayesNet
    bestBayesNet = new BayesNet();
    bestBayesNet.m_Instances = instances;
    bestBayesNet.initStructure();
    copyParentSets(bestBayesNet, bayesNet);
   
               
        // initialize population       
    BayesNetRepresentation  [] population = new BayesNetRepresentation [getPopulationSize()];
        for (int i = 0; i < getPopulationSize(); i++) {
          population[i] = new BayesNetRepresentation (instances.numAttributes());
      population[i].randomInit();
      if (population[i].getScore() > fBestScore) {
        copyParentSets(bestBayesNet, bayesNet);
        fBestScore = population[i].getScore();
       
      }
        }
       
        // go do the search       
        for (int iRun = 0; iRun < m_nRuns; iRun++) {
          // create descendants
      BayesNetRepresentation  [] descendantPopulation = new BayesNetRepresentation  [getDescendantPopulationSize()];
      for (int i = 0; i < getDescendantPopulationSize(); i++) {
        descendantPopulation[i] = population[m_random.nextInt(getPopulationSize())].copy();
        if (getUseMutation()) {
          if (getUseCrossOver() && m_random.nextBoolean()) {
            descendantPopulation[i].crossOver(population[m_random.nextInt(getPopulationSize())]);           
          } else {
            descendantPopulation[i].mutate();               
          }
        } else {
          // use crossover
          descendantPopulation[i].crossOver(population[m_random.nextInt(getPopulationSize())]);
        }

        if (descendantPopulation[i].getScore() > fBestScore) {
          copyParentSets(bestBayesNet, bayesNet);
          fBestScore = descendantPopulation[i].getScore();
        }
      }
      // select new population
      boolean [] bSelected = new boolean [getDescendantPopulationSize()];
      for (int i = 0; i < getPopulationSize(); i++) {
        int iSelected = 0;
        if (m_bUseTournamentSelection) {
          // use tournament selection
          iSelected = m_random.nextInt(getDescendantPopulationSize());
          while (bSelected[iSelected]) {
            iSelected = (iSelected + 1) % getDescendantPopulationSize();
          }
          int iSelected2 =  m_random.nextInt(getDescendantPopulationSize());
          while (bSelected[iSelected2]) {
            iSelected2 = (iSelected2 + 1) % getDescendantPopulationSize();
          }
          if (descendantPopulation[iSelected2].getScore() > descendantPopulation[iSelected].getScore()) {
            iSelected = iSelected2;
          }
        } else {
          // find best scoring network in population
          while (bSelected[iSelected]) {
            iSelected++;
          }
          double fScore = descendantPopulation[iSelected].getScore();
          for (int j = 0; j < getDescendantPopulationSize(); j++) {
            if (!bSelected[j] && descendantPopulation[j].getScore() > fScore) {
              fScore = descendantPopulation[j].getScore();
              iSelected = j;
            }
          }
        }
        population[i] = descendantPopulation[iSelected];
        bSelected[iSelected] = true;
      }
        }
       
        // restore current network to best network
    copyParentSets(bayesNet, bestBayesNet);
   
    // free up memory
    bestBayesNet = null;
    } // search


  /** copyParentSets copies parent sets of source to dest BayesNet
   * @param dest destination network
   * @param source source network
   */
  void copyParentSets(BayesNet dest, BayesNet source) {
    int nNodes = source.getNrOfNodes();
    // clear parent set first
    for (int iNode = 0; iNode < nNodes; iNode++) {
      dest.getParentSet(iNode).copy(source.getParentSet(iNode));
    }   
  } // CopyParentSets

    /**
    * @return number of runs
    */
    public int getRuns() {
        return m_nRuns;
    } // getRuns

    /**
     * Sets the number of runs
     * @param nRuns The number of runs to set
     */
    public void setRuns(int nRuns) {
        m_nRuns = nRuns;
    } // setRuns

  /**
   * Returns an enumeration describing the available options.
   *
   * @return an enumeration of all the available options.
   */
  public Enumeration listOptions() {
    Vector newVector = new Vector(7);

    newVector.addElement(new Option("\tPopulation size", "L", 1, "-L <integer>"));
    newVector.addElement(new Option("\tDescendant population size", "A", 1, "-A <integer>"));
    newVector.addElement(new Option("\tNumber of runs", "U", 1, "-U <integer>"));
    newVector.addElement(new Option("\tUse mutation.\n\t(default true)", "M", 0, "-M"));
    newVector.addElement(new Option("\tUse cross-over.\n\t(default true)", "C", 0, "-C"));
    newVector.addElement(new Option("\tUse tournament selection (true) or maximum subpopulatin (false).\n\t(default false)", "O", 0, "-O"));
    newVector.addElement(new Option("\tRandom number seed", "R", 1, "-R <seed>"));

    Enumeration enu = super.listOptions();
    while (enu.hasMoreElements()) {
      newVector.addElement(enu.nextElement());
    }
    return newVector.elements();
  } // listOptions

  /**
   * Parses a given list of options. <p/>
   *
   <!-- options-start -->
   * Valid options are: <p/>
   *
   * <pre> -L &lt;integer&gt;
   *  Population size</pre>
   *
   * <pre> -A &lt;integer&gt;
   *  Descendant population size</pre>
   *
   * <pre> -U &lt;integer&gt;
   *  Number of runs</pre>
   *
   * <pre> -M
   *  Use mutation.
   *  (default true)</pre>
   *
   * <pre> -C
   *  Use cross-over.
   *  (default true)</pre>
   *
   * <pre> -O
   *  Use tournament selection (true) or maximum subpopulatin (false).
   *  (default false)</pre>
   *
   * <pre> -R &lt;seed&gt;
   *  Random number seed</pre>
   *
   * <pre> -mbc
   *  Applies a Markov Blanket correction to the network structure,
   *  after a network structure is learned. This ensures that all
   *  nodes in the network are part of the Markov blanket of the
   *  classifier node.</pre>
   *
   * <pre> -S [BAYES|MDL|ENTROPY|AIC|CROSS_CLASSIC|CROSS_BAYES]
   *  Score type (BAYES, BDeu, MDL, ENTROPY and AIC)</pre>
   *
   <!-- options-end -->
   *
   * @param options the list of options as an array of strings
   * @throws Exception if an option is not supported
   */
  public void setOptions(String[] options) throws Exception {
    String sPopulationSize = Utils.getOption('L', options);
    if (sPopulationSize.length() != 0) {
      setPopulationSize(Integer.parseInt(sPopulationSize));
    }
    String sDescendantPopulationSize = Utils.getOption('A', options);
    if (sDescendantPopulationSize.length() != 0) {
      setDescendantPopulationSize(Integer.parseInt(sDescendantPopulationSize));
    }
    String sRuns = Utils.getOption('U', options);
    if (sRuns.length() != 0) {
      setRuns(Integer.parseInt(sRuns));
    }
    String sSeed = Utils.getOption('R', options);
    if (sSeed.length() != 0) {
      setSeed(Integer.parseInt(sSeed));
    }
    setUseMutation(Utils.getFlag('M', options));
    setUseCrossOver(Utils.getFlag('C', options));
    setUseTournamentSelection(Utils.getFlag('O', options));
   
    super.setOptions(options);
  } // setOptions

  /**
   * Gets the current settings of the search algorithm.
   *
   * @return an array of strings suitable for passing to setOptions
   */
  public String[] getOptions() {
    String[] superOptions = super.getOptions();
    String[] options = new String[11 + superOptions.length];
    int current = 0;
   
    options[current++] = "-L";
    options[current++] = "" + getPopulationSize();

    options[current++] = "-A";
    options[current++] = "" + getDescendantPopulationSize();

    options[current++] = "-U";
    options[current++] = "" + getRuns();

    options[current++] = "-R";
    options[current++] = "" + getSeed();

    if (getUseMutation()) {
      options[current++] = "-M";
    }
    if (getUseCrossOver()) {
      options[current++] = "-C";
    }
    if (getUseTournamentSelection()) {
      options[current++] = "-O";
    }

    // insert options from parent class
    for (int iOption = 0; iOption < superOptions.length; iOption++) {
      options[current++] = superOptions[iOption];
    }

    // Fill up rest with empty strings, not nulls!
    while (current < options.length) {
      options[current++] = "";
    }
    return options;
  } // getOptions

  /**
   * @return whether cross-over is used
   */
  public boolean getUseCrossOver() {
    return m_bUseCrossOver;
  }

  /**
   * @return whether mutation is used
   */
  public boolean getUseMutation() {
    return m_bUseMutation;
  }

  /**
   * @return descendant population size
   */
  public int getDescendantPopulationSize() {
    return m_nDescendantPopulationSize;
  }

  /**
   * @return population size
   */
  public int getPopulationSize() {
    return m_nPopulationSize;
  }

  /**
   * @param bUseCrossOver sets whether cross-over is used
   */
  public void setUseCrossOver(boolean bUseCrossOver) {
    m_bUseCrossOver = bUseCrossOver;
  }

  /**
   * @param bUseMutation sets whether mutation is used
   */
  public void setUseMutation(boolean bUseMutation) {
    m_bUseMutation = bUseMutation;
  }

  /**
   * @return whether Tournament Selection (true) or Maximum Sub-Population (false) should be used
   */
  public boolean getUseTournamentSelection() {
    return m_bUseTournamentSelection;
  }

  /**
   * @param bUseTournamentSelection sets whether Tournament Selection or Maximum Sub-Population should be used
   */
  public void setUseTournamentSelection(boolean bUseTournamentSelection) {
    m_bUseTournamentSelection = bUseTournamentSelection;
  }

  /**
   * @param iDescendantPopulationSize sets descendant population size
   */
  public void setDescendantPopulationSize(int iDescendantPopulationSize) {
    m_nDescendantPopulationSize = iDescendantPopulationSize;
  }

  /**
   * @param iPopulationSize sets population size
   */
  public void setPopulationSize(int iPopulationSize) {
    m_nPopulationSize = iPopulationSize;
  }

  /**
  * @return random number seed
  */
  public int getSeed() {
    return m_nSeed;
  } // getSeed

  /**
   * Sets the random number seed
   * @param nSeed The number of the seed to set
   */
  public void setSeed(int nSeed) {
    m_nSeed = nSeed;
  } // setSeed

  /**
   * This will return a string describing the classifier.
   * @return The string.
   */
  public String globalInfo() {
    return "This Bayes Network learning algorithm uses genetic search for finding a well scoring " +
    "Bayes network structure. Genetic search works by having a population of Bayes network structures " +
    "and allow them to mutate and apply cross over to get offspring. The best network structure " +
    "found during the process is returned.";
  } // globalInfo
 
  /**
   * @return a string to describe the Runs option.
   */
  public String runsTipText() {
    return "Sets the number of generations of Bayes network structure populations.";
  } // runsTipText
 
  /**
   * @return a string to describe the Seed option.
   */
  public String seedTipText() {
    return "Initialization value for random number generator." +
    " Setting the seed allows replicability of experiments.";
  } // seedTipText

  /**
   * @return a string to describe the Population Size option.
   */
  public String populationSizeTipText() {
    return "Sets the size of the population of network structures that is selected each generation.";
  } // populationSizeTipText

  /**
   * @return a string to describe the Descendant Population Size option.
   */
  public String descendantPopulationSizeTipText() {
    return "Sets the size of the population of descendants that is created each generation.";
  } // descendantPopulationSizeTipText

  /**
   * @return a string to describe the Use Mutation option.
   */
  public String useMutationTipText() {
    return "Determines whether mutation is allowed. Mutation flips a bit in the bit " +
      "representation of the network structure. At least one of mutation or cross-over " +
      "should be used.";
  } // useMutationTipText

  /**
   * @return a string to describe the Use Cross-Over option.
   */
  public String useCrossOverTipText() {
    return "Determines whether cross-over is allowed. Cross over combined the bit " +
      "representations of network structure by taking a random first k bits of one" +
      "and adding the remainder of the other. At least one of mutation or cross-over " +
      "should be used.";
  } // useCrossOverTipText

  /**
   * @return a string to describe the Use Tournament Selection option.
   */
  public String useTournamentSelectionTipText() {
    return "Determines the method of selecting a population. When set to true, tournament " +
      "selection is used (pick two at random and the highest is allowed to continue). " +
      "When set to false, the top scoring network structures are selected.";
  } // useTournamentSelectionTipText

  /**
   * Returns the revision string.
   *
   * @return    the revision
   */
  public String getRevision() {
    return RevisionUtils.extract("$Revision: 1.5 $");
  }
} // GeneticSearch
TOP

Related Classes of weka.classifiers.bayes.net.search.local.GeneticSearch$BayesNetRepresentation

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.