Package org.encog.workbench.tabs.training

Source Code of org.encog.workbench.tabs.training.BasicTrainingProgress

/*
* Encog(tm) Workbench v3.0
* http://www.heatonresearch.com/encog/
* http://code.google.com/p/encog-java/
* Copyright 2008-2011 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*  
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.workbench.tabs.training;

import java.awt.BorderLayout;
import java.awt.Color;
import java.awt.Font;
import java.awt.FontMetrics;
import java.awt.Graphics;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.io.File;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

import javax.swing.JButton;
import javax.swing.JComboBox;
import javax.swing.JPanel;

import org.encog.StatusReportable;
import org.encog.mathutil.randomize.Distort;
import org.encog.ml.MLMethod;
import org.encog.ml.MLResettable;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.svm.training.search.SVMSearchJob;
import org.encog.ml.train.MLTrain;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.persist.EncogDirectoryPersistence;
import org.encog.util.Format;
import org.encog.util.file.FileUtil;
import org.encog.util.validate.ValidateNetwork;
import org.encog.workbench.EncogWorkBench;
import org.encog.workbench.frames.document.tree.ProjectEGFile;
import org.encog.workbench.tabs.EncogCommonTab;
import org.encog.workbench.util.EncogFonts;
import org.encog.workbench.util.TimeSpanFormatter;

/**
* Common dialog box that displays the training progress for most of the
* training methods. A chart is shown that displays the decline of the error.
* Additionally, the user can start, stop or abort the training process.
*
* @author jheaton
*
*/
public class BasicTrainingProgress extends EncogCommonTab implements Runnable,
    ActionListener, StatusReportable {

  private final JComboBox comboReset;

  /**
   * The start button.
   */
  private final JButton buttonStart;

  /**
   * The stop button.
   */
  private final JButton buttonStop;

  /**
   * The close button.
   */
  private final JButton buttonClose;

  /**
   * The body of the dialog box is stored in this panel.
   */
  private final JPanel panelBody;

  /**
   * The buttons are hold in this panel.
   */
  private final JPanel panelButtons;

  /**
   * The background thread that processes training.
   */
  private Thread thread;

  /**
   * Has training been canceled.
   */
  private boolean cancel;

  /**
   * Panel used to display the current status of the training.
   */
  protected TrainingStatusPanel statusPanel;

  /**
   * Panel that holds the chart.
   */
  protected ChartPane chartPanel;

  /**
   * The training method being used.
   */
  private MLTrain train;

  /**
   * The training data.
   */
  private MLDataSet trainingData;

  /**
   * The max allowed error.
   */
  private double maxError;

  /**
   * What iteration are we on.
   */
  private int iteration;

  /**
   * The font to use for headings.
   */
  private Font headFont;

  /**
   * The font for body text.
   */
  private Font bodyFont;

  /**
   * The current error.
   */
  private double currentError;

  /**
   * The last error.
   */
  private double lastError;

  /**
   * The error improvement.
   */
  private double errorImprovement;

  /**
   * When was training started.
   */
  private Date started;

  /**
   * When was the last update.
   */
  private long lastUpdate;

  /**
   * Number formatter.
   */
  private final NumberFormat nf = NumberFormat.getInstance();

  /**
   * Shorter number formatter.
   */
  private final NumberFormat nfShort = NumberFormat.getInstance();

  /**
   * The current performance count, how many iterations per minute.
   */
  private int performanceCount;

  /**
   * What time did we last take a performance sample.
   */
  private Date performanceLast;

  /**
   * What was the iteration number, the last time a performance sample was
   * taken.
   */
  private int performanceLastIteration;

  private String status;

  /**
   * Should the dialog box exit? Are we waiting for training to shut down
   * first.
   */
  private boolean shouldExit;

  private AtomicInteger resetOption = new AtomicInteger(-1);
 
  private boolean error = false;
 
 
  private String lastMessage = "";
 

  /**
   * Construct the dialog box.
   *
   * @param owner
   *            The owner of the dialog box.
   */
  public BasicTrainingProgress(MLTrain train, ProjectEGFile method,
      MLDataSet trainingData) {
    super(method);
   
    if( method instanceof MLMethod ) {
      ValidateNetwork.validateMethodToData((MLMethod)method.getObject(), trainingData);
    }
    List<String> list = new ArrayList<String>();
    list.add("<Select Option>");
    list.add("Reset");
    list.add("Perturb 1%");
    list.add("Perturb 5%");
    list.add("Perturb 10%");
    list.add("Perturb 15%");
    list.add("Perturb 20%");
    list.add("Perturb 50%");

    this.comboReset = new JComboBox(list.toArray());

    this.train = train;
    this.trainingData = trainingData;

    this.buttonStart = new JButton("Start");
    this.buttonStop = new JButton("Stop");
    this.buttonClose = new JButton("Close");

    this.buttonStart.addActionListener(this);
    this.buttonStop.addActionListener(this);
    this.buttonClose.addActionListener(this);
    this.comboReset.addActionListener(this);

    setLayout(new BorderLayout());
    this.panelBody = new JPanel();
    this.panelButtons = new JPanel();
    this.panelButtons.add(this.buttonStart);
    this.panelButtons.add(this.buttonStop);
    this.panelButtons.add(this.buttonClose);
    this.panelButtons.add(this.comboReset);
    add(this.panelBody, BorderLayout.CENTER);
    add(this.panelButtons, BorderLayout.SOUTH);
    this.panelBody.setLayout(new BorderLayout());
    this.panelBody.add(this.statusPanel = new TrainingStatusPanel(this),
        BorderLayout.NORTH);
    this.panelBody.add(this.chartPanel = new ChartPane(),
        BorderLayout.CENTER);
    this.buttonStop.setEnabled(false);

    this.shouldExit = false;
    this.bodyFont = EncogFonts.getInstance().getBodyFont();
    this.headFont = EncogFonts.getInstance().getHeadFont();
    this.status = "Ready to Start";
   
    if( train instanceof SVMSearchJob ) {
      ((SVMSearchJob)train).setReport(this);
    }
  }

  private void performClose() {
   
    if( error )
      return;

    if (EncogWorkBench.askQuestion("Training", "Save the training?")) {

      if( this.getEncogObject()!=null ) {

        ((ProjectEGFile)this.getEncogObject()).save(train.getMethod());
        if( this.getParentTab()!=null ) {
          this.getParentTab().setEncogObject(this.getEncogObject());
        }
      }
     
      if( this.train.canContinue() ) {
        TrainingContinuation cont = train.pause()
        String name = FileUtil.getFileName(this.getEncogObject().getFile());
        name = FileUtil.forceExtension(name + "-cont", "eg");
        File path = new File(name);
        EncogWorkBench.getInstance().save(path, cont);
        EncogWorkBench.getInstance().refresh();
      }
           
      EncogWorkBench.getInstance().refresh();
    } else {
      if( this.getEncogObject()!=null) {
        ((ProjectEGFile)this.getEncogObject()).revert();
      }
    }
  }

  /**
   * Track button presses.
   *
   * @param e
   *            Event info.
   */
  public void actionPerformed(final ActionEvent e) {
    if (e.getSource() == this.buttonClose) {
      dispose();
    } else if (e.getSource() == this.buttonStart) {
      performStart();
    } else if (e.getSource() == this.buttonStop) {
      performStop();
    } else if (e.getSource() == this.comboReset) {
      this.resetOption.set(this.comboReset.getSelectedIndex() - 1);
      this.comboReset.setSelectedIndex(0);
    }
  }

  public boolean close() {
    if (this.thread == null) {
      performClose();
      return true;
    } else {
      this.shouldExit = true;
      this.cancel = true;
      return false;
    }
  }

  /**
   * @return the train
   */
  public MLTrain getTrain() {
    return this.train;
  }

  /**
   * @return the trainingData
   */
  public MLDataSet getTrainingData() {
    return this.trainingData;
  }

  public void paintStatus(final Graphics g) {
    g.setColor(Color.white);
    final int width = getWidth();
    final int height = getHeight();
    g.fillRect(0, 0, width, height);
    g.setColor(Color.black);
    g.setFont(this.headFont);
    final FontMetrics fm = g.getFontMetrics();
    int y = fm.getHeight();
    g.drawString("Iteration:", 10, y);
    y += fm.getHeight();
    g.drawString("Current Error:", 10, y);
    y += fm.getHeight();
    g.drawString("Error Improvement:", 10, y);
    y += fm.getHeight();
    g.drawString("Message:", 10, y);

    y = fm.getHeight();
    g.drawString("Elapsed Time:", 400, y);
    y += fm.getHeight();
    g.drawString("Performance:", 400, y);

    y = fm.getHeight();
    g.setFont(this.bodyFont);
    String str = this.nf.format(this.iteration);

    str += " (" + this.status + ")";

    g.drawString(str, 150, y);
    y += fm.getHeight();
    g.drawString(Format.formatPercent(this.currentError), 150, y);
    y += fm.getHeight();
    g.drawString(Format.formatPercent(this.errorImprovement), 150, y);
    y += fm.getHeight();
    g.drawString(this.lastMessage, 150, y);

    y = fm.getHeight();
    long seconds = 0;
    if (this.started != null) {
      final Date now = new Date();
      seconds = (now.getTime() - this.started.getTime()) / 1000;
    }
    g.drawString(TimeSpanFormatter.formatTime(seconds), 500, y);

    y += fm.getHeight();

    if (this.performanceCount == -1) {
      str = "  (calculating performance)";
    } else {
      final double d = this.performanceCount / 60.0;
      str = "  (" + this.nfShort.format(d) + "/sec)";
    }
   

    g.drawString(str, 500, y);

  }

  /**
   * Start the training.
   */
  private void performStart() {

    /*
     * if (!EncogWorkBench.getInstance().getMainWindow().getTabManager()
     * .checkTrainingOrNetworkOpen()) return;
     */

    this.started = new Date();
    this.performanceLast = this.started;
    this.performanceCount = -1;
    this.performanceLastIteration = 0;

    this.buttonStart.setEnabled(false);
    this.buttonStop.setEnabled(true);
    this.cancel = false;
    this.status = "Started";
    this.thread = new Thread(this);
    this.thread.start();
  }

  /**
   * Request that the training stop.
   */
  private void performStop() {
    this.status = "Canceled";
    this.cancel = true;
  }

  /**
   * Repaint the dialog box.
   */
  public void redraw() {
    this.statusPanel.repaint();
    this.lastUpdate = System.currentTimeMillis();
    this.chartPanel.addData(this.iteration, this.train.getError(),
        this.errorImprovement);
  }

  /**
   * Process the background thread. Cycle through training iterations. If the
   * cancel flag is set, then exit.
   */
  public void run() {

    try {

      startup();

      // this.method = (MLMethod) method.clone();

      // see if we need to continue training.
      if( this.train.canContinue() ) {
        String name = FileUtil.getFileName( getEncogObject().getFile() );
        name+="-cont.eg";
        File path = new File(name);
        if( path.exists() ) {
          try {
            TrainingContinuation cont = (TrainingContinuation)EncogDirectoryPersistence.loadObject(path);
            train.resume(cont);
          } catch(Exception ex) {
            EncogWorkBench.displayError("Trainning Resume Incompatible", "Cannot use previous training data, training will begin as best it can.");
            path.delete();
          }
        }
      }

      while (!this.cancel) {
        this.iteration++;
        this.lastError = this.train.getError();

        if (this.resetOption.get() != -1) {
          MLMethod method = null;
         
          if( getEncogObject() instanceof ProjectEGFile ) {
            method = (MLMethod)((ProjectEGFile)getEncogObject()).getObject();
          }
         
          if( method==null )
          {
            this.resetOption.set(-1);
            EncogWorkBench.displayError("Error", "This machine learning method cannot be reset or randomized.");
            return;
          }
         
          switch (this.resetOption.get()) {
          case 0:
            if (method instanceof MLResettable) {
              ((MLResettable)method).reset();
            } else {
              EncogWorkBench
                  .displayError("Error",
                      "This Machine Learning method cannot be reset.");
            }
            break;
          case 1:
            (new Distort(0.01)).randomize(method);
            break;
          case 2:
            (new Distort(0.05)).randomize(method);
            break;
          case 3:
            (new Distort(0.1)).randomize(method);
            break;
          case 4:
            (new Distort(0.15)).randomize(method);
            break;
          case 5:
            (new Distort(0.20)).randomize(method);
            break;
          case 6:
            (new Distort(0.50)).randomize(method);
            break;

          }

          this.resetOption.set(-1);
        }

        this.train.iteration();

        this.currentError = this.train.getError();

        if (this.currentError < this.maxError) {
          this.status = "Max Error Reached";
          this.cancel = true;
        }

        if (this.train.isTrainingDone()) {
          this.status = "Training Complete";
          this.cancel = true;
        }
       
        this.errorImprovement = (this.lastError - this.currentError)
            / this.lastError;
        if( Double.isInfinite(this.errorImprovement) || Double.isNaN(this.errorImprovement)) {
          this.errorImprovement = 100.0;
        }
       
       
        if (System.currentTimeMillis() - this.lastUpdate > 1000
            || this.cancel) {
          redraw();
        }

        final Date now = new Date();
        if (now.getTime() - this.performanceLast.getTime() > 60000) {
          this.performanceLast = now;
          this.performanceCount = this.iteration
              - this.performanceLastIteration;
          this.performanceLastIteration = this.iteration;
        }
      }
      this.train.finishTraining();
      shutdown();
      stopped();

      if (this.shouldExit) {
        dispose();
      }
    } catch (Throwable t) {
      this.error = true;
      EncogWorkBench.displayError("Error", t, this.getEncogObject(),this.trainingData);
      shutdown();
      stopped();
      dispose();
    }
  }

  /**
   * @param maxError
   *            the maxError to set
   */
  public void setMaxError(final double maxError) {
    this.maxError = maxError;
  }

  /**
   * @param train
   *            the train to set
   */
  public void setTrain(final MLTrain train) {
    this.train = train;
  }

  /**
   * @param trainingData
   *            the trainingData to set
   */
  public void setTrainingData(final MLDataSet trainingData) {
    this.trainingData = trainingData;
  }

  /**
   * Implemented by subclasses to perform any shutdown after training.
   */
  public void shutdown() {

  }

  /**
   * Implemented by subclasses to perform any activity before training.
   */
  public void startup() {

  }

  /**
   * Called when training has stopped.
   */
  private void stopped() {
    this.thread = null;
    this.buttonStart.setEnabled(true);
    this.buttonStop.setEnabled(false);
    this.cancel = true;
  }

  @Override
  public String getName() {
    return "Training Progress";
  }

  @Override
  public void report(int total, int current, String message) {
    this.lastMessage = message;
    redraw();   
  }
}
TOP

Related Classes of org.encog.workbench.tabs.training.BasicTrainingProgress

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.