Package org.neuroph.core.learning

Source Code of org.neuroph.core.learning.TrainingSet

/**
* Copyright 2010 Neuroph Project http://neuroph.sourceforge.net
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*    http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.neuroph.core.learning;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import org.encog.engine.data.EngineData;
import org.encog.engine.data.EngineDataSet;
import org.encog.engine.data.EngineIndexableSet;
import org.neuroph.core.exceptions.VectorSizeMismatchException;

/**
* A set of training elements for training neural network.
*
* @author Zoran Sevarac <sevarac@gmail.com>
*/
public class TrainingSet implements Serializable, EngineIndexableSet {

  /**
   * The class fingerprint that is set to indicate serialization compatibility
   * with a previous version of the class
   */
  private static final long serialVersionUID = 2L;

  /**
   * Collection of training elements
   */
  private List<TrainingElement> elements;

  private int inputVectorSize = 0;
  private int outputVectorSize = 0;

  /**
   * Label for this training set
   */
  private String label;

  /**
   * Full file path including file name
   */
  private transient String filePath;

  /**
   * Creates an instance of new empty training set
   */
  public TrainingSet() {
    this.elements = new ArrayList<TrainingElement>();
  }

  /**
   * Creates an instance of new empty training set with given label
   *
   * @param label
   *            training set label
   */
  public TrainingSet(String label) {
    this.label = label;
    this.elements = new ArrayList<TrainingElement>();
  }

  /**
   * Creates an instance of new empty training set
   *
   * @param inputVectorSize
   */
  public TrainingSet(int inputVectorSize) {
    this.elements = new ArrayList<TrainingElement>();
    this.inputVectorSize = inputVectorSize;
  }

  /**
   * Creates an instance of new empty training set
   *
   * @param inputVectorSize
   * @param outputVectorSize
   */
  public TrainingSet(int inputVectorSize, int outputVectorSize) {
    this.elements = new ArrayList<TrainingElement>();
    this.inputVectorSize = inputVectorSize;
    this.outputVectorSize = outputVectorSize;
  }

  /**
   * Adds new training element to this training set
   *
   * @param el
   *            training element to add
   */
  public void addElement(TrainingElement el)
      throws VectorSizeMismatchException {
    // check input vector size if it is predefined
    if ((this.inputVectorSize != 0)
        && (el.getInput().length != this.inputVectorSize)) {
      throw new VectorSizeMismatchException(
          "Input vector size does not match training set!");
    }
    // check output vector size if it is predefined
    if (el instanceof SupervisedTrainingElement) {
      SupervisedTrainingElement sel = (SupervisedTrainingElement) el;
      if ((this.outputVectorSize != 0)
          && (sel.getDesiredOutput().length != this.outputVectorSize)) {
        throw new VectorSizeMismatchException(
            "Output vector size does not match training set!");
      }
    }
    // if everything went ok add training element
    this.elements.add(el);
  }

  /**
   * Removes training element at specified index position
   *
   * @param idx
   *            position of element to remove
   */
  public void removeElementAt(int idx) {
    this.elements.remove(idx);
  }

  /**
   * Returns Iterator for iterating training elements collection
   *
   * @return Iterator for iterating training elements collection
   */
  public Iterator<TrainingElement> iterator() {
    return this.elements.iterator();
  }

  /**
   * Returns training elements collection
   *
   * @return training elements collection
   */
  public List<TrainingElement> trainingElements() {
    return this.elements;
  }

  /**
   * Returns training element at specified index position
   *
   * @param idx
   *            index position of training element to return
   * @return training element at specified index position
   */
  public TrainingElement elementAt(int idx) {
    return this.elements.get(idx);
  }

  /**
   * Removes all alements from training set
   */
  public void clear() {
    this.elements.clear();
  }

  /**
   * Returns true if training set is empty, false otherwise
   *
   * @return true if training set is empty, false otherwise
   */
  public boolean isEmpty() {
    return this.elements.isEmpty();
  }

  /**
   * Returns number of training elements in this training set set
   *
   * @return number of training elements in this training set set
   */
  public int size() {
    return this.elements.size();
  }

  /**
   * Returns label for this training set
   *
   * @return label for this training set
   */
  public String getLabel() {
    return label;
  }

  /**
   * Sets label for this training set
   *
   * @param label
   *            label for this training set
   */
  public void setLabel(String label) {
    this.label = label;
  }

  /**
   * Sets full file path for this training set
   *
   * @param filePath
   */
  public void setFilePath(String filePath) {
    this.filePath = filePath;
  }

  /**
   * Returns full file path for this training set
   *
   * @return full file path for this training set
   */
  public String getFilePath() {
    return filePath;
  }

  /**
   * Returns label of this training set
   *
   * @return label of this training set
   */
  @Override
  public String toString() {
    return this.label;
  }

  /**
   * Saves this training set to the specified file
   *
   * @param filePath
   */
  public void save(String filePath) {
    this.filePath = filePath;
    this.save();
  }

  /**
   * Saves this training set to file specified in its filePath field
   */
  public void save() {
    ObjectOutputStream out = null;

    try {
      File file = new File(this.filePath);
      out = new ObjectOutputStream(new FileOutputStream(file));
      out.writeObject(this);
      out.flush();

    } catch (Exception e) {
      System.err.println(e.getMessage());
      e.printStackTrace();
    } finally {
      if (out != null) {
        try {
          out.close();
        } catch (IOException ioe) {
        }
      }
    }
  }

  /**
   * Loads training set from the specified file
   *
   * @param filePath
   *            training set file
   * @return loded training set
   */
  public static TrainingSet load(String filePath) {
    ObjectInputStream oistream = null;

    try {
      File file = new File(filePath);
      if (!file.exists()) {
        throw new FileNotFoundException("Cannot find file: " + filePath);
      }

      oistream = new ObjectInputStream(new FileInputStream(filePath));
      TrainingSet tSet = (TrainingSet) oistream.readObject();

      return tSet;

    } catch (IOException ioe) {
      ioe.printStackTrace();
    } catch (ClassNotFoundException cnfe) {
      cnfe.printStackTrace();
    } finally {
      if (oistream != null) {
        try {
          oistream.close();
        } catch (IOException ioe) {
        }
      }
    }

    return null;
  }

  /**
   * Returns output vector size of training elements in this training set This
   * method is implementation of EngineIndexableSet interface, and it is added
   * to provide compatibility with Encog data sets and FlatNetwork
   */
  @Override
  public int getIdealSize() {
    return this.outputVectorSize;
  }

  /**
   * Returns output vector size of training elements in this training set.
   */
  public int getOutputSize() {
    return this.outputVectorSize;
  }

  /**
   * Returns input vector size of training elements in this training set This
   * method is implementation of EngineIndexableSet interface, and it is added
   * to provide compatibility with Encog data sets and FlatNetwork
   */
  @Override
  public int getInputSize() {
    return this.inputVectorSize;
  }

  /**
   * Returns true if training set contains supervised training elements This
   * method is implementation of EngineIndexableSet interface, and it is added
   * to provide compatibility with Encog data sets and FlatNetwork
   */
  @Override
  public boolean isSupervised() {
    return this.outputVectorSize > 0;
  }

  /**
   * Gets training data/record at specified index position. This method is
   * implementation of EngineIndexableSet interface. It is added for
   * Encog-Engine compatibility.
   */
  @Override
  public void getRecord(long index, EngineData pair) {
    EngineData item = this.elements.get((int) index);
    pair.setInputArray(item.getInputArray());
    pair.setIdealArray(item.getIdealArray());
  }

  /**
   * Returns training elements/records count This method is implementation of
   * EngineIndexableSet interface. It is added for Encog-Engine compatibility.
   */
  @Override
  public long getRecordCount() {
    return this.elements.size();
  }

  /**
   * This method is implementation of EngineIndexableSet interface, and it is
   * added to provide compatibility with Encog data sets and FlatNetwork.
   *
   * Some datasets are not memory based, they may make use of a SQL connection
   * or a binary flat file. Because of this these datasets need to be cloned
   * for multi-threaded training or performance will greatly suffer. Because
   * this is a memory-based dataset, no cloning takes place and the "this"
   * object is returned.
   */
  @Override
  public EngineIndexableSet openAdditional() {
    return this;
  }
}
TOP

Related Classes of org.neuroph.core.learning.TrainingSet

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.