package trust.jfcm.learning;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;
import javax.xml.transform.Source;
import javax.xml.transform.Transformer;
import javax.xml.transform.TransformerFactory;
import javax.xml.transform.dom.DOMSource;
import javax.xml.transform.stream.StreamResult;
import javax.xml.transform.stream.StreamSource;
import javax.xml.validation.Schema;
import javax.xml.validation.SchemaFactory;
import javax.xml.xpath.XPath;
import javax.xml.xpath.XPathConstants;
import javax.xml.xpath.XPathExpressionException;
import javax.xml.xpath.XPathFactory;
import org.apache.commons.lang.StringUtils;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.NodeList;
import org.xml.sax.SAXException;
import trust.jfcm.CognitiveMap;
import trust.jfcm.Concept;
import trust.jfcm.FcmConnection;
import trust.jfcm.HyperbolicTangentActivator;
import trust.jfcm.IdentityActivator;
import trust.jfcm.SigmoidActivator;
import trust.jfcm.SignumActivator;
import trust.jfcm.WeightedConnection;
import trust.jfcm.learning.supervised.*;
import trust.jfcm.learning.unsupervised.*;
import trust.jfcm.utils.FcmIO;
public class FcmLearningIO{
public static void saveAsXml(CognitiveMap map, OutputStream outputStream) {
List<CognitiveMap> maps = new ArrayList<CognitiveMap>(1);
maps.add(map);
saveAsXml(maps, outputStream);
}
public static void saveAsXml(List<CognitiveMap> maps, OutputStream outputStream) {
try {
DocumentBuilder documentBuilder = getDocumentBuilder(false);
Document doc = documentBuilder.newDocument();
Element mapsElem = doc.createElement("maps");
doc.appendChild(mapsElem);
for (CognitiveMap map : maps) {
appendMap(map, doc, mapsElem);
}
Transformer transformer = TransformerFactory.newInstance().newTransformer();
transformer.transform(new DOMSource(doc), new StreamResult(outputStream));
} catch (Exception ex) {
throw new RuntimeException("Error saving XML file", ex);
} finally {
try {
outputStream.flush();
outputStream.close();
} catch (IOException ioex) {
throw new RuntimeException("Error closing stream", ioex);
}
}
}
/*
* @author MPiunti
*/
public static Map<String,FcmLearning> loadMapFromXml(InputStream inputStream) {
try {
DocumentBuilder documentBuilder = getDocumentBuilder(false);
Document doc = documentBuilder.parse(inputStream);
XPath xpath = XPathFactory.newInstance().newXPath();
Map<String,FcmLearning> maps = new HashMap<String,FcmLearning>();
NodeList nodelist = (NodeList) xpath.evaluate("/maps/map", doc, XPathConstants.NODESET);
for (int i = 0; i < nodelist.getLength(); i++) {
Element mapElem = (Element) nodelist.item(i);
maps.put(mapElem.getAttribute("name"), parseMap(xpath, mapElem));
//add(mapElem.getAttribute("name"), parseMap(xpath, mapElem));
}
return maps;
} catch (Exception ex) {
throw new RuntimeException("Error loading XML file", ex);
} finally {
try {
inputStream.close();
} catch (IOException ioex) {
throw new RuntimeException("Error closing stream", ioex);
}
}
}
public static FcmTrainingSet loadTrainingSetFromFile(FileReader input, FcmLearning map){
BufferedReader br = new BufferedReader(input);
FcmTrainingSet dataset = new FcmTrainingSet(map);
try {
String line = br.readLine();
while(line!=null){
String[] st = line.split("\t");
HashMap<String, Double> inputs = new HashMap<String, Double>();
HashMap<String, Double> outputs = new HashMap<String, Double>();
inputs.put(st[0],1.0);
inputs.put(st[1],1.0);
inputs.put(st[2],1.0);
outputs.put(st[3], Double.parseDouble(st[4]));
dataset.addEntry(inputs, outputs);
line = br.readLine();
}
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
return dataset;
}
/*
* private stuff
*/
private static DocumentBuilder getDocumentBuilder(boolean strict) throws SAXException,
ParserConfigurationException {
DocumentBuilderFactory documentBuilderFactory = DocumentBuilderFactory.newInstance();
if (strict) {
Source schemaSource = new StreamSource(FcmIO.class
.getResourceAsStream("/JFCM-map-v-1.0.xsd"));
Schema schema = SchemaFactory.newInstance("http://www.w3.org/2001/XMLSchema")
.newSchema(schemaSource);
documentBuilderFactory.setSchema(schema);
documentBuilderFactory.setValidating(true);
}
DocumentBuilder documentBuilder = documentBuilderFactory.newDocumentBuilder();
return documentBuilder;
}
private static void appendMap(CognitiveMap map, Document doc, Element mapsElem) {
Element mapElem = doc.createElement("map");
mapElem.setAttribute("name", map.getName());
mapsElem.appendChild(mapElem);
appendConcepts(map, doc, mapElem);
appendConnections(map, doc, mapElem);
}
private static void appendConcepts(CognitiveMap map, Document doc, Element elem) {
Element elemConcepts = doc.createElement("concepts");
elem.appendChild(elemConcepts);
Iterator<Concept> cIter = map.getConceptsIterator();
while (cIter.hasNext()) {
Concept c = cIter.next();
Element elemConcept = doc.createElement("concept");
elemConcept.setAttribute("name", c.getName());
if (c.getDescription() != null) {
Element elemDescription = doc.createElement("description");
elemDescription.setTextContent(c.getDescription());
elemConcept.appendChild(elemDescription);
}
if (c.getConceptActivator() != null) {
if (c.getConceptActivator() instanceof SignumActivator) {
SignumActivator act = (SignumActivator) c.getConceptActivator();
elemConcept.setAttribute("act", "SIGNUM");
Element paramElem = doc.createElement("param");
paramElem.setAttribute("name", "threshold");
paramElem.setAttribute("value", Double.toString(act.getThreshold()));
elemConcept.appendChild(paramElem);
} else if (c.getConceptActivator() instanceof SigmoidActivator) {
SigmoidActivator act = (SigmoidActivator) c.getConceptActivator();
elemConcept.setAttribute("act", "SIGMOID");
Element paramElem = doc.createElement("param");
paramElem.setAttribute("name", "k");
paramElem.setAttribute("value", Double.toString(act.getK()));
elemConcept.appendChild(paramElem);
} else if (c.getConceptActivator() instanceof HyperbolicTangentActivator) {
HyperbolicTangentActivator act = (HyperbolicTangentActivator) c
.getConceptActivator();
elemConcept.setAttribute("act", "TANH");
Element paramElem = doc.createElement("param");
paramElem.setAttribute("name", "threshold");
paramElem.setAttribute("value", Double.toString(act.getThreshold()));
elemConcept.appendChild(paramElem);
}
}
if (c.getInput() != null) {
elemConcept.setAttribute("input", c.getInput().toString());
}
if (c.getOutput() != null) {
elemConcept.setAttribute("output", c.getOutput().toString());
}
if (c.isFixedOutput()) {
elemConcept.setAttribute("fixed", "true");
}
elemConcepts.appendChild(elemConcept);
}
}
private static void appendConnections(CognitiveMap map, Document doc, Element elem) {
Element elemConnections = doc.createElement("connections");
elem.appendChild(elemConnections);
Iterator<FcmConnection> connIter = map.getConnectionsIterator();
while (connIter.hasNext()) {
FcmConnection conn = connIter.next();
Element elemConn = doc.createElement("connection");
elemConn.setAttribute("name", conn.getName());
if (conn.getDescription() != null) {
Element elemDescription = doc.createElement("description");
elemDescription.setTextContent(conn.getDescription());
elemConn.appendChild(elemDescription);
}
if (conn.getFrom() != null) {
elemConn.setAttribute("from", conn.getFrom().getName());
}
if (conn.getTo() != null) {
elemConn.setAttribute("to", conn.getTo().getName());
}
if (conn instanceof WeightedConnection) {
elemConn.setAttribute("type", "WEIGHTED");
WeightedConnection wc = (WeightedConnection) conn;
Element param = doc.createElement("param");
param.setAttribute("name", "name");
param.setAttribute("value", Double.toString(wc.getWeight()));
elemConn.appendChild(param);
} else {
throw new UnsupportedOperationException(
"FcmConnection implementation not supported: " + conn.getClass().getName());
}
elemConnections.appendChild(elemConn);
}
}
private static FcmLearning parseMap(XPath xpath, Element mapElem) throws Exception {
NodeList nodelist;
FcmLearning map = new FcmLearning();
map.setName(xpath.evaluate("@name", mapElem));
nodelist = (NodeList) xpath.evaluate("concepts/concept", mapElem, XPathConstants.NODESET);
for (int i = 0; i < nodelist.getLength(); i++) {
Element conceptElem = (Element) nodelist.item(i);
map.addConcept(parseLearningConcept(xpath, conceptElem));
}
nodelist = (NodeList) xpath.evaluate("connections/connection", mapElem,
XPathConstants.NODESET);
for (int i = 0; i < nodelist.getLength(); i++) {
Element connElem = (Element) nodelist.item(i);
map.addConnection(parseLearningConnection(map, xpath, connElem));
}
// TODO optimize
Iterator<FcmConnection> iter = map.getConnectionsIterator();
while (iter.hasNext()) {
FcmConnection conn = (FcmConnection) iter.next();
map.connect(conn.getFrom().getName(), conn.getName(), conn.getTo().getName());
}
return map;
}
private static Concept parseLearningConcept(XPath xpath, Element conceptElem)
throws XPathExpressionException {
/*
* Parse Learning Concept
*/
String type = conceptElem.getAttribute("type");
if (StringUtils.isBlank(type)) {
return parseConcept(xpath, conceptElem);
}
LearningConcept c = null;
if (type.equals("LEARNING")) {
c = new LearningConcept();
} else if (type.equals("INPUT_LEARNING")) {
c = new InputLearningConcept();
} else if (type.equals("OUTPUT_LEARNING")) {
c = new OutputLearningConcept();
}
c.setName(xpath.evaluate("@name", conceptElem));
String description = xpath.evaluate("description/text()", conceptElem);
if (StringUtils.isNotBlank(description)) {
c.setDescription(description);
}
/* Activation function */
String actAttr = conceptElem.getAttribute("act");
Element thresholdParam = (Element) xpath.evaluate("param[@name='threshold']", conceptElem,
XPathConstants.NODE);
Double threshold = null;
if (thresholdParam != null) {
threshold = Double.parseDouble(xpath.evaluate("@value", thresholdParam));
}
if ("SIGNUM".equals(actAttr)) {
SignumActivator act = new SignumActivator();
if (threshold != null) {
act.setThreshold(threshold);
}
c.setConceptActivator(act);
} else if ("SIGMOID".equals(actAttr)) {
SigmoidActivator act = new SigmoidActivator();
Element kParam = (Element) xpath.evaluate("param[@name='k']", conceptElem,
XPathConstants.NODE);
Double k = null;
if (kParam != null) {
k = Double.parseDouble(xpath.evaluate("@value", kParam));
}
if (k != null) {
act.setK(k);
}
c.setConceptActivator(act);
} else if ("TANH".equals(actAttr)) {
HyperbolicTangentActivator act = new HyperbolicTangentActivator();
//add alpha parameter
Element alphaParam = (Element) xpath.evaluate("param[@name='alpha']", conceptElem,
XPathConstants.NODE);
Double alpha = null;
if (alphaParam != null) {
alpha = Double.parseDouble(xpath.evaluate("@value", alphaParam));
}
if (alpha != null) {
act.setAlpha(alpha);
}
if (threshold != null) {
act.setThreshold(threshold);
}
c.setConceptActivator(act);
} else if ("IDENTITY".equals(actAttr)) {
IdentityActivator act = new IdentityActivator();
c.setConceptActivator(act);
}
/* Input & Output*/
String s;
s = xpath.evaluate("@input", conceptElem);
if (StringUtils.isNotBlank(s)) {
c.setInput(Double.parseDouble(s));
}
s = xpath.evaluate("@output", conceptElem);
if (StringUtils.isNotBlank(s)) {
c.setOutput(Double.parseDouble(s));
}
s = xpath.evaluate("@fixed", conceptElem);
if (StringUtils.isNotBlank(s)) {
c.setFixedOutput(Boolean.parseBoolean(s));
}
/* Learning method */
Element methodParam = (Element) xpath.evaluate("param[@name='method']", conceptElem,
XPathConstants.NODE);
if (methodParam != null) {
String method = (xpath.evaluate("@value", methodParam));
//System.out.println(method);
if (method.equals("LinearBP")) {
LinearBP rule = new LinearBP();
c.setTrainingFunction(rule);
}
if ("NHL".equals(method)) {
NHL rule = new NHL();
c.setTrainingFunction(rule);
}
}
//System.out.println(c.getTrainingFunction());
return c;
}
private static FcmConnection parseLearningConnection(FcmLearning map, XPath xpath, Element connElem)
throws Exception {
FcmConnection conn;
String s;
String name;
name = xpath.evaluate("@name", connElem);
if (StringUtils.isBlank(name)) {
throw new Exception("Missing connection name");
}
String type = xpath.evaluate("@type", connElem);
/* Set weight if weighted*/
if ("WEIGHTED".equalsIgnoreCase(type)) {
LearningWeightedConnection wConn = new LearningWeightedConnection();
NodeList params = (NodeList) xpath.evaluate("param", connElem, XPathConstants.NODESET);
Element param = (Element) params.item(0);
s = xpath.evaluate("@value", param);
if ("weight".equals(param.getAttribute("name"))) {
wConn.setWeight(Double.parseDouble(s));
// by default the uncertainty is set equal to the weight
wConn.setWeightUncertainty(wConn.getWeight());
}
/* set weight uncertainty*/
if(params.getLength()>1){
param = (Element) params.item(1);
s = xpath.evaluate("@value", param);
if ("uncertainty".equals(param.getAttribute("name"))) {
wConn.setWeightUncertainty(Double.parseDouble(s));
}
}
conn = wConn;
} else {
throw new UnsupportedOperationException("Connection type not supported: \"" + type
+ "\"");
}
conn.setName(name);
String description = xpath.evaluate("description/text()", connElem);
if (StringUtils.isNotBlank(description)) {
conn.setDescription(description);
}
Concept c;
s = xpath.evaluate("@from", connElem);
if (StringUtils.isBlank(s)) {
throw new Exception("Missing \"from\" reference in connection \"" + conn.getName()
+ "\"");
}
c = map.getConcept(s);
if (c == null) {
throw new Exception("Missing \"from\" reference in connection \"" + conn.getName()
+ "\"");
}
conn.setFrom(c);
s = xpath.evaluate("@to", connElem);
if (StringUtils.isBlank(s)) {
throw new Exception("Missing \"to\" reference in connection \"" + conn.getName() + "\"");
}
c = map.getConcept(s);
if (c == null) {
throw new Exception("Missing \"to\" reference in connection \"" + conn.getName() + "\"");
}
conn.setTo(c);
//System.out.println(conn);
return conn;
}
private static Concept parseConcept(XPath xpath, Element conceptElem)
throws XPathExpressionException {
Concept c = new Concept();
c.setName(xpath.evaluate("@name", conceptElem));
String description = xpath.evaluate("description/text()", conceptElem);
if (StringUtils.isNotBlank(description)) {
c.setDescription(description);
}
/* threshold parameter */
Element thresholdParam = (Element) xpath.evaluate("param[@name='threshold']", conceptElem,
XPathConstants.NODE);
Double threshold = null;
if (thresholdParam != null) {
threshold = Double.parseDouble(xpath.evaluate("@value", thresholdParam));
}
/* act attribute */
String actAttr = conceptElem.getAttribute("act");
if ("SIGNUM".equals(actAttr)) {
SignumActivator act = new SignumActivator();
if (threshold != null) {
act.setThreshold(threshold);
}
c.setConceptActivator(act);
} else if ("SIGMOID".equals(actAttr)) {
SigmoidActivator act = new SigmoidActivator();
Element kParam = (Element) xpath.evaluate("param[@name='k']", conceptElem,
XPathConstants.NODE);
Double k = null;
if (kParam != null) {
k = Double.parseDouble(xpath.evaluate("@value", kParam));
}
if (k != null) {
act.setK(k);
}
c.setConceptActivator(act);
} else if ("TANH".equals(actAttr)) {
HyperbolicTangentActivator act = new HyperbolicTangentActivator();
/* add alpha parameter */
Element alphaParam = (Element) xpath.evaluate("param[@name='alpha']", conceptElem,
XPathConstants.NODE);
Double alpha = null;
if (alphaParam != null) {
alpha = Double.parseDouble(xpath.evaluate("@value", alphaParam));
}
if (alpha != null) {
act.setAlpha(alpha);
}
if (threshold != null) {
act.setThreshold(threshold);
}
c.setConceptActivator(act);
} else if ("IDENTITY".equals(actAttr)) {
IdentityActivator act = new IdentityActivator();
c.setConceptActivator(act);
}
String s;
s = xpath.evaluate("@input", conceptElem);
if (StringUtils.isNotBlank(s)) {
c.setInput(Double.parseDouble(s));
}
s = xpath.evaluate("@output", conceptElem);
if (StringUtils.isNotBlank(s)) {
c.setOutput(Double.parseDouble(s));
}
s = xpath.evaluate("@fixed", conceptElem);
if (StringUtils.isNotBlank(s)) {
c.setFixedOutput(Boolean.parseBoolean(s));
}
return c;
}
}