Package org.encog.engine.network.train.prop

Source Code of org.encog.engine.network.train.prop.TrainFlatNetworkProp

/*
* Encog(tm) Core v2.5 - Java Version
* http://www.heatonresearch.com/encog/
* http://code.google.com/p/encog-java/
* Copyright 2008-2010 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*  
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/

package org.encog.engine.network.train.prop;

import org.encog.engine.EncogEngineError;
import org.encog.engine.concurrency.DetermineWorkload;
import org.encog.engine.concurrency.EngineConcurrency;
import org.encog.engine.concurrency.TaskGroup;
import org.encog.engine.data.EngineDataSet;
import org.encog.engine.data.EngineIndexableSet;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.network.train.TrainFlatNetwork;
import org.encog.engine.network.train.gradient.FlatGradientWorker;
import org.encog.engine.network.train.gradient.GradientWorkerCPU;
import org.encog.engine.util.EngineArray;
import org.encog.engine.util.IntRange;

/**
* Train a flat network using multithreading, and GPU support.
*
* The training data must be indexable, it will be broken into groups for each
* thread to process.
*
* At the end of each iteration the training from each thread is aggregated back
* to the neural network.
*
*/
public abstract class TrainFlatNetworkProp implements TrainFlatNetwork {

  /**
   * The number of threads to use.
   */
  private int numThreads;

  /**
   * The gradients.
   */
  protected double[] gradients;

  /**
   * The last gradients, from the last training iteration.
   */
  private double[] lastGradient;

  /**
   * The network to train.
   */
  protected final FlatNetwork network;

  /**
   * The training data.
   */
  private final EngineDataSet training;

  /**
   * The network in indexable form.
   */
  private final EngineIndexableSet indexable;

  /**
   * The workers.
   */
  private FlatGradientWorker[] workers;

  /**
   * The total error. Used to take the average of.
   */
  private double totalError;

  /**
   * The current error is the average error over all of the threads.
   */
  protected double currentError;

  /**
   * Reported exception from the threads.
   */
  private Throwable reportedException;

  /**
   * The iteration.
   */
  private int iteration;

  /**
   * Train a flat network multithreaded.
   *
   * @param network
   *            The network to train.
   * @param training
   *            The training data to use.
   */
  public TrainFlatNetworkProp(final FlatNetwork network,
      final EngineDataSet training) {

    if (!(training instanceof EngineIndexableSet)) {
      throw new EncogEngineError(
          "Training data must be Indexable for this training type.");
    }

    this.training = training;
    this.network = network;

    this.gradients = new double[this.network.getWeights().length];
    this.lastGradient = new double[this.network.getWeights().length];

    this.indexable = (EngineIndexableSet) training;
    this.numThreads = 0;
    this.reportedException = null;
  }

  /**
   * Calculate the gradients.
   */
  public void calculateGradients() {
    if (this.workers == null) {
      init();
    }

    this.workers[0].getNetwork().clearContext();
    this.totalError = 0;

    if (this.workers.length > 1) {

      final TaskGroup group = EngineConcurrency.getInstance()
          .createTaskGroup();

      for (final FlatGradientWorker worker : this.workers) {
        EngineConcurrency.getInstance().processTask(worker, group);
      }

      group.waitForComplete();
    } else {
      this.workers[0].run();
    }

    this.currentError = this.totalError / this.workers.length;

  }

  /**
   * Copy the contexts to keep them consistent with multithreaded training.
   */
  private void copyContexts() {
   
    // copy the contexts(layer outputO from each group to the next group
    for (int i = 0; i < (this.workers.length - 1); i++) {
      final double[] src = this.workers[i].getNetwork().getLayerOutput();
      final double[] dst = this.workers[i + 1].getNetwork()
          .getLayerOutput();
      EngineArray.arrayCopy(src, dst);
    }
    
         // copy the contexts from the final group to the real network
         EngineArray.arrayCopy(
             this.workers[this.workers.length - 1].getNetwork().getLayerOutput(),
             this.network.getLayerOutput());
  }

  /**
   * {@inheritDoc}
   */
  public void finishTraining() {
    // nothing to do
  }

  /**
   * {@inheritDoc}
   */
  public double getError() {
    return this.currentError;
  }

  /**
   * @return The gradients from the last iteration;
   */
  public double[] getLastGradient() {
    return this.lastGradient;
  }

  /**
   * {@inheritDoc}
   */
  public FlatNetwork getNetwork() {
    return this.network;
  }

  /**
   * {@inheritDoc}
   */
  public int getNumThreads() {
    return this.numThreads;
  }

  /**
   * {@inheritDoc}
   */
  public EngineDataSet getTraining() {
    return this.training;
  }

  /**
   * Init the process.
   */
  private void init() {

    final DetermineWorkload determine = new DetermineWorkload(
        this.numThreads, (int) this.indexable.getRecordCount());

    this.workers = new FlatGradientWorker[determine.getThreadCount()];

    int index = 0;

    // handle CPU
    for (final IntRange r : determine.calculateWorkers()) {
      this.workers[index++] = new GradientWorkerCPU(this.network.clone(),
          this, this.indexable.openAdditional(), r.getLow(),
          r.getHigh());
    }
  }

  /**
   * {@inheritDoc}
   */
  public void iteration() {

    this.iteration++;

    calculateGradients();

    if (this.network.isLimited()) {
      learnLimited();
    } else {
      learn();
    }

    for (final FlatGradientWorker worker : this.workers) {
      EngineArray.arrayCopy(this.network.getWeights(), 0,
          worker.getWeights(), 0, this.network.getWeights().length);
    }

    copyContexts();

    if (this.reportedException != null) {
      throw (new EncogEngineError(this.reportedException));
    }
  }

  /**
   * Apply and learn.
   */
  protected void learn() {
    final double[] weights = this.network.getWeights();
    for (int i = 0; i < this.gradients.length; i++) {
      weights[i] += updateWeight(this.gradients, this.lastGradient, i);
      this.gradients[i] = 0;
    }
  }

  /**
   * Apply and learn. This is the same as learn, but it checks to see if any
   * of the weights are below the limit threshold. In this case, these weights
   * are zeroed out. Having two methods allows the regular learn method, which
   * is what is usually use, to be as fast as possible.
   */
  protected void learnLimited() {
    final double limit = this.network.getConnectionLimit();
    final double[] weights = this.network.getWeights();
    for (int i = 0; i < this.gradients.length; i++) {
      if (weights[i] < limit) {
        weights[i] = 0;
      } else {
        weights[i] += updateWeight(this.gradients, this.lastGradient, i);
      }
      this.gradients[i] = 0;
    }
  }

  /**
   * Called by the worker threads to report the progress at each step.
   *
   * @param gradients
   *            The gradients from that worker.
   * @param error
   *            The error for that worker.
   * @param ex
   *            The exception.
   */
  public void report(final double[] gradients, final double error,
      final Throwable ex) {
    synchronized (this) {
      if (ex == null) {

        for (int i = 0; i < gradients.length; i++) {
          this.gradients[i] += gradients[i];
        }
        this.totalError += error;
      } else {
        this.reportedException = ex;
      }
    }
  }

  /**
   * {@inheritDoc}
   */
  public void setNumThreads(final int numThreads) {
    this.numThreads = numThreads;
  }

  /**
   * Update a weight, the means by which weights are updated vary depending on
   * the training.
   *
   * @param gradients
   *            The gradients.
   * @param lastGradient
   *            The last gradients.
   * @param index
   *            The index.
   * @return The update value.
   */
  public abstract double updateWeight(double[] gradients,
      double[] lastGradient, int index);

  /**
   * Perform the specified number of training iterations. This is a basic
   * implementation that just calls iteration the specified number of times.
   * However, some training methods, particularly with the GPU, benefit
   * greatly by calling with higher numbers than 1.
   *
   * @param count
   *            The number of training iterations.
   */
  public void iteration(final int count) {
    for (int i = 0; i < count; i++) {
      iteration();
    }
  }

  /**
   * {@inheritDoc}
   */
  public int getIteration() {
    return this.iteration;
  }
 
  /**
   * {@inheritDoc}
   */
  public void setIteration(final int iteration) {
    this.iteration = iteration;
  }
}
TOP

Related Classes of org.encog.engine.network.train.prop.TrainFlatNetworkProp

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.