Package com.github.neuralnetworks.samples.mnist

Source Code of com.github.neuralnetworks.samples.mnist.MnistInputProvider

package com.github.neuralnetworks.samples.mnist;

import java.io.IOException;
import java.io.RandomAccessFile;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

import com.github.neuralnetworks.input.InputConverter;
import com.github.neuralnetworks.training.TrainingInputData;
import com.github.neuralnetworks.training.TrainingInputDataImpl;
import com.github.neuralnetworks.training.TrainingInputProviderImpl;
import com.github.neuralnetworks.util.Matrix;

/**
* MNIST data set with random order
*/
public class MnistInputProvider extends TrainingInputProviderImpl {

    private static final long serialVersionUID = 1L;

    private RandomAccessFile images;
    private RandomAccessFile labels;
    private int epochs;
    private int currentEpoch;
    private int rows;
    private int cols;
    private int inputSize;
    private final int batchSize;
    private List<Integer> elementsOrder;
    private Random random;
    private final InputConverter targetConverter;
    private Matrix tempImages;
    private byte[] current;

    public MnistInputProvider(String imagesFile, String labelsFile, int batchSize, int epochs, InputConverter targetConverter) {
  super();

  this.epochs = epochs;
  this.batchSize = batchSize;
  this.targetConverter = targetConverter;

  try {
      this.images = new RandomAccessFile(imagesFile, "r");
      this.labels = new RandomAccessFile(labelsFile, "r");

      // magic numbers
      images.readInt();
      inputSize = images.readInt();
      rows = images.readInt();
      cols = images.readInt();
      current = new byte[rows * cols];

      random = new Random();
  } catch (IOException e) {
      e.printStackTrace();
  }
    }

    @Override
    public TrainingInputData getNextUnmodifiedInput() {
  TrainingInputData result = null;

  if (elementsOrder.size() == 0 && currentEpoch < epochs) {
      resetOrder();
      currentEpoch++;
  }

  if (elementsOrder.size() > 0) {
      int length = elementsOrder.size() > batchSize ? batchSize : elementsOrder.size();
      int[] indexes = new int[length];
      for (int i = 0; i < length; i++) {
    indexes[i] = elementsOrder.remove(random.nextInt(elementsOrder.size()));
      }

      Matrix input = getImages(indexes);

      Matrix target = targetConverter.convert(getLabels(indexes));

      result = new TrainingInputDataImpl(input, target);
  }

  return result;
    }

    @Override
    public void reset() {
  currentEpoch = 1;
  resetOrder();
    }

    public void resetOrder() {
  elementsOrder = new ArrayList<Integer>(inputSize);
  for (int i = 0; i < inputSize; i++) {
      elementsOrder.add(i);
  }
    }

    @Override
    public int getInputSize() {
  return inputSize * epochs;
    }

    private Matrix getImages(int[] indexes) {
  int size = cols * rows;
  if (tempImages == null || tempImages.getRows() != indexes.length) {
      tempImages = new Matrix(size, indexes.length);
  }

  try {
      for (int i = 0; i < indexes.length; i++) {
    images.seek(16 + size * indexes[i]);
    images.readFully(current);
    for (int j = 0; j < size; j++) {
        tempImages.set(current[j] & 0xFF, j, i);
    }
      }
  } catch (IOException e) {
      e.printStackTrace();
  }

  return tempImages;
    }

    private Integer[] getLabels(int indexes[]) {
  Integer[] result = new Integer[indexes.length];

  try {
      for (int i = 0; i < indexes.length; i++) {
    labels.seek(8 + indexes[i]);
    result[i] = labels.readUnsignedByte();
      }
  } catch (IOException e) {
      e.printStackTrace();
  }

  return result;
    }

    public int getRows() {
  return rows;
    }

    public int getCols() {
  return cols;
    }
}
TOP

Related Classes of com.github.neuralnetworks.samples.mnist.MnistInputProvider

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.