return new BackPropagationTrainer<NeuralNetwork>(p);
}
private static BackPropagationLayerCalculatorImpl bplc(NeuralNetworkImpl nn, Properties p) {
BackPropagationLayerCalculatorImpl blc = new BackPropagationLayerCalculatorImpl();
LayerCalculatorImpl lc = (LayerCalculatorImpl) nn.getLayerCalculator();
List<ConnectionCandidate> connections = new BreadthFirstOrderStrategy(nn, nn.getOutputLayer()).order();
if (connections.size() > 0) {
Layer current = null;
List<Connections> chunk = new ArrayList<>();
Set<Layer> convCalculatedLayers = new HashSet<>(); // tracks
// convolutional
// layers
// (because their
// calculations
// are
// interlinked)
convCalculatedLayers.add(nn.getOutputLayer());
for (int i = 0; i < connections.size(); i++) {
ConnectionCandidate c = connections.get(i);
chunk.add(c.connection);
if (i == connections.size() - 1 || connections.get(i + 1).target != c.target) {
current = c.target;
ConnectionCalculator result = null;
ConnectionCalculator ffcc = null;
if (Util.isBias(current)) {
ffcc = lc.getConnectionCalculator(current.getConnections().get(0).getOutputLayer());
} else if (Util.isConvolutional(current) || Util.isSubsampling(current)) {
if (chunk.size() != 1) {
throw new IllegalArgumentException("Convolutional layer with more than one connection");
}
ffcc = lc.getConnectionCalculator(Util.getOppositeLayer(chunk.iterator().next(), current));
} else {
ffcc = lc.getConnectionCalculator(current);
}
if (ffcc instanceof AparapiSigmoid) {
result = new BackPropagationSigmoid(p);
} else if (ffcc instanceof AparapiTanh) {
result = new BackPropagationTanh(p);
} else if (ffcc instanceof AparapiSoftReLU) {
result = new BackPropagationSoftReLU(p);
} else if (ffcc instanceof AparapiReLU) {
result = new BackPropagationReLU(p);
} else if (ffcc instanceof AparapiMaxout) {
result = new BackpropagationMaxout(p);
} else if (ffcc instanceof AparapiMaxPooling2D || ffcc instanceof AparapiStochasticPooling2D) {
result = new BackpropagationMaxPooling2D();
} else if (ffcc instanceof AparapiAveragePooling2D) {
result = new BackpropagationAveragePooling2D();
} else if (ffcc instanceof ConnectionCalculatorConv) {
Layer opposite = Util.getOppositeLayer(chunk.iterator().next(), current);
if (!convCalculatedLayers.contains(opposite)) {
convCalculatedLayers.add(opposite);
if (ffcc instanceof AparapiConv2DSigmoid) {
result = new BackPropagationConv2DSigmoid(p);
} else if (ffcc instanceof AparapiConv2DTanh) {
result = new BackPropagationConv2DTanh(p);
} else if (ffcc instanceof AparapiConv2DSoftReLU) {
result = new BackPropagationConv2DSoftReLU(p);
} else if (ffcc instanceof AparapiConv2DReLU) {
result = new BackPropagationConv2DReLU(p);
}
} else {
result = new BackPropagationConv2D(p);
}
}
if (result != null) {
blc.addConnectionCalculator(current, result);
}
chunk.clear();
}
}