Package

Source Code of RegionFrame

import htm.Region;
import htm.Region.RegionStats;

import java.awt.Color;
import java.awt.Dimension;
import java.awt.GridBagConstraints;
import java.awt.GridBagLayout;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.ItemEvent;
import java.awt.event.ItemListener;
import java.awt.event.KeyEvent;
import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Random;

import javax.swing.BorderFactory;
import javax.swing.JCheckBoxMenuItem;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JMenu;
import javax.swing.JMenuBar;
import javax.swing.JMenuItem;
import javax.swing.JOptionPane;
import javax.swing.JPanel;
import javax.swing.KeyStroke;
import javax.swing.event.ChangeEvent;
import javax.swing.event.ChangeListener;

/**
* A top-level UI window that displays a menu with options for running
* time steps of an HTM Region.  Within the window is a RegionPanel which
* renders a visualization of the current state of the Region.
* Currently the Region parameters and inputs are hardcoded.
* The input is a text file defined by the INPUT_FILE static variable.
* The Region is defined as having 25x25 columns.  When reading from the
* text file, each word is used to seed a random number the used to select
* a random 40 bits (or columns directly if spatial pooling is hardcoded)
* in a byte array to be 1's (the rest 0).
* Finally, the UI displays a set of statistics about the Region including
* how many segments and synapses are present over time.
* @author Barry Maturkanich
*/
public class RegionFrame extends JPanel {

  private static final long serialVersionUID = 1L;
 
  private static final String INPUT_FILE =
    //"C:/Apps/Numenta/input.txt";
    "C:/Apps/Numenta/classics/A Tale of Two Cities (Charles Dickens).txt";
 
  private static final String punctuation = "\"!#$%&'()*+,-./:;<=>?@[]^_`{|}~\"";
  private static final String[] PUNC = {"!","\"","(",")",".",";","?","[","]","{","}"};
 
  private ArrayDeque<String> _lastLine = null;
  private boolean _isWordBreak = false;
  private Region _region;
  private Dimension _regionShape;
  private RegionPanel _regionPanel;
  private Random _rand = new Random();
  private byte[] _inputData;
  private Thread _thread;
  private final ChangeListener _statListener;
  private BufferedReader _buffer;
  private final NumberFormat _nf;
 
  public static void main(String[] args) {
    // Schedule a job for the event-dispatching thread:
    // creating and showing this application's GUI.
    javax.swing.SwingUtilities.invokeLater(new Runnable() {
      public void run() {
        new RegionFrame();
      }
    });
  }
 
  public RegionFrame() {
    super(new GridBagLayout());
    setBorder(BorderFactory.createLineBorder(Color.black));
    _nf = DecimalFormat.getInstance();
    _nf.setMaximumFractionDigits(3);
   
    //Status labels:
    //accuracies (pred acc, active acc, last+average)
    //total segments, average per cell, most any cell
    //total segUpdates, average per cell, most any cell
    //total synapses, average per segment, most any segment
    JPanel labelArea = new JPanel(new GridBagLayout());
   
    JPanel accArea = new JPanel(new GridBagLayout());
    accArea.setBorder(BorderFactory.createTitledBorder("Region Accuracy"));
    final JLabel predAccLabel0 = new JLabel(" Prediction Acc: ");
    final JLabel predAccLabel = new JLabel("100%");
    final JLabel activeAccLabel0 = new JLabel(" Activation Acc: ");
    final JLabel activeAccLabel = new JLabel("100%");
    final JLabel predMAccLabel0 = new JLabel("  Mean Pred Acc: ");
    final JLabel predMAccLabel = new JLabel("0%");
    final JLabel activeMAccLabel0 = new JLabel("  Mean Active Acc: ");
    final JLabel activeMAccLabel = new JLabel("0%");
    accArea.add(predAccLabel0, setGrid(0,0));
    accArea.add(predAccLabel, setGrid(1,0));
    accArea.add(activeAccLabel0, setGrid(0,1));
    accArea.add(activeAccLabel, setGrid(1,1));
    accArea.add(predMAccLabel0, setGrid(0,2));
    accArea.add(predMAccLabel, setGrid(1,2));
    accArea.add(activeMAccLabel0, setGrid(0,3));
    accArea.add(activeMAccLabel, setGrid(1,3));
   
    final JLabel[] totalSegLabel = new JLabel[3];
    final JLabel[] meanSegLabel = new JLabel[3];
    final JLabel[] medSegLabel = new JLabel[3];
    final JLabel[] mostSegLabel = new JLabel[3];
   
    JPanel segArea = new JPanel(new GridBagLayout());
    segArea.setBorder(BorderFactory.createTitledBorder("Cell Segments"));
    final JLabel totalSegLabel0 = new JLabel(" Total Segments: ");
    totalSegLabel[0] = new JLabel("0");
    final JLabel meanSegLabel0 = new JLabel(" Mean Segments: ");
    meanSegLabel[0] = new JLabel("0");
    final JLabel medSegLabel0 = new JLabel(" Median Segments: ");
    medSegLabel[0] = new JLabel("0");
    final JLabel mostSegLabel0 = new JLabel(" Most Segments: ");
    mostSegLabel[0] = new JLabel("0");
    segArea.add(totalSegLabel0, setGrid(0,0));
    segArea.add(totalSegLabel[0], setGrid(1,0));
    segArea.add(meanSegLabel0, setGrid(0,1));
    segArea.add(meanSegLabel[0], setGrid(1,1));
    segArea.add(medSegLabel0, setGrid(0,2));
    segArea.add(medSegLabel[0], setGrid(1,2));
    segArea.add(mostSegLabel0, setGrid(0,3));
    segArea.add(mostSegLabel[0], setGrid(1,3));
   
    JPanel seg1Area = new JPanel(new GridBagLayout());
    seg1Area.setBorder(BorderFactory.createTitledBorder("Sequence Segments"));
    final JLabel totalSeg1Label0 = new JLabel(" Total Sequence: ");
    totalSegLabel[1] = new JLabel("0");
    final JLabel meanSeg1Label0 = new JLabel(" Mean Sequence: ");
    meanSegLabel[1] = new JLabel("0");
    final JLabel medSeg1Label0 = new JLabel(" Median Sequence: ");
    medSegLabel[1] = new JLabel("0");
    final JLabel mostSeg1Label0 = new JLabel(" Most Sequence: ");
    mostSegLabel[1] = new JLabel("0");
    seg1Area.add(totalSeg1Label0, setGrid(0,0));
    seg1Area.add(totalSegLabel[1], setGrid(1,0));
    seg1Area.add(meanSeg1Label0, setGrid(0,1));
    seg1Area.add(meanSegLabel[1], setGrid(1,1));
    seg1Area.add(medSeg1Label0, setGrid(0,2));
    seg1Area.add(medSegLabel[1], setGrid(1,2));
    seg1Area.add(mostSeg1Label0, setGrid(0,3));
    seg1Area.add(mostSegLabel[1], setGrid(1,3));
   
    JPanel seg2Area = new JPanel(new GridBagLayout());
    seg2Area.setBorder(BorderFactory.createTitledBorder("Regular Segments"));
    final JLabel totalSeg2Label0 = new JLabel(" Total Regular: ");
    totalSegLabel[2] = new JLabel("0");
    final JLabel meanSeg2Label0 = new JLabel(" Mean Regular: ");
    meanSegLabel[2] = new JLabel("0");
    final JLabel medSeg2Label0 = new JLabel(" Median Regular: ");
    medSegLabel[2] = new JLabel("0");
    final JLabel mostSeg2Label0 = new JLabel(" Most Regular: ");
    mostSegLabel[2] = new JLabel("0");
    seg2Area.add(totalSeg2Label0, setGrid(0,0));
    seg2Area.add(totalSegLabel[2], setGrid(1,0));
    seg2Area.add(meanSeg2Label0, setGrid(0,1));
    seg2Area.add(meanSegLabel[2], setGrid(1,1));
    seg2Area.add(medSeg2Label0, setGrid(0,2));
    seg2Area.add(medSegLabel[2], setGrid(1,2));
    seg2Area.add(mostSeg2Label0, setGrid(0,3));
    seg2Area.add(mostSegLabel[2], setGrid(1,3));
   
    JPanel pendArea = new JPanel(new GridBagLayout());
    pendArea.setBorder(BorderFactory.createTitledBorder("Pending Segments"));
    final JLabel totalPendLabel0 = new JLabel(" Pending Segments: ");
    final JLabel totalPendLabel = new JLabel("0");
    final JLabel meanPendLabel0 = new JLabel(" Average Pending: ");
    final JLabel meanPendLabel = new JLabel("0");
    final JLabel medPendLabel0 = new JLabel(" Median Pending: ");
    final JLabel medPendLabel = new JLabel("0");
    final JLabel mostPendLabel0 = new JLabel(" Most Pending: ");
    final JLabel mostPendLabel = new JLabel("0");
    pendArea.add(totalPendLabel0, setGrid(0,0));
    pendArea.add(totalPendLabel, setGrid(1,0));
    pendArea.add(meanPendLabel0, setGrid(0,1));
    pendArea.add(meanPendLabel, setGrid(1,1));
    pendArea.add(medPendLabel0, setGrid(0,2));
    pendArea.add(medPendLabel, setGrid(1,2));
    pendArea.add(mostPendLabel0, setGrid(0,3));
    pendArea.add(mostPendLabel, setGrid(1,3));
   
    final JLabel[] totalSynLabel = new JLabel[3];
    final JLabel[] meanSynLabel = new JLabel[3];
    final JLabel[] medSynLabel = new JLabel[3];
    final JLabel[] mostSynLabel = new JLabel[3];
   
    JPanel synArea = new JPanel(new GridBagLayout());
    synArea.setBorder(BorderFactory.createTitledBorder("Synapses"));
    final JLabel totalSynLabel0 = new JLabel(" Total Synapses: ");
    totalSynLabel[0] = new JLabel("0");
    final JLabel meanSynLabel0 = new JLabel(" Mean Synapses: ");
    meanSynLabel[0] = new JLabel("0");
    final JLabel medSynLabel0 = new JLabel(" Median Synapses: ");
    medSynLabel[0] = new JLabel("0");
    final JLabel mostSynLabel0 = new JLabel(" Most Synapses: ");
    mostSynLabel[0] = new JLabel("0");
    synArea.add(totalSynLabel0, setGrid(0,0));
    synArea.add(totalSynLabel[0], setGrid(1,0));
    synArea.add(meanSynLabel0, setGrid(0,1));
    synArea.add(meanSynLabel[0], setGrid(1,1));
    synArea.add(medSynLabel0, setGrid(0,2));
    synArea.add(medSynLabel[0], setGrid(1,2));
    synArea.add(mostSynLabel0, setGrid(0,3));
    synArea.add(mostSynLabel[0], setGrid(1,3));
   
    JPanel syn1Area = new JPanel(new GridBagLayout());
    syn1Area.setBorder(BorderFactory.createTitledBorder("Sequence Synapses"));
    final JLabel totalSyn1Label0 = new JLabel(" Total SeqSyn: ");
    totalSynLabel[1] = new JLabel("0");
    final JLabel meanSyn1Label0 = new JLabel(" Mean SeqSyn: ");
    meanSynLabel[1] = new JLabel("0");
    final JLabel medSyn1Label0 = new JLabel(" Median SeqSyn: ");
    medSynLabel[1] = new JLabel("0");
    final JLabel mostSyn1Label0 = new JLabel(" Most SeqSyn: ");
    mostSynLabel[1] = new JLabel("0");
    syn1Area.add(totalSyn1Label0, setGrid(0,0));
    syn1Area.add(totalSynLabel[1], setGrid(1,0));
    syn1Area.add(meanSyn1Label0, setGrid(0,1));
    syn1Area.add(meanSynLabel[1], setGrid(1,1));
    syn1Area.add(medSyn1Label0, setGrid(0,2));
    syn1Area.add(medSynLabel[1], setGrid(1,2));
    syn1Area.add(mostSyn1Label0, setGrid(0,3));
    syn1Area.add(mostSynLabel[1], setGrid(1,3));
   
    JPanel syn2Area = new JPanel(new GridBagLayout());
    syn2Area.setBorder(BorderFactory.createTitledBorder("Regular Synapses"));
    final JLabel totalSyn2Label0 = new JLabel(" Total RegSyn: ");
    totalSynLabel[2] = new JLabel("0");
    final JLabel meanSyn2Label0 = new JLabel(" Mean RegSyn: ");
    meanSynLabel[2] = new JLabel("0");
    final JLabel medSyn2Label0 = new JLabel(" Median RegSyn: ");
    medSynLabel[2] = new JLabel("0");
    final JLabel mostSyn2Label0 = new JLabel(" Most RegSyn: ");
    mostSynLabel[2] = new JLabel("0");
    syn2Area.add(totalSyn2Label0, setGrid(0,0));
    syn2Area.add(totalSynLabel[2], setGrid(1,0));
    syn2Area.add(meanSyn2Label0, setGrid(0,1));
    syn2Area.add(meanSynLabel[2], setGrid(1,1));
    syn2Area.add(medSyn2Label0, setGrid(0,2));
    syn2Area.add(medSynLabel[2], setGrid(1,2));
    syn2Area.add(mostSyn2Label0, setGrid(0,3));
    syn2Area.add(mostSynLabel[2], setGrid(1,3));
   
    _statListener = new ChangeListener() {
      public void stateChanged(ChangeEvent e) {
        RegionStats stats = _region.getStats();
       
        activeAccLabel.setText(_nf.format(stats.activationAccuracy*100.0)+"%");
        predAccLabel.setText(_nf.format(stats.predicationAccuracy*100.0)+"%");
       
        for(int i=0; i<3; ++i) {
          totalSegLabel[i].setText(String.valueOf(stats.totalSegments[i]));
          meanSegLabel[i].setText(_nf.format(stats.meanSegments[i]));
          medSegLabel[i].setText(String.valueOf(stats.medianSegments[i]));
          mostSegLabel[i].setText(String.valueOf(stats.mostSegments[i]));
         
          totalSynLabel[i].setText(String.valueOf(stats.totalSynapses[i]));
          meanSynLabel[i].setText(_nf.format(stats.meanSynapses[i]));
          medSynLabel[i].setText(String.valueOf(stats.medianSynapses[i]));
          mostSynLabel[i].setText(String.valueOf(stats.mostSynapses[i]));
        }
       
        totalPendLabel.setText(String.valueOf(stats.pendingSegments));
        meanPendLabel.setText(_nf.format(stats.meanPending));
        medPendLabel.setText(String.valueOf(stats.medianPending));
        mostPendLabel.setText(String.valueOf(stats.mostPending));
      }
    };
   
    labelArea.add(accArea, setGrid(0,0, GridBagConstraints.BOTH));
    labelArea.add(segArea, setGrid(1,0, GridBagConstraints.BOTH));
    labelArea.add(seg1Area, setGrid(2,0, GridBagConstraints.BOTH));
    labelArea.add(seg2Area, setGrid(3,0, GridBagConstraints.BOTH));
    labelArea.add(pendArea, setGrid(4,0, GridBagConstraints.BOTH));
    labelArea.add(synArea, setGrid(5,0, GridBagConstraints.BOTH));
    labelArea.add(syn1Area, setGrid(6,0, GridBagConstraints.BOTH));
    labelArea.add(syn2Area, setGrid(7,0, GridBagConstraints.BOTH));
    super.add(labelArea, setGrid(0,0, GridBagConstraints.HORIZONTAL));
   
    _region = createRegion();
    try {
      _buffer = new BufferedReader(new FileReader(INPUT_FILE));
    }catch(FileNotFoundException e) {
      e.printStackTrace();
    }
   
    _regionPanel = new RegionPanel(_region);
    super.add(_regionPanel, setGrid(0,1, GridBagConstraints.BOTH));
   
    JFrame f = new JFrame("Region Visualizer");
    f.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
    f.setJMenuBar(createJMenuBar());
    f.add(this);
    f.pack();
    f.setVisible(true);
  }
 
  private Region createRegion() {
    _regionShape = new Dimension(25,25);
    int inputSizeX = _regionShape.width;
    int inputSizeY = _regionShape.height;
    int colGridSizeX = inputSizeX;
    int colGridSizeY = inputSizeY;
    float pctInputPerCol = 0.2f;
    float pctMinOverlap = 0.08f;
    int localityRadius = 0;
    float pctLocalActivity = 0.25f;
    int cellsPerCol = 4;
    int segActiveThreshold = 3;
    int newSynapseCount = 5;
   
    _inputData = new byte[inputSizeX*inputSizeY];
   
    Region.FULL_DEFAULT_SPATIAL_PERMANENCE = true;
    final Region region = new Region(inputSizeX, inputSizeY,
        colGridSizeX, colGridSizeY,
        pctInputPerCol, pctMinOverlap, localityRadius,
        pctLocalActivity, cellsPerCol, segActiveThreshold,
        newSynapseCount, _inputData);
    region.setSpatialHardcoded(true);
    region.setSpatialLearning(false);
    region.setTemporalLearning(true);
   
    return region;
  }
 
  private Thread createThread() {
    return new Thread(new Runnable() {
      public void run() {
        //hide the region graphics while running as we don't want to wait
        //for repainting each time-step.  plus since we're on a non-UI thread
        //we would need to perform repaints on the UI-thread and thus also
        //wait for thread sync each time step
        _regionPanel.setVisible(false);
        try {
          int i=0;
          String word = readWordFromFile(_buffer);
          while(word!=null && !Thread.interrupted()) {
            //System.out.println("word: "+word);
            getWordRepresentation(word);
           
            //javax.swing.SwingUtilities.invokeAndWait(runRegion);
            _region.runOnce();
            if(i % 10 == 0)//update UI stats once every 10 runs for speed
              _statListener.stateChanged(new ChangeEvent(_region));
            ++i;
           
            word = readWordFromFile(_buffer);
          }
        } catch(IOException io) {
          io.printStackTrace();//fall through and terminate thread
        } finally {
          _regionPanel.setVisible(true);
          System.out.println("thread ended");
        }
      }
    });
  }
 
  private JMenuBar createJMenuBar() {
    JMenuBar mainMenuBar = new JMenuBar();
    JMenu menu1 = new JMenu("Run");
    mainMenuBar.add(menu1);
   
    // Creating the MenuItems
    JMenuItem stepItem = new JMenuItem("Forward 1 Step", KeyEvent.VK_F);
    stepItem.setAccelerator(KeyStroke.getKeyStroke(KeyEvent.VK_F,0));
    stepItem.addActionListener(new ActionListener() {
      public void actionPerformed(ActionEvent e) {
        try {
          String word = readWordFromFile(_buffer);
          if(word==null)
            return;
          System.out.println("word: "+word);
          getWordRepresentation(word);
          _region.runOnce();
          _statListener.stateChanged(new ChangeEvent(_region));
          _regionPanel.repaint();
        }catch(IOException e1) {
          JOptionPane.showMessageDialog(
              RegionFrame.this, "Failure to read the input file.");
        }
      }
    });
    menu1.add(stepItem);
   
    final JCheckBoxMenuItem playItem = new JCheckBoxMenuItem("Play");
    playItem.setMnemonic(KeyEvent.VK_P);
    playItem.setAccelerator(KeyStroke.getKeyStroke(KeyEvent.VK_P,0));
    playItem.addItemListener(new ItemListener() {
      public void itemStateChanged(ItemEvent e) {
        if(_thread==null || !_thread.isAlive()) {
          _thread = createThread();
          _thread.start();
          System.out.println("Thread running");
        }
        else {
          _thread.interrupt();
          _regionPanel.repaint();
        }
      }
    });
    menu1.add(playItem);
   
    return mainMenuBar;
  }
 
  private GridBagConstraints setGrid(int gridx, int gridy) {
    return setGrid(gridx,gridy, GridBagConstraints.NONE);
  }
 
  private GridBagConstraints setGrid(int gridx, int gridy, int fill) {
    return setGrid(gridx, gridy, fill, GridBagConstraints.WEST);
  }
 
  private GridBagConstraints setGrid(int gridx, int gridy, int fill, int anchor) {
    GridBagConstraints gb = new GridBagConstraints();
    gb.gridx = gridx;
    gb.gridy = gridy;
    gb.fill = fill;
    gb.anchor = anchor;
    return gb;
  }
 
  //  def getWordRepresentation(self, word):
  //    """
  //    Calculate a (NumColumns choose ActiveColumns) bit representation of
  //    the specified word string.
  //    """
  //    nx,ny = self._regionShape
  //    ncol = nx*ny #num columns in region 0
  //    nact = 40 #num active columns to represent any input
  //    random.seed(word*2)
  //    ibits = random.sample(xrange(0,ncol), nact)
  //    #Create ncol-bit number where indicies in list are 1-bits
  //    array = numpy.array(numpy.zeros((ncol)), dtype=numpy.uint8)
  //    if word!="": #blank word means all zero input to break sequence
  //      array[ibits] = 1
  //    array = array.reshape(self._regionShape)
  //    #print "Rep: ",array #[:,0]
  //    return array
  /**
   * Calculate a (NumColumns choose ActiveColumns) bit representation of
   * the specified word string.
   */
  private void getWordRepresentation(String word) {
    int nx = _regionShape.width;
    int ny = _regionShape.height;
    int ncol = nx*ny;
    int nact = 40; //num active columns to represent any input
    _rand.setSeed(word.hashCode()*2);
   
    Arrays.fill(_inputData, (byte)0);
    if(word.equals(""))
      return; //blank word means all zero input to break sequence
   
    //choose nact random column indicies to represent word
    int r=0;
    while(r < nact) {
      int i = _rand.nextInt(ncol);
      if(_inputData[i]==0) {
        _inputData[i] = 1;
        r++;
      }
    }
  }
 
    /*def readWordFromFile(self):
      """ Read the next word from the last read line in the file. """
      if self._isWordBreak:
        self._isWordBreak = False
        return ""
     
      while self._fileLine==None or len(self._fileLine)==0:
        line = self._file.readline()
        if line=="":
          self._fileLine = None
          return None
        self._fileLine = line.split() #split on whitespace
       
      #parse line by popping next word and stripping punctuation
      wordf = self._fileLine.pop(0).lower()
      self._isWordBreak = wordf.endswith(PUNC)
      return wordf.translate(None, string.punctuation)
   */
  private String readWordFromFile(BufferedReader file) throws IOException {
    if(_isWordBreak) {
      _isWordBreak = false;
      return "";
    }
   
    while(_lastLine==null || _lastLine.size()==0) {
      String line = file.readLine();
      if(line==null) {
        _lastLine = null;
        return null;
      }
      String[] lineSplit = line.split("\\s+");
      _lastLine = new ArrayDeque<String>(lineSplit.length);
      for(int i=0; i<lineSplit.length; ++i)
        _lastLine.push(lineSplit[i]);
    }
   
    //parse line by popping next word and stripping punctuation
    String wordf = _lastLine.removeLast();
   
    _isWordBreak = false;
    for(int i=0; i<PUNC.length; ++i) {
      if(wordf.endsWith(PUNC[i])) {
        _isWordBreak = true;
        break;
      }
    }
   
    //remove all punctuation characters from the word
    StringBuilder sb = new StringBuilder();
    for(int c=0; c<wordf.length(); ++c) {
      char ch = wordf.charAt(c);
      boolean ok = true;
      for(int i=0; i<punctuation.length(); ++i) {
        if(ch==punctuation.charAt(i)) {
          ok = false;
          break;
        }
      }
      if(ok)
        sb.append(ch);
    }
   
    return sb.toString();
  }

}
TOP

Related Classes of RegionFrame

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.