Package eas.simulation.brain.neural

Source Code of eas.simulation.brain.neural.GeneralNeuralNetwork

/*
* File name:        NeuralNetwork.java (package eas.simulation.brain.neural)
* Author(s):        Lukas König
* Java version:     6.0
* Generation date:  01.08.2011 (12:52:48)
*
* (c) This file and the EAS (Easy Agent Simulation) framework containing it
* is protected by Creative Commons by-nc-sa license. Any altered or
* further developed versions of this file have to meet the agreements
* stated by the license conditions.
*
* In a nutshell
* -------------
* You are free:
* - to Share -- to copy, distribute and transmit the work
* - to Remix -- to adapt the work
*
* Under the following conditions:
* - Attribution -- You must attribute the work in the manner specified by the
*   author or licensor (but not in any way that suggests that they endorse
*   you or your use of the work).
* - Noncommercial -- You may not use this work for commercial purposes.
* - Share Alike -- If you alter, transform, or build upon this work, you may
*   distribute the resulting work only under the same or a similar license to
*   this one.
*
* + Detailed license conditions (Germany):
*   http://creativecommons.org/licenses/by-nc-sa/3.0/de/
* + Detailed license conditions (unported):
*   http://creativecommons.org/licenses/by-nc-sa/3.0/deed.en
*
* This header must be placed in the beginning of any version of this file.
*/

package eas.simulation.brain.neural;

import java.awt.Color;
import java.awt.Font;
import java.awt.Graphics2D;
import java.awt.image.BufferedImage;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedList;

import eas.math.geometry.Polygon2D;
import eas.math.geometry.Vector2D;
import eas.miscellaneous.StaticMethods;
import eas.simulation.brain.neural.functions.ActivationFunction;
import eas.simulation.brain.neural.functions.ActivationFunctionIdentity;
import eas.simulation.brain.neural.functions.ActivationFunctionSigmoid;
import eas.simulation.brain.neural.functions.NeuralCombinationFunctionAverage;
import eas.simulation.brain.neural.functions.TransitionFunction;
import eas.simulation.brain.neural.functions.TransitionFunctionWeightedSum;
import eas.startSetup.GlobalVariables;
import eas.startSetup.ParCollection;
import eas.startSetup.marbBuilder.zeichenModi.ArrowMaster;

/**
* TODO: Reflexive links are not displayed.
*
* @author Lukas König
*/
public class GeneralNeuralNetwork implements Serializable {

    private static final long serialVersionUID = 5981273888504161222L;

    public static final double DE_FACTO_INFINITY = 10000000;

    private HashMap<Integer, Neuron> neurons;

    private int neuronCount = 0;

    /**
     * The standard activation function. Activation functions are constrained to
     * return values from [0, 1].
     */
    private ActivationFunction standardActFct;

    private ArrowMaster pfeilMaster;

    /**
     * The standard transition function - standard: weighted sum.
     */
    private TransitionFunction standardTrnFct = new TransitionFunctionWeightedSum();

    private ParCollection pars;

    private boolean allowTransitionsToInputs = false;

    private boolean allowRecurrentTransitions = false;

    private boolean allowTransitionsFromOutputs = true;

    public GeneralNeuralNetwork(eas.users.lukas.neuroCEGPM.simpleNeural.NeuralNetwork other) {
        this(other.getPars());
        this.allowRecurrentTransitions = true;
        this.allowTransitionsFromOutputs = true;
        this.allowTransitionsToInputs = true;

        for (int i = 0; i < other.getNeuronCount(); i++) {
            this.addNeuron(i, other.getActivationFunction(i).getCorrespondingDoubleBasedFunction(), Neuron.INPUT_OUTPUT_NEURON);
            this.setInput(i, other.getNeuronValue(i).doubleValue());
            this.setOutput(i, other.getActivationFunction(i).getCorrespondingDoubleBasedFunction().activationPhi(other.getNeuronValue(i).doubleValue()));
        }
        for (int i = 0; i < other.getNeuronCount(); i++) {
            for (int j = 0; j < other.getNeuronCount(); j++) {
                this.addLink(i, j, other.getWeight(i, j).doubleValue());
            }
        }
        this.standardActFct = new ActivationFunctionSigmoid(1.0);
    }
   
    public GeneralNeuralNetwork(eas.users.lukas.neuroCEGPM.simpleNeural.NeuralNetworkDoubleBase other) {
        this(other.getPars());
        this.allowRecurrentTransitions = true;
        this.allowTransitionsFromOutputs = true;
        this.allowTransitionsToInputs = true;

        for (int i = 0; i < other.getNeuronCount(); i++) {
            this.addNeuron(i, other.getActivationFunction(i), Neuron.INPUT_OUTPUT_NEURON);
            this.setInput(i, other.getNeuronValue(i));
            this.setOutput(i, other.getActivationFunction(i).activationPhi(other.getNeuronValue(i)));
        }
        for (int i = 0; i < other.getNeuronCount(); i++) {
            for (int j = 0; j < other.getNeuronCount(); j++) {
                this.addLink(i, j, other.getWeight(i, j));
            }
        }
        this.standardActFct = new ActivationFunctionSigmoid(1.0);
    }
   
    public GeneralNeuralNetwork(final ParCollection params) {
        this.neurons = new HashMap<Integer, Neuron>();
        this.pars = params;
        this.pfeilMaster = new ArrowMaster(this.pars);
        this.standardActFct = new ActivationFunctionSigmoid(1.0);
    }

    /**
     * Returns the neuron with this ID.
     *
     * @param id
     *            The neuron id to get.
     *
     * @return The neuron with the given ID or null if it does not exist.
     */
    public Neuron getNeuron(final int id) {
        return this.neurons.get(id);
    }

    public boolean existsNeuron(final int id) {
        return this.getNeuron(id) != null;
    }

    /**
     * Adds a new unconnected neuron to the network using standard transition
     * and activiation functions.
     *
     * @param neuronType
     *            The type of the neuron (Neuron.*_NEURON; * in {INPUT, HIDDEN,
     *            OUTPUT}).
     *
     * @return The id of the new neuron.
     */
    public int addNeuron(int neuronType) {
        if (neuronType == Neuron.INPUT_NEURON
                || neuronType == Neuron.INPUT_OUTPUT_NEURON) {
            return this.addNeuron(new ActivationFunctionIdentity(), neuronType);
        } else {
            return this.addNeuron(this.standardActFct, neuronType);
        }
    }

    /**
     * Adds a new unconnected neuron to the network using standard transition
     * function.
     *
     * @param actFctPhi
     *            The activation function of the new neuron.
     * @param neuronType
     *            The type of the neuron (Neuron.*_NEURON; * in {INPUT, HIDDEN,
     *            OUTPUT}).
     *
     * @return The id of the new neuron.
     */
    public int addNeuron(ActivationFunction actFctPhi, int neuronType) {
        return this.addNeuron(actFctPhi, standardTrnFct, neuronType);
    }

    /**
     * Adds a new unconnected neuron to the network.
     *
     * @param actFctPhi
     *            The activation function of the new neuron.
     * @param trnFctSigma
     *            The transition function of the new neuron.
     * @param neuronType
     *            The type of the neuron (Neuron.*_NEURON; * in {INPUT, HIDDEN,
     *            OUTPUT}).
     *
     * @return The id of the new neuron.
     */
    public int addNeuron(ActivationFunction actFctPhi,
            TransitionFunction trnFctSigma, int neuronType) {
        return this.addNeuron(this.neuronCount, actFctPhi, trnFctSigma,
                neuronType);
    }

    public int addNeuron(int neuronID, ActivationFunction actFctPhi,
            int neuronType) {
        return this.addNeuron(neuronID, actFctPhi, this.standardTrnFct,
                neuronType);
    }

    public int addNeuron(int neuronID, int neuronType) {
        if (neuronType == Neuron.INPUT_NEURON
                || neuronType == Neuron.INPUT_OUTPUT_NEURON) {
            return this.addNeuron(neuronID, new ActivationFunctionIdentity(),
                    this.standardTrnFct, neuronType);
        } else {
            return this.addNeuron(neuronID, this.standardActFct,
                    this.standardTrnFct, neuronType);
        }
    }

    /**
     * This method is eventually called by any of the addNeuron(.) methods.
     *
     * @param neuronID
     * @param actFctPhi
     * @param trnFctSigma
     * @param neuronType
     * @return
     */
    public int addNeuron(int neuronID, ActivationFunction actFctPhi,
            TransitionFunction trnFctSigma, int neuronType) {
        int id = neuronID;

        if (id > this.neuronCount || id < 0) {
            throw new RuntimeException("Neuron id cannot be inserted: " + id);
        }

        if (neuronID < this.neuronCount) {
            // Move all neurons beginning with id one step further.
            for (int i = this.neuronCount - 1; i >= id; i--) {
                Neuron neuron = this.getNeuron(i);
                neuron.setId(i + 1);
                this.neurons.put(i + 1, neuron);
            }
        }

        this.neurons.put(id, new Neuron(id, actFctPhi, trnFctSigma, neuronType,
                this));
        neuronCount++;

        if (neuronID < this.neuronCount) {
            // Move all neurons greater than or equal to id referenced in links.
            for (Neuron neuron : this.neurons.values()) {
                for (NeuralLink link : neuron.getIncomingLinks()) {
                    if (link.getSourceNeuronID() >= id) {
                        link.setSourceNeurID(link.getSourceNeuronID() + 1);
                    }
                }
            }
        }

        return id;
    }

    public void removeNeuron(int neuronID) {
        Neuron neuron = this.getNeuron(neuronID);

        if (neuron == null) {
            throw new RuntimeException("No such neuron to remove: " + neuronID);
        }

        // Remove neuron and all references to it.
        this.neurons.remove(neuronID);
        this.neuronCount--;

        for (Neuron n : this.neurons.values()) {
            n.removeIncomingLink(neuronID);
        }

        // Shift all ids greater than the removed one place higher.
        for (int i = neuronID + 1; i <= this.neuronCount; i++) {
            Neuron n = this.getNeuron(i);
            n.setId(i - 1);
            this.neurons.put(i - 1, n);
        }

        this.neurons.remove(this.neuronCount);

        // Adapt all links to the new order.
        for (Neuron n : this.neurons.values()) {
            for (NeuralLink link : n.getIncomingLinks()) {
                if (link.getSourceNeuronID() > neuronID) {
                    link.setSourceNeurID(link.getSourceNeuronID() - 1);
                }
            }
        }
    }

    public boolean addLink(int sourceNeuron, int targetNeuron, double weight) {
        if (sourceNeuron >= this.neuronCount
                || targetNeuron >= this.neuronCount || sourceNeuron < 0
                || targetNeuron < 0) {
            throw new RuntimeException("Bad neuron ID (" + sourceNeuron + ", " + targetNeuron + ") for addLink(.).");
        }

        // if (sourceNeuron == targetNeuron) {
        // throw new RuntimeException("Link cannot point to itself (neuron id: "
        // + sourceNeuron + ").");
        // }

        if (!this.isAllowLinksFromOutputs()
                && this.getNeuron(sourceNeuron).isOutput()) {
            return false;
        }
        if (!this.isAllowLinksToInputs()
                && this.getNeuron(targetNeuron).isInput()) {
            return false;
        }
        if (!this.isAllowRecurrentLinks() && sourceNeuron >= targetNeuron) {
            return false;
        }

        this.neurons.get(targetNeuron).addIncomingLink(
                new NeuralLink(sourceNeuron, weight));

        return true;
    }

    public boolean removeLink(int sourceID, int targetID) {
        Neuron neuron = this.neurons.get(targetID);

        if (neuron == null) {
            return false;
        } else {
            return neuron.removeIncomingLink(sourceID);
        }
    }

    public ActivationFunction getStandardActFct() {
        return this.standardActFct;
    }

    public TransitionFunction getStandardTrnFct() {
        return this.standardTrnFct;
    }

    public void setStandardActFct(ActivationFunction actFct) {
        this.standardActFct = actFct;
    }

    public void setStandardTrnFct(TransitionFunction trnFct) {
        this.standardTrnFct = trnFct;
    }

    /**
     * This method returns the number of neurons in the net and also the first
     * free neuron id which is simulatneously the only id that can be associated
     * to a newly inserted neuron.
     *
     * @return The neuron count. All neurons from id 0 to neuroncount - 1 exist.
     */
    public int getNeuronCount() {
        return this.neuronCount;
    }

    public BufferedImage generateNeuroImage(int width) {
        BufferedImage img = new BufferedImage(width, width - 75,
                BufferedImage.TYPE_INT_RGB);
        Graphics2D g = img.createGraphics();
        g.setColor(Color.white);
        g.fillRect(0, 0, img.getWidth() - 1, img.getHeight() - 1);
        g.setColor(Color.black);
        g.drawRect(0, 0, img.getWidth() - 1, img.getHeight() - 1);

        if (this.neuronCount == 0) {
            return img;
        }

        int nodesize = (width - 80 - 11) / this.neurons.size() - 10;

        nodesize = Math.max(2, nodesize);

        // Draw all possible links.
        // for (int i = 0; i < this.neurons.size(); i++) {
        // for (int j = 0; j < this.neurons.size(); j++) {
        // if (j > i && (this.neurons.get(i).isInput() ||
        // this.neurons.get(i).isOutput())
        // || j < i && (this.neurons.get(j).isInput() ||
        // this.neurons.get(j).isOutput())) {
        // g.setColor(Color.orange);
        // } else {
        // if (j > i) {
        // g.setColor(Color.blue);
        // } else {
        // g.setColor(Color.black);
        // }
        // }
        //
        // g.fillRect(
        // i * (nodesize + 10) + 48 + nodesize / 2,
        // j * (nodesize + 10) + 65 - nodesize / 2,
        // 3,
        // 7);
        // g.fillRect(
        // i * (nodesize + 10) + 46 + nodesize / 2,
        // j * (nodesize + 10) + 67 - nodesize / 2,
        // 7,
        // 3);
        // }
        // }

        for (int i = 0; i < this.neurons.size(); i++) {
            // Draw neuron.
            Neuron currentNeuron = this.neurons.get(i);

            g.drawImage(currentNeuron.generateNeuronView(nodesize, nodesize), i
                    * (nodesize + 10) + 50, i * (nodesize + 10) + 10, null);

            // Draw input / output.
            if (currentNeuron.isInput()) {
                double translation = 0;
                Color color = Color.orange;

                translation += 5;

                Polygon2D points = new Polygon2D();
                ArrayList<Double> thicks = new ArrayList<Double>(2);

                thicks.add(2.5);
                thicks.add(2.5);
                Vector2D startPoint = new Vector2D(10 + translation, i
                        * (nodesize + 10) + 10 + nodesize / 2 - 1);
                points.add(startPoint);
                points.add(new Vector2D(30 + translation, i * (nodesize + 10)
                        + 10 + nodesize / 2 - 1));

                Polygon2D arrow = this.pfeilMaster.segmentPfeilPol2D(points,
                        thicks, ArrowMaster.KUGEL_ENDE,
                        ArrowMaster.EINFACHE_SPITZE_1, new Vector2D(1, 1),
                        new Vector2D(1, 1));
                g.setColor(color);
                g.fillPolygon(arrow.toPol());
                g.setColor(Color.black);
                g.drawPolygon(arrow.toPol());
                g.drawString(StaticMethods.round(currentNeuron.getInput(), 2)
                        + "", (float) startPoint.x, (float) startPoint.y - 5);
            }

            if (currentNeuron.isOutput()) {
                Color color = Color.orange;
                double translation = 0;
                Polygon2D points = new Polygon2D();
                ArrayList<Double> thicks = new ArrayList<Double>(2);
                translation = this.neurons.size() * (nodesize + 10) + 40;
                translation += 5;

                thicks.add(2.5);
                thicks.add(2.5);
                Vector2D startPoint = new Vector2D(10 + translation, i
                        * (nodesize + 10) + 10 + nodesize / 2 - 1);
                points.add(startPoint);
                points.add(new Vector2D(30 + translation, i * (nodesize + 10)
                        + 10 + nodesize / 2 - 1));

                Polygon2D arrow = this.pfeilMaster.segmentPfeilPol2D(points,
                        thicks, ArrowMaster.KUGEL_ENDE,
                        ArrowMaster.EINFACHE_SPITZE_1, new Vector2D(1, 1),
                        new Vector2D(1, 1));
                g.setColor(color);
                g.fillPolygon(arrow.toPol());
                g.setColor(Color.black);
                g.drawPolygon(arrow.toPol());
                g.drawString(StaticMethods.round(currentNeuron.getOutput(), 2)
                        + "", (float) startPoint.x, (float) startPoint.y - 8);

                color = Color.orange;
            }

            // Draw links.
            double maxWeight = Double.NEGATIVE_INFINITY;
            double minWeight = Double.POSITIVE_INFINITY;

            for (NeuralLink link : currentNeuron.getIncomingLinks()) {
                if (link.getWeight() > maxWeight) {
                    maxWeight = link.getWeight();
                }
                if (link.getWeight() < minWeight) {
                    minWeight = link.getWeight();
                }
            }

            for (NeuralLink link : currentNeuron.getIncomingLinks()) {
                Polygon2D points = new Polygon2D();
                ArrayList<Double> thicks = new ArrayList<Double>(2);
                Vector2D weightPos = null;
                double weightFactor;

                if (Math.abs(maxWeight) > Math.abs(minWeight)) {
                    minWeight = -maxWeight;
                } else {
                    maxWeight = -minWeight;
                }

                if (link.getWeight() >= 0) {
                    weightFactor = link.getWeight() / maxWeight;
                } else {
                    weightFactor = link.getWeight() / minWeight;
                }
                if (weightFactor > 0 && weightFactor < 0.26) {
                    weightFactor = 0.26;
                }
                if (weightFactor <= 0 && weightFactor > -0.26) {
                    weightFactor = -0.26;
                }

                if (Double.isNaN(weightFactor)) {
                    weightFactor = 0.1;
                }

                thicks.add(4 * weightFactor);
                thicks.add(4 * weightFactor);
                thicks.add(4 * weightFactor);

                Neuron otherNeuron = this.getNeuron(link.getSourceNeuronID());

                Font font;
                if (Math.abs(weightFactor) > 0.5) {
                    font = new Font("", Font.BOLD, 12);
                } else {
                    font = new Font("", Font.PLAIN, 10);
                }

                if (otherNeuron.getId() < currentNeuron.getId()) {
                    // Regular link from earlier neuron.
                    points.add(new Vector2D(otherNeuron.getId()
                            * (nodesize + 10) + 50 + nodesize / 2, otherNeuron
                            .getId() * (nodesize + 10) + 10 + nodesize));
                    points.add(new Vector2D(otherNeuron.getId()
                            * (nodesize + 10) + 50 + nodesize / 2,
                            currentNeuron.getId() * (nodesize + 10) + 10
                                    + nodesize / 2));
                    points.add(new Vector2D(currentNeuron.getId()
                            * (nodesize + 10) + 40, currentNeuron.getId()
                            * (nodesize + 10) + 10 + nodesize / 2));
                    weightPos = new Vector2D(otherNeuron.getId()
                            * (nodesize + 10) + 53 + nodesize / 2,
                            currentNeuron.getId() * (nodesize + 10) + 10
                                    + nodesize / 2 - 3);
                } else if (otherNeuron.getId() > currentNeuron.getId()) {
                    // Reccurent link from later neuron.
                    points.add(new Vector2D(otherNeuron.getId()
                            * (nodesize + 10) + 50 + nodesize / 2, otherNeuron
                            .getId() * (nodesize + 10) + 10 + nodesize));
                    points.add(new Vector2D(otherNeuron.getId()
                            * (nodesize + 10) + 50 + nodesize / 2,
                            currentNeuron.getId() * (nodesize + 10) + 10
                                    + nodesize / 2));
                    points.add(new Vector2D(currentNeuron.getId()
                            * (nodesize + 10) + 60 + nodesize, currentNeuron
                            .getId() * (nodesize + 10) + 10 + nodesize / 2));
                    weightPos = new Vector2D(otherNeuron.getId()
                            * (nodesize + 10) + 53 + nodesize / 2,
                            currentNeuron.getId() * (nodesize + 10) + 10
                                    + nodesize / 2 - 3);
                } else {
//                     Reccurent link from same neuron.
                    double epsilon = 0.05;
                    double radius = nodesize / 2.0 - 2;
                    double x = currentNeuron.getId() * (nodesize + 10) + 51 + nodesize / 2.0;
                    double y = currentNeuron.getId() * (nodesize + 10) + 11 + nodesize / 2.0;
                    thicks.clear();
                   
                    for (double d = 0; d < Math.PI * 2; d += epsilon) {
                        points.add(new Vector2D(x + Math.cos(d) * radius, y + Math.sin(d) * radius));
                        thicks.add(0 * weightFactor);
                    }
                    
                    weightPos = new Vector2D(otherNeuron.getId()
                            * (nodesize + 10) + 60, currentNeuron.getId()
                            * (nodesize + 10) + 8);
                    weightFactor = Double.POSITIVE_INFINITY;
                }

                double spitzenFactor = 1.0;
                if (weightFactor < 0) {
                    spitzenFactor = -1.0;
                }

                if (points.size() > 0) {
                    Polygon2D arrow = ArrowMaster.DOPPELSPITZES_ENDE;
                    try {
                        arrow = this.pfeilMaster.segmentPfeilPol2D(points,
                                thicks, ArrowMaster.DOPPELSPITZES_ENDE,
                                ArrowMaster.EINFACHE_SPITZE_1, new Vector2D(
                                        spitzenFactor * 1 / weightFactor / 1.7,
                                        1 / weightFactor / 1.7), new Vector2D(
                                        spitzenFactor * 1 / weightFactor / 1.7,
                                        1 / weightFactor / 1.7));
                    } catch (Exception e) {
                        GlobalVariables.getPrematureParameters().logStage1(Arrays.deepToString(e.getStackTrace()));
                    }

                    if (link.getWeight() > 0) {
                        g.setColor(Color.black);
                    } else {
                        g.setColor(Color.red);
                    }
                    g.fillPolygon(arrow.toPol());
                }

                if (link.getWeight() > 0) {
                    g.setColor(Color.black);
                } else {
                    g.setColor(Color.red);
                }
                g.setFont(font);
                g.drawString("" + StaticMethods.round(link.getWeight(), 2),
                        (int) weightPos.x, (int) weightPos.y);
                g.setColor(Color.black);
            }

            // Neuron output.
            g.setColor(Color.blue);
            g.setFont(new Font("", Font.PLAIN, 10));
            g.drawString(
                    "[" + Math.round(currentNeuron.getNetOutput() * 10000.0)
                            / 10000.0 + "]", i * (nodesize + 10) + 50, (i + 1)
                            * (nodesize + 10) + 10);
        }

        return img;
    }

    /**
     * @return Returns the allowTransitionsToInputs.
     */
    public boolean isAllowLinksToInputs() {
        return this.allowTransitionsToInputs;
    }

    /**
     * @return Returns the allowRecurrentTransitions.
     */
    public boolean isAllowRecurrentLinks() {
        return this.allowRecurrentTransitions;
    }

    /**
     * @return Returns the allowTransitionsFromOutputs.
     */
    public boolean isAllowLinksFromOutputs() {
        return this.allowTransitionsFromOutputs;
    }

    /**
     * @param allowTransitionsToInputs
     *            The allowTransitionsToInputs to set.
     */
    public void setAllowLinksToInputs(boolean allowLinksToInputs) {
        this.allowTransitionsToInputs = allowLinksToInputs;
    }

    /**
     * @param allowRecurrentTransitions
     *            The allowRecurrentTransitions to set.
     */
    public void setAllowRecurrentLinks(boolean allowRecurrentLinks) {
        this.allowRecurrentTransitions = allowRecurrentLinks;
    }

    /**
     * @param allowTransitionsFromOutputs
     *            The allowTransitionsFromOutputs to set.
     */
    public void setAllowLinksFromOutputs(boolean allowLinksFromOutputs) {
        this.allowTransitionsFromOutputs = allowLinksFromOutputs;
    }

    public void setInput(int inputNeuronID, double value) {
        this.neurons.get(inputNeuronID).setInput(value);
    }

    public void setOutput(int outputNeuronID, double value) {
        this.neurons.get(outputNeuronID).setOutput(value);
    }

    // public static void main(String[] args) {
    // ParCollection params = GlobalVariables.getPrematureParameters();
    // params.overwriteParameterList(args);
    //
    // NeuralNetwork net = new NeuralNetwork(params);
    // Random rand = new Random();
    //
    // net.setAllowLinksFromOutputs(true);
    // net.setAllowLinksToInputs(false);
    // net.setAllowRecurrentLinks(true);
    //
    // net.addNeuron(0, new ActivationFunctionIdentity(), Neuron.INPUT_NEURON);
    // net.addNeuron(1, new ActivationFunctionStep(), Neuron.OUTPUT_NEURON);
    // for (int i = 0; i < 20; i++) {
    // net.addNeuron(rand.nextInt(net.getNeuronCount() + 1), rand.nextInt(4));
    // }
    // for (int i = 0; i < 100; i++) {
    // net.addLink(rand.nextInt(net.neuronCount), rand.nextInt(net.neuronCount),
    // Math.round(1000 * (rand.nextDouble() - 0.5)) / 1000.0);
    // }
    //
    // for (int i = 0; i < 10; i++) {
    // net.removeNeuron(rand.nextInt(net.neuronCount));
    // }
    //
    // for (int i = 0; i < 1; i++) {
    // BufferedImage img = net.generateNeuroImage(700);
    // net.propagate();
    // StaticMethods.showImage(img, "TEST" + i);
    // }
    // }

    public ParCollection getPars() {
        return this.pars;
    }

    /**
     * The list of soft links applied to this neural network. Note that this is
     * NOT a complete history of construction of the network, but only of the
     * SOFT LINKS!
     */
    private LinkedList<NeuralLinkDummyDouble> constructionSoftLinkHistory = new LinkedList<NeuralLinkDummyDouble>();

    /**
     * The list of soft links applied to this neural network. Note that this is
     * NOT a complete history of construction of the network, but only of the
     * SOFT LINKS!
     *
     * @return Returns the constructionSoftLinkHistory.
     */
    public LinkedList<NeuralLinkDummyDouble> getConstructionSoftLinkHistory() {
        return this.constructionSoftLinkHistory;
    }

    private double absMaxWeight = Double.MAX_VALUE;

    /**
     * Applies a dummy link softly to the network meaning "anything goes": - If
     * source or target neuron is missing, they are inserted beforehand as
     * hidden neurons. Additionally, all neurons with IDs between the highest ID
     * so far and the ID to insert have to be inserted. - If the link already
     * exists, the weight is updated. - All the constrains (such as forbidden
     * recurrent links) are preserved. - All exceptions are ignored. - In the
     * above two cases, nothing happens or a part of the action might be
     * executed (such as inserting the neurons only). - The activation function
     * is the standard activation function.
     *
     * @param link  The link to apply.
     */
    public void applyNeuralLinkSoft(NeuralLinkDummyDouble link) {
        System.out.println(link);
       
        // Store construction history.
        this.constructionSoftLinkHistory.add(new NeuralLinkDummyDouble(link
                .getSourceID(), link.getTargetID(), link.getWeight(), Integer.MAX_VALUE));

        // Apply link.
        int sourceID = (int) Math.abs(link.getSourceID());
        int targetID = (int) Math.abs(link.getTargetID());
        double weight = link.getWeight();

        if (weight > absMaxWeight) {
            weight = absMaxWeight;
        }

        if (weight < -absMaxWeight) {
            weight = -absMaxWeight;
        }

        if (targetID == 0) { // 0 is the bias neuron (no input).
            targetID = 1;
        }

//        if (targetID == sourceID) {
//            targetID = targetID + 1;
//        }

        // Neuron does not exist.
        int i = this.getNeuronCount();
        while (this.getNeuron(sourceID) == null) {
            this.addNeuron(i, Neuron.HIDDEN_NEURON);
            i++;
        }

        // Neuron does not exist.
        i = this.getNeuronCount();
        while (this.getNeuron(targetID) == null) {
            this.addNeuron(i, Neuron.HIDDEN_NEURON);
            i++;
        }

        // Link already exists?
        NeuralLink existingLink = this.getNeuron(targetID).containsLink(sourceID);
        if (existingLink != null) {
            this.setWeight(sourceID, targetID, weight);
        } else {
            this.addLink(sourceID, targetID, weight);
        }
    }

    public boolean setWeight(int sourceID, int targetID, double weight) {
        NeuralLink link = this.getNeuron(targetID).containsLink(sourceID);
       
        if (link != null) {
            link.setWeight(weight);
            return true;
        } else {
            return false;
        }
    }
   
    /**
     * The equals and hashCode methods are based on the current structure of the
     * neural network only. Two neural networks are equal if they have the same
     * number of neurons and every pair of neurons is connected with the same
     * weight (or unconnected) in both cases.
     */
    @Override
    public int hashCode() {
        final int prime = 31;
        int result = 1;
        result = prime * result
                + ((this.neurons == null) ? 0 : this.neurons.hashCode());
        return result;
    }

    /**
     * The equals and hashCode methods are based on the current structure of the
     * neural network only. Two neural networks are equal if they have the same
     * number of neurons and every pair of neurons is connected with the same
     * weight (or unconnected) in both cases.
     *
     * Strictly speaking, only the neurons lists are compared.
     */
    @Override
    public boolean equals(Object obj) {
        if (this == obj)
            return true;
        if (obj == null)
            return false;
        if (getClass() != obj.getClass())
            return false;
        GeneralNeuralNetwork other = (GeneralNeuralNetwork) obj;
        if (this.neurons == null) {
            if (other.neurons != null)
                return false;
        } else if (!this.neurons.equals(other.neurons))
            return false;
        return true;
    }

    /**
     * Returns the weight of the link between the neuron with the fist ID and
     * the neuron with the second ID.
     *
     * @param sourceID
     *            The source neuron's ID.
     * @param targetID
     *            The target neuron's ID.
     *
     * @return The weight of the link if both neurons and the link exist, 0
     *         otherwise.
     */
    public double getWeight(int sourceID, int targetID) {
        if (!this.existsNeuron(sourceID) || !this.existsNeuron(targetID)) {
            return 0;
        }

        NeuralLink link = this.getNeuron(targetID).containsLink(sourceID);

        if (link != null) {
            return link.getWeight();
        } else {
            return 0;
        }
    }

    /**
     * Returns a new (!) neural network as a combination of the neural networks
     * given as a parameter. There, the number of neurons of the resulting
     * network is equal to the maximum number of neurons of the inserted
     * networks. For any link existing in any of the networks, a link is
     * inserted in the resulting network using a combination function to
     * calculate the value of the new weight (e.g., average of the links of the
     * networks). There, non-existing links count as 0 weight.
     *
     * @param networks
     *            The networks to combine.
     *
     * @return A combination of the networks (new instance).
     */
    public GeneralNeuralNetwork generateCombinationNetwork(
            Collection<GeneralNeuralNetwork> networks,
            NeuralCombinationFunctionAverage function) {
        GeneralNeuralNetwork resultingNet = new GeneralNeuralNetwork(pars);
        final double epsilon = 0.00000001;
        int neuronCount = 0;

        // Find maximal neuron count.
        for (GeneralNeuralNetwork n : networks) {
            neuronCount = Math.max(neuronCount, n.neuronCount);
        }

        // Insert new links including neurons.
        for (int sourceID = 0; sourceID < neuronCount; sourceID++) {
            for (int targetID = 0; targetID < neuronCount; targetID++) {
                ArrayList<Double> values = new ArrayList<Double>(
                        networks.size());
                double weight;

                for (GeneralNeuralNetwork n : networks) {
                    values.add(n.getWeight(sourceID, targetID));
                }
                weight = function.calculateValue(values);

                if (Math.abs(weight) > epsilon) {
                    resultingNet.applyNeuralLinkSoft(new NeuralLinkDummyDouble(
                            sourceID, targetID, weight, Integer.MAX_VALUE));
                }
            }
        }

        return resultingNet;
    }

    /**
     * Performs the propagation of values during one step of the neural network.
     * This means that all neurons, increaslingly beginning with 0 compute their
     * net output based on the computed (and given) values of their incoming
     * links.
     */
    public void propagate() {
        for (int i = 0; i < this.neuronCount; i++) {
            this.getNeuron(i).computeAndStoreNetOutput();
        }
    }

    public void resetAllNeuronInputsAndOutputs() {
        for (int i = 0; i < this.getNeuronCount(); i++) {
            this.getNeuron(i).reset();
        }
    }

    public void setAllNeuronOutputsTo(ArrayList<Double> setToValues) {
        for (int i = 0; i < this.neuronCount; i++) {
            this.getNeuron(i).setNetOutput(setToValues.get(i));
        }
    }

    protected HashMap<Integer, Neuron> getNeurons() {
        return this.neurons;
    }

    /**
     * Sets the forward mode and resets the net output to value for all neurons.
     *
     * @param forwardMode
     *            Set forward mode (true) or backward mode (false).
     */
    public void setForwardModeAndResetNetOutput(double value,
            boolean forwardMode) {
        for (Neuron n : this.getNeurons().values()) {
            n.setForwardModeAndResetNetOutput(value, forwardMode);
        }
    }
}
TOP

Related Classes of eas.simulation.brain.neural.GeneralNeuralNetwork

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.