Package com.github.neuralnetworks.training

Source Code of com.github.neuralnetworks.training.Trainer

package com.github.neuralnetworks.training;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;

import com.github.neuralnetworks.architecture.Layer;
import com.github.neuralnetworks.architecture.NeuralNetwork;
import com.github.neuralnetworks.calculation.OutputError;
import com.github.neuralnetworks.calculation.ValuesProvider;
import com.github.neuralnetworks.events.TrainingEvent;
import com.github.neuralnetworks.events.TrainingEventListener;
import com.github.neuralnetworks.training.events.MiniBatchFinishedEvent;
import com.github.neuralnetworks.training.events.TestingFinishedEvent;
import com.github.neuralnetworks.training.events.TestingStartedEvent;
import com.github.neuralnetworks.training.random.NNRandomInitializer;
import com.github.neuralnetworks.util.Constants;
import com.github.neuralnetworks.util.Properties;
import com.github.neuralnetworks.util.UniqueList;

/**
* Base class for training (used for both supervised and unsupervised learning)
*/
public abstract class Trainer<N extends NeuralNetwork> implements Serializable {

    private static final long serialVersionUID = 1L;

    /**
     * Properties for the training (for example learnig rate, weight decay etc)
     */
    protected Properties properties;

    /**
     * Training event listeners
     */
    protected List<TrainingEventListener> listeners;

    public Trainer() {
  super();
    }

    public Trainer(Properties properties) {
  super();
  this.properties = properties;
    }

    /**
     * Training method
     */
    public abstract void train();

    /**
     * The network is tested via the testing input provider and the training error is aggregated for each example.
     */
    public void test() {
  TrainingInputProvider ip = getTestingInputProvider();
  NeuralNetwork n = getNeuralNetwork();

  if (ip != null && n != null && n.getLayerCalculator() != null) {
      ip.reset();

      triggerEvent(new TestingStartedEvent(this));

      Set<Layer> calculatedLayers = new UniqueList<>();
      ValuesProvider results = new ValuesProvider();
      TrainingInputData input = null;

      if (getOutputError() != null) {
    getOutputError().reset();
      }

      while ((input = ip.getNextInput()) != null) {
    calculatedLayers.clear();
    calculatedLayers.add(n.getInputLayer());
    results.addValues(n.getInputLayer(), input.getInput());
    n.getLayerCalculator().calculate(n, n.getOutputLayer(), calculatedLayers, results);

    if (getOutputError() != null) {
        getOutputError().addItem(results.getValues(n.getOutputLayer()), input.getTarget());
    }

    triggerEvent(new MiniBatchFinishedEvent(this, input, results, null));
      }
     
      triggerEvent(new TestingFinishedEvent(this));
  }
    }

    public Properties getProperties() {
  return properties;
    }

    public void setProperties(Properties properties) {
  this.properties = properties;
    }

    public N getNeuralNetwork() {
  return properties.getParameter(Constants.NEURAL_NETWORK);
    }

    public void setNeuralNetwork(N neuralNetwork) {
  properties.setParameter(Constants.NEURAL_NETWORK, neuralNetwork);
    }

    public TrainingInputProvider getTrainingInputProvider() {
  return properties.getParameter(Constants.TRAINING_INPUT_PROVIDER);
    }

    public void setTrainingInputProvider(TrainingInputProvider trainingInputProvider) {
  properties.setParameter(Constants.TRAINING_INPUT_PROVIDER, trainingInputProvider);
    }

    public TrainingInputProvider getTestingInputProvider() {
  return properties.getParameter(Constants.TESTING_INPUT_PROVIDER);
    }

    public void setTestingInputProvider(TrainingInputProvider testingInputProvider) {
  properties.setParameter(Constants.TESTING_INPUT_PROVIDER, testingInputProvider);
    }

    public OutputError getOutputError() {
  return properties.getParameter(Constants.OUTPUT_ERROR);
    }

    public void setOutputError(OutputError outputError) {
  properties.setParameter(Constants.OUTPUT_ERROR, outputError);
    }
   
    public NNRandomInitializer getRandomInitializer() {
  return properties.getParameter(Constants.RANDOM_INITIALIZER);
    }
   
    public void setRandomInitializer(NNRandomInitializer randomInitializer) {
  properties.setParameter(Constants.RANDOM_INITIALIZER, randomInitializer);
    }

    public void addEventListener(TrainingEventListener listener) {
  if (listeners == null) {
      listeners = new UniqueList<>();
  }

  listeners.add(listener);
    }

    public void removeEventListener(TrainingEventListener listener) {
  if (listeners != null) {
      listeners.remove(listener);
  }
    }

    protected void triggerEvent(TrainingEvent event) {
  if (listeners != null) {
      List<TrainingEventListener> listeners = new ArrayList<>(this.listeners);
      listeners.forEach(l -> l.handleEvent(event));
  }
    }

    protected boolean stopTraining(int index) {
  return index >= getTestingInputProvider().getInputSize();
    }
}
TOP

Related Classes of com.github.neuralnetworks.training.Trainer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.