package org.neuralnet;
import java.text.DecimalFormat;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import javax.swing.JFrame;
import com.mxgraph.swing.mxGraphComponent;
import com.mxgraph.view.mxGraph;
public class Visualisation extends JFrame {
private final Integer NODESIZE = 30;
private final Integer HORIZSPACING = 80;
private final Integer VERTICSPACING = 100;
private Network network;
static final long serialVersionUID = 1L;
public Visualisation() {
this.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
mxGraph graph = new mxGraph();
Object parent = graph.getDefaultParent();
Map<String, Object> defaultVertexStyle = graph.getStylesheet()
.getDefaultVertexStyle();
defaultVertexStyle.put(com.mxgraph.util.mxConstants.STYLE_SHAPE,
com.mxgraph.util.mxConstants.SHAPE_ELLIPSE);
graph.getStylesheet().setDefaultVertexStyle(defaultVertexStyle);
graph.getModel().beginUpdate();
try {
Object v1 = graph.insertVertex(parent, null, null, 30, 30, 30, 30);
Object v2 = graph.insertVertex(parent, null, null, 30, 130, 30, 30);
Object v3 = graph.insertVertex(parent, null, null, 110, 30, 30, 30);
graph.insertEdge(parent, null, null, v1, v2);
graph.insertEdge(parent, null, null, v1, v1);
} finally {
graph.getModel().endUpdate();
}
mxGraphComponent graphComponent = new mxGraphComponent(graph);
this.getContentPane().add(graphComponent);
this.setSize(800, 600);
this.setVisible(true);
}
public Visualisation(Network net) {
this.network = net;
mxGraph graph = new mxGraph();
Map<String, Object> defaultVertexStyle = graph.getStylesheet()
.getDefaultVertexStyle();
defaultVertexStyle.put(com.mxgraph.util.mxConstants.STYLE_SHAPE,
com.mxgraph.util.mxConstants.SHAPE_ELLIPSE);
graph.getStylesheet().setDefaultVertexStyle(defaultVertexStyle);
Collection<Neuron> neurons = this.network.getNeurons();
Iterator<Neuron> it = neurons.iterator();
HashSet<NeuronConnection> connections = this.network
.getNeuronConnectionCollection();
HashMap<Integer, HashMap<Integer, Neuron>> layers = new HashMap<Integer, HashMap<Integer, Neuron>>();
while (it.hasNext()) {
Neuron neuron = it.next();
Integer layer = neuron.getOutputDistance();
if (!layers.containsKey(layer)) {
layers.put(layer, new HashMap<Integer, Neuron>());
}
layers.get(layer).put(neuron.getId(), neuron);
}
Iterator it2 = layers.values().iterator();
Integer maxlayersize = 0;
while (it2.hasNext()) {
HashMap<Integer, Neuron> layer = (HashMap<Integer, Neuron>) it2
.next();
if (layer.size() > maxlayersize)
maxlayersize = layer.size();
}
Integer width = maxlayersize * (NODESIZE + HORIZSPACING);
Integer height = layers.size() * (NODESIZE + HORIZSPACING);
Iterator it3 = layers.values().iterator();
HashMap<Integer, Object> vertices = new HashMap<Integer, Object>();
Object parent = graph.getDefaultParent();
graph.getModel().beginUpdate();
try {
for (int i = 0; i < layers.size(); i++) {
HashMap<Integer, Neuron> layer = layers.get(i);
Iterator<Neuron> it4 = layer.values().iterator();
Integer k = 0;
while (it4.hasNext()) {
Neuron current = it4.next();
// current.doClownSong("datdatdadadadatdatdada");
Object vertex = graph.insertVertex(parent, null, current
.getId(), ((++k) * (HORIZSPACING + NODESIZE)),
((i + 1) * (VERTICSPACING + NODESIZE)), NODESIZE,
NODESIZE);
vertices.put(current.getId(), vertex);
}
}
Iterator<NeuronConnection> it9 = connections.iterator();
while (it9.hasNext()) {
NeuronConnection current = it9.next();
graph.insertEdge(parent, null, this.weightToString(current
.getWeight()), vertices
.get(current.getSource().getId()), vertices.get(current
.getTarget().getId()));
}
} finally {
graph.getModel().endUpdate();
}
mxGraphComponent graphComponent = new mxGraphComponent(graph);
this.getContentPane().add(graphComponent);
this.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
this.setSize(800, 600);
this.setVisible(true);
}
public static void main(String[] args) {
new Visualisation();
}
double weightToString(double weight) {
DecimalFormat format = new DecimalFormat("#.##");
return Double.valueOf(format.format(weight));
}
}