if (layers[0].length != 3) {
throw new IllegalArgumentException("first layer must be convolutional");
}
NeuralNetworkImpl result = new NeuralNetworkImpl();
ConnectionFactory cf = new ConnectionFactory();
result.setProperties(new Properties());
result.getProperties().setParameter(Constants.CONNECTION_FACTORY, cf);
Layer prev = null;
int prevUnitCount = layers[0][0] * layers[0][1] * layers[0][2];
result.addLayer(prev = new Layer());
for (int i = 1; i < layers.length; i++) {
int[] l = layers[i];
Layer newLayer = null;
Layer biasLayer = null;
if (l.length == 1) {
cf.fullyConnected(prev, newLayer = new Layer(), prevUnitCount, l[0]);
if (addBias) {
cf.fullyConnected(biasLayer = new Layer(), newLayer, 1, l[0]);
}
prevUnitCount = l[0];
} else if (l.length == 4 || l.length == 2) {
Integer inputFMRows = null;
Integer inputFMCols = null;
Integer filters = null;
if (i == 1) {
inputFMRows = layers[0][0];
inputFMCols = layers[0][1];
filters = layers[0][2];
} else {
for (Connections c : prev.getConnections()) {
if (c.getOutputLayer() == prev) {
if (c instanceof Conv2DConnection) {
Conv2DConnection cc = (Conv2DConnection) c;
inputFMRows = cc.getOutputFeatureMapRows();
inputFMCols = cc.getOutputFeatureMapColumns();
filters = cc.getOutputFilters();
break;
} else if (c instanceof Subsampling2DConnection) {
Subsampling2DConnection sc = (Subsampling2DConnection) c;
inputFMRows = sc.getOutputFeatureMapRows();
inputFMCols = sc.getOutputFeatureMapColumns();
filters = sc.getFilters();
break;
}
}
}
}
if (l.length == 4) {
Conv2DConnection c = cf.conv2d(prev, newLayer = new Layer(), inputFMRows, inputFMCols, filters, l[0], l[1], l[2], l[3]);
if (addBias) {
cf.conv2d(biasLayer = new Layer(), newLayer, c.getOutputFeatureMapRows(), c.getOutputFeatureMapColumns(), 1, 1, 1, l[2], l[3]);
}
prevUnitCount = c.getOutputUnitCount();
} else if (l.length == 2) {
Subsampling2DConnection c = cf.subsampling2D(prev, newLayer = new Layer(), inputFMRows, inputFMCols, l[0], l[1], filters);
prevUnitCount = c.getOutputUnitCount();
}
}
result.addLayer(newLayer);