package trust.jfcm.learning.training;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.log4j.Logger;
import trust.jfcm.Concept;
import trust.jfcm.FcmConnection;
import trust.jfcm.WeightedConnection;
import trust.jfcm.learning.EntryStructure;
import trust.jfcm.learning.FcmLearning;
import trust.jfcm.learning.FcmTrainingSet;
import trust.jfcm.learning.FcmTrainingSetEntry;
import trust.jfcm.learning.InputLearningConcept;
import trust.jfcm.learning.LearningConcept;
import trust.jfcm.learning.LearningWeightedConnection;
import trust.jfcm.learning.OutputLearningConcept;
import trust.jfcm.utils.FcmRunner;
import trust.jfcm.utils.SimpleFcmRunner;
/**
* This class implements a training method for learning FCM based on forwarding the errors
* of the output nodes straight to the input nodes.
*
* Additionally, the weight can be updated using the error_weight parameter (instead of the weight parameter itself)
* of the learning connection.
* @author pc
*
*/
public class ForwardPropagation implements FcmTrainer{
/**
* Squared Errors Vector
*/
double[] error_vector;
FcmLearning map;
/******* TRAINING PARAMETERS ********/
/**
* The number of epochs to train through.
*/
int trainingTime = 500;
/**
* Target level of accuracy on the error.
*/
double accuracy = 0.01;
/**
* Level of tolerance in the error change.
*/
double epsilon = 0.0000001;
/**
* Fcm runner instance
*/
FcmRunner runner;
/**
* Current training epoch
*/
int epoch;
/**
* Elapsed time
*/
long start, end;
/**
* Logger
*/
Logger logger;
/**
* Constructor
*/
public ForwardPropagation(FcmLearning map){
this.map = map;
error_vector =new double[trainingTime];
runner = new SimpleFcmRunner(map, 0.000001, 1000);
epoch = 0;
}
/**
* Constructor with logger
*/
public ForwardPropagation(FcmLearning map, Logger logger){
this.map = map;
error_vector =new double[trainingTime];
runner = new SimpleFcmRunner(map, 0.000001, 1000);
epoch = 0;
this.logger = logger;
}
/**
* Train the map and update weights on the current training data set
*/
public void train(){
start = System.currentTimeMillis();
/*
* Calculate delta on weights
*/
double delta = calculateDeltaOnWeights();
map.addDeltaWeight(delta);
FcmTrainingSet trainingSet = map.getTrainingSet();
for(epoch=0; epoch<trainingTime; epoch++){
/*
try {
Thread.sleep(5000);
} catch (InterruptedException e1) {
// TODO Auto-generated catch block
e1.printStackTrace();
}*/
//logger.debug(" - Epoch: "+(epoch+1)+"/"+trainingTime);
for(int i=0; i<trainingSet.size(); i++){
FcmTrainingSetEntry e = trainingSet.get(i);
//logger.debug("\n- Processing "+e);
/*
* Reset concept values and errors
*/
map.resetConcepts();
/*
* Calculate errors on the nodes
*/
calculateErrors(e);
/*
* Update network weights
*/
updateNetworkWeights();
//logger.debug("\n before normalization"+map);
/*
* Normalize network weights
*/
map.normalizeNetworkWeights(true);
//logger.debug("\n before normalization"+map);
//logger.debug("\n"+map);
}
/*
* Check objective functions
*/
if(checkTerminalConditions()){
break;
}
}
end = System.currentTimeMillis();
printInfo();
cleanUp();
}
private double calculateDeltaOnWeights() {
double weight_sum = 0;
Collection<FcmConnection> connectionSet = map.getConnections().values();
double size = connectionSet.size();
for(FcmConnection conn : connectionSet){
WeightedConnection wconn = (WeightedConnection) conn;
weight_sum += wconn.getWeight();
}
return weight_sum / size;
}
private void printInfo(){
logger.debug("\n****** TRAINING MAP ********");
FcmTrainingSet trainingSet = map.getTrainingSet();
for(int k=0; k<trainingSet.size(); k++){
FcmTrainingSetEntry e = trainingSet.get(k);
logger.debug("");
logger.debug("\n- "+e);
/*
* Reset map
*/
map.resetConcepts();
/*
* Insert inputs
*/
setTrainingInputs(e.getInputs());
/*
* Execute map
*/
runner.converge();
/*
*Set input errors
*/
setInputErrors(e);
/*
* compute errors on the output nodes
*/
List<EntryStructure> outputs = e.getOutputs();
for(int i=0; i<outputs.size(); i++){
EntryStructure e1 = outputs.get(i);
Concept c = e1.getConcept();
if(c==null)
System.err.println("Error: Concept "+e1.getConcept().getName()+" not found");
LearningConcept lc = null;
if(c instanceof LearningConcept)
lc = (LearningConcept) c;
if(lc != null && lc.isOutput()){
OutputLearningConcept oc = (OutputLearningConcept) lc;
logger.debug("");
logger.debug(" Output concept: "+c.getName());
logger.debug(" Desirable output: "+oc.getDesirableOutput()+", output: "+oc.getOutput().doubleValue());
logger.debug(" Error on output concept: "+c.getName()+": "+(oc.getDesirableOutput() - oc.getOutput().doubleValue()));
}
}
}
if(epoch==error_vector.length)
logger.debug("\n Squeared mean error: "+Math.sqrt(error_vector[epoch-1]));
else
logger.debug("\n Squeared mean error: "+Math.sqrt(error_vector[epoch]));
logger.debug(" Training epoch: "+(epoch)+"/"+trainingTime);
logger.debug(" Elapsed time: "+(end-start)/1000F+" sec \n");
map.resetConcepts();
}
/**
* Store the squared error on the output nodes in the error vector
*/
private void measureOutputError(List<EntryStructure> outputs) {
/*
* compute errors on the output nodes
*/
double error = 0;
for(int i=0; i<outputs.size(); i++){
EntryStructure e1 = outputs.get(i);
Concept c = e1.getConcept();
if(c==null)
System.err.println("Error: Concept "+e1.getConcept().getName()+" not found");
LearningConcept lc = null;
if(c instanceof LearningConcept)
lc = (LearningConcept) c;
if(lc != null && lc.isOutput()){
OutputLearningConcept oc = (OutputLearningConcept) lc;
error += Math.pow(oc.getOutput() - e1.getValue(), 2);
//logger.debug(" OutputNode: "+c.getName()+": "+c.getOutput()+", desirable value: "+e1.getValue());
}
}
/* Update squared error vector*/
//logger.debug("epoch:"+epoch);
if(error_vector[epoch]==0){
//logger.debug(" error is: "+error);
error_vector[epoch] = (error);
}
else{
double prev_error = error_vector[epoch];
//logger.debug(" update error: prev_err:"+prev_error+" error: "+error);
error_vector[epoch] = prev_error + error;
//logger.debug(error_vector[epoch]);
}
}
/**
* Calculate the overall squared error of the current map
* on the list of given outputs
* @param outputs: training set entry of the output data
* @return: the number of set input nodes
*/
private int setInputErrors(FcmTrainingSetEntry e){
List<EntryStructure> outputs = e.getOutputs();
//logger.debug("- Calculating errors:");
/*
* Set desirable outputs and compute the mean error
*/
double sum=0;
for(int i=0; i<outputs.size(); i++){
EntryStructure e1 = outputs.get(i);
Concept c = e1.getConcept();
if(c==null)
System.err.println("Error: Concept "+e1.getConcept().getName()+" not found");
LearningConcept lc = null;
if(c instanceof LearningConcept)
lc = (LearningConcept) c;
if(lc != null && lc.isOutput()){
OutputLearningConcept oc = (OutputLearningConcept) lc;
oc.setDesirableOutput(e1.getValue());
sum += oc.getDesirableOutput() - oc.getOutput().doubleValue();
}
}
double mean_error = sum / outputs.size();
//logger.debug("The sum error is: "+mean_error);
/*
* Put mean error on the input nodes of the whole map
*/
int num=0;
Iterator<Concept> it = map.getConceptsIterator();
while(it.hasNext()){
Concept c = it.next();
LearningConcept lc = null;
if(c instanceof LearningConcept)
lc = (LearningConcept) c;
if(lc != null && lc.isInput()){
InputLearningConcept ic = (InputLearningConcept) lc;
ic.setError(mean_error);
num++;
//logger.debug("Error on concept "+ic.getName()+": "+ic.getError());
}
}
return num;
}
/**
* This method set the errors on the concepts through the network.
* @param e
*/
private void calculateErrors(FcmTrainingSetEntry e){
/*
* Insert inputs
*/
setTrainingInputs(e.getInputs());
/*
* Execute map
*/
runner.converge();
//logger.debug("...map executed");
//logger.debug(this);
/*
* Evaluate error on the output nodes
*/
int set_nodes = setInputErrors(e);
/*
* **Calculate errors on the concepts.**
*
* Repeatedly call the calculateError() function until errors
* are computed on all the nodes
*
*/
//logger.debug(" Propagate errors...");
//printErrorFlag();
int num_completed = set_nodes;
while(num_completed < map.getNumLearningConcepts()){
Iterator<Concept> it = map.getConceptsIterator();
while(it.hasNext()){
Concept c = it.next();
if(c instanceof LearningConcept){
LearningConcept lc = (LearningConcept) c;
if(!lc.isErrorCalculated()){
//error to be calculated
//logger.debug(" error to be calculated for " + lc.getName());
if(calculateErrorOnConcept(lc)){
//calculation successful
num_completed++;
}
}
}
//logger.debug(" Completed "+num_completed+"/"+map.getNumLearningConcepts());
}
}
/*
* Measure errors
*/
measureOutputError(e.getOutputs());
}
/**
* This function calculates the error on the current node in
* the back propagation fashion. Error is the weighted sum of
* the errors of the outgoing nodes.
* @param node The node to calculate the error for.
* @return true if the node is successfully updated.
*/
private boolean calculateErrorOnConcept(LearningConcept c){
//logger.debug(" Processing concept "+c);
if(c.isErrorCalculated()){
//logger.debug(" Error already calculated for "+c);
return false;
}
/*
* Specialization of the calculateError() method.
* For output nodes the error is the difference with the desirable output.
*/
if(c.isInput()){
return true;
}
/*
* Weighted sum of the errors on the incoming nodes.
*/
Set<FcmConnection> in = c.getInConnections();
Iterator<FcmConnection> it = in.iterator();
double val = 0;
while(it.hasNext()){
FcmConnection conn = it.next();
LearningWeightedConnection wconn = null;
if(conn instanceof LearningWeightedConnection)
wconn = (LearningWeightedConnection) conn;
Concept cin = wconn.getFrom();
LearningConcept lin = null;
if(cin instanceof LearningConcept)
lin = (LearningConcept) cin;
if(!lin.isErrorCalculated()){
//logger.debug(" Error for "+c+" cannot be yet calculated because "+lin+" is not updated");
/*
* Error cannot be computed yet
*/
return false;
}
val += wconn.getWeight() * lin.getError();
//val += wconn.getWeight()* wconn.getWeightUncertainty() * lin.getError();
}
/*
* Set Error
*/
c.setError(val);
//logger.debug("Error on concept "+c.getName()+": "+c.getError());
return true;
}
/**
* Stop the training cycle if the terminal conditions are true.
* @return
*/
private boolean checkTerminalConditions() {
if(Math.sqrt(error_vector[epoch])<=accuracy){
//logger.debug("STOP accuracy");
return true;
}
//int size = error_vector.length;
if(epoch>1){
double error_1 = Math.sqrt(error_vector[epoch-1]);
double error_2 = Math.sqrt(error_vector[epoch-2]);
double delta = Math.abs(error_1 - error_2);
//logger.debug("Error(-1): "+error_1+" Error(-2): "+error_2+" Delta: "+delta+" epsilon: "+epsilon);
if(delta <= epsilon){
logger.debug("STOP epsilon");
return true;
}
}
return false;
}
/**
* Reset the data structures after the training
*/
private void cleanUp() {
for(int i=0; i<epoch; i++) error_vector[i] = 0;
epoch = 0;
start = 0;
end = 0;
Collection<FcmConnection> cons = map.getConnections().values();
for(FcmConnection f : cons){
if(f instanceof LearningWeightedConnection){
LearningWeightedConnection w = (LearningWeightedConnection) f;
w.setChangeInWeight(0);
}
}
}
/**
* Set the inputs to the map from the given training set entry
* @param inputs: training set entry of input data
*/
private void setTrainingInputs(List<EntryStructure> inputs){
for(int i=0; i<inputs.size(); i++){
EntryStructure e1 = inputs.get(i);
//logger.debug("EntryStructure: "+e1);
LearningConcept lc = (LearningConcept) e1.getConcept();
if(lc != null && lc.isInput()){
InputLearningConcept ilc = (InputLearningConcept) lc;
ilc.setInput(e1.getValue());
}
}
}
/**
* Update weights of the connections.
*
* Iteratively call the updateWeights function in all the nodes
* until the whole network is updated
*/
private void updateNetworkWeights(){
//logger.debug("- Update weights...");
Iterator<Concept> it = map.getConcepts().values().iterator();
while(it.hasNext()){
Concept c = it.next();
if(c instanceof LearningConcept){
LearningConcept lc = (LearningConcept) c;
Set<FcmConnection> out = lc.getOutConnections();
Iterator<FcmConnection> it1 = out.iterator();
while(it1.hasNext()){
FcmConnection conn = it1.next();
LearningWeightedConnection wconn = null;
if(conn instanceof LearningWeightedConnection){
wconn = (LearningWeightedConnection) conn;
/*
* Update learning connection
*/
lc.getTrainingFunction().update(wconn);
}
}
}
}
}
@Override
public void setMap(FcmLearning map) {
this.map = map;
}
}