package trust.jfcm.learning.training;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import trust.jfcm.Concept;
import trust.jfcm.FcmConnection;
import trust.jfcm.WeightedConnection;
import trust.jfcm.learning.EntryStructure;
import trust.jfcm.learning.FcmLearning;
import trust.jfcm.learning.FcmTrainingSet;
import trust.jfcm.learning.FcmTrainingSetEntry;
import trust.jfcm.learning.InputLearningConcept;
import trust.jfcm.learning.LearningConcept;
import trust.jfcm.learning.LearningWeightedConnection;
import trust.jfcm.learning.OutputLearningConcept;
import trust.jfcm.utils.FcmRunner;
import trust.jfcm.utils.SimpleFcmRunner;
/**
* This class implement the linear regression
* for the FCM training used in the backward propagation
* @author pc
*
*/
public class LinearRegression implements FcmTrainer{
/**
* Squared Errors Vector
*/
double[] error_vector;
FcmLearning map;
/******* TRAINING PARAMETERS ********/
/**
* The number of epochs to train through.
*/
int trainingTime = 500;
/**
* Target level of accuracy on the error.
*/
double accuracy = 0.01;
/**
* Level of tolerance in the error change.
*/
double epsilon = 0.0000001;
/**
* Fcm runner instance
*/
FcmRunner runner;
/**
* Current training epoch
*/
int epoch;
/**
* Elapsed time
*/
long start, end;
/**
* Constructor
*/
public LinearRegression(FcmLearning map){
this.map = map;
error_vector =new double[trainingTime];
runner = new SimpleFcmRunner(map, 0.000001, 1000);
epoch = 0;
}
/**
* Train the map and update weights on the current training data set
*/
public void train(){
start = System.currentTimeMillis();
FcmTrainingSet trainingSet = map.getTrainingSet();
for(epoch=0; epoch<trainingTime; epoch++){
/*
try {
Thread.sleep(5000);
} catch (InterruptedException e1) {
// TODO Auto-generated catch block
e1.printStackTrace();
}*/
//System.out.println(" - Epoch: "+(epoch+1)+"/"+trainingTime);
for(int i=0; i<trainingSet.size(); i++){
FcmTrainingSetEntry e = trainingSet.get(i);
//System.out.println("\n- Processing "+e);
/*
* Reset concept values and errors
*/
map.resetConcepts();
/*
* Calculate errors on the nodes
*/
calculateErrors(e);
/*
* Update network weights
*/
updateNetworkWeights();
/*
* Normalize network weights
*/
map.normalizeNetworkWeights(true);
//System.out.println("\n"+map);
}
/*
* Check objective functions
*/
if(checkTerminalConditions()){
break;
}
}
end = System.currentTimeMillis();
printInfo();
cleanUp();
}
private void printInfo(){
System.out.println("\n****** TRAINING MAP ********");
FcmTrainingSet trainingSet = map.getTrainingSet();
for(int k=0; k<trainingSet.size(); k++){
FcmTrainingSetEntry e = trainingSet.get(k);
System.out.println("\n- "+e);
/*
* Reset map
*/
map.resetConcepts();
/*
* Insert inputs
*/
setTrainingInputs(e.getInputs());
/*
* Evaluate error on the output nodes
*/
setDesirableOutputs(e.getOutputs());
/*
* Execute map
*/
runner.converge();
/*
* compute errors on the output nodes
*/
List<EntryStructure> outputs = e.getOutputs();
for(int i=0; i<outputs.size(); i++){
EntryStructure e1 = outputs.get(i);
Concept c = e1.getConcept();
if(c==null)
System.err.println("Error: Concept "+e1.getConcept().getName()+" not found");
LearningConcept lc = null;
if(c instanceof LearningConcept)
lc = (LearningConcept) c;
if(lc != null && lc.isOutput()){
OutputLearningConcept oc = (OutputLearningConcept) lc;
System.out.println(" Output concept: "+c.getName());
System.out.println(" Desirable output: "+oc.getDesirableOutput()+", output: "+oc.getOutput().doubleValue());
System.out.println(" Error on output concept: "+c.getName()+": "+(oc.getDesirableOutput() - oc.getOutput().doubleValue()));
}
}
}
if(epoch==error_vector.length)
System.out.println("\n Squeared mean error: "+Math.sqrt(error_vector[epoch-1]));
else
System.out.println("\n Squeared mean error: "+Math.sqrt(error_vector[epoch]));
System.out.println(" Training epoch: "+(epoch)+"/"+trainingTime);
System.out.println(" Elapsed time: "+(end-start)/1000F+" sec \n");
map.resetConcepts();
}
/**
* Store the squared error on the output nodes in the error vector
*/
private void measureOutputError(List<EntryStructure> outputs) {
/*
* compute errors on the output nodes
*/
double error = 0;
for(int i=0; i<outputs.size(); i++){
EntryStructure e1 = outputs.get(i);
Concept c = e1.getConcept();
if(c==null)
System.err.println("Error: Concept "+e1.getConcept().getName()+" not found");
LearningConcept lc = null;
if(c instanceof LearningConcept)
lc = (LearningConcept) c;
if(lc != null && lc.isOutput()){
OutputLearningConcept oc = (OutputLearningConcept) lc;
error += Math.pow(oc.getOutput() - e1.getValue(), 2);
//System.out.println(" OutputNode: "+c.getName()+": "+c.getOutput()+", desirable value: "+e1.getValue());
}
}
/* Update squared error vector*/
//System.out.println("epoch:"+epoch);
if(error_vector[epoch]==0){
//System.out.println(" error is: "+error);
error_vector[epoch] = (error);
}
else{
double prev_error = error_vector[epoch];
//System.out.println(" update error: prev_err:"+prev_error+" error: "+error);
error_vector[epoch] = prev_error + error;
//System.out.println(error_vector[epoch]);
}
}
/**
* Calculate the overall squared error of the current map
* on the list of given outputs
* @param outputs: training set entry of the output data
*/
private void setDesirableOutputs(List<EntryStructure> outputs){
//System.out.println("- Calculating errors:");
/*
* Set desirable outputs
*/
for(int i=0; i<outputs.size(); i++){
EntryStructure e1 = outputs.get(i);
Concept c = e1.getConcept();
if(c==null)
System.err.println("Error: Concept "+e1.getConcept().getName()+" not found");
LearningConcept lc = null;
if(c instanceof LearningConcept)
lc = (LearningConcept) c;
if(lc != null && lc.isOutput()){
OutputLearningConcept oc = (OutputLearningConcept) lc;
oc.setDesirableOutput(e1.getValue());
//System.out.println(" OutputNode: "+c.getName()+": "+c.getOutput()+", desirable value: "+e1.getValue());
}
}
}
/**
*
* @param e
*/
private void calculateErrors(FcmTrainingSetEntry e){
/*
* Insert inputs
*/
setTrainingInputs(e.getInputs());
/*
* Evaluate error on the output nodes
*/
setDesirableOutputs(e.getOutputs());
/*
* Execute map
*/
runner.converge();
//System.out.println("...map executed");
//System.out.println(this);
/*
* **Calculate errors on the concepts.**
*
* Repeatedly call the calculateError() function until errors
* are computed on all the nodes
*
*/
//System.out.println(" Propagate errors...");
//printErrorFlag();
int num_completed = 0;
while(num_completed < map.getNumLearningConcepts()){
Iterator<Concept> it = map.getConceptsIterator();
while(it.hasNext()){
Concept c = it.next();
if(c instanceof LearningConcept){
LearningConcept lc = (LearningConcept) c;
if(!lc.isErrorCalculated()){
//error to be calculated
//System.out.println(" error to be calculated for " + lc.getName());
if(calculateErrorOnConcept(lc)){
//calculation successful
num_completed++;
//System.out.println(" error on concept "+lc.getName()+": "+lc.getError());
}
} //else System.out.println(" error already calculated for " + lc.getName() + " (completed "+num_completed+")");
}
}
}
/*
* Measure errors
*/
measureOutputError(e.getOutputs());
}
/**
* This function calculates the error on the current node in
* the back propagation fashion. Error is the weighted sum of
* the errors of the outgoing nodes.
* @param node The node to calculate the error for.
* @return true if the node is successfully updated.
*/
private boolean calculateErrorOnConcept(LearningConcept c){
if(c.isErrorCalculated())
return false;
/*
* Specialization of the calculateError() method.
* For output nodes the error is the difference with the desirable output.
*/
if(c.isOutput()){
OutputLearningConcept oc = (OutputLearningConcept) c;
oc.setError(oc.getDesirableOutput() - oc.getOutput().doubleValue());
return true;
}
/*
* Weighted sum of the errors on the outgoing nodes.
*/
Set<FcmConnection> out = c.getOutConnections();
Iterator<FcmConnection> it = out.iterator();
double val = 0;
while(it.hasNext()){
FcmConnection conn = it.next();
WeightedConnection wconn = null;
if(conn instanceof WeightedConnection)
wconn = (WeightedConnection) conn;
Concept cout = wconn.getTo();
LearningConcept lout = null;
if(cout instanceof LearningConcept)
lout = (LearningConcept) cout;
if(!lout.isErrorCalculated())
/*
* Error cannot be computed yet
*/
return false;
val += wconn.getWeight() * lout.getError();
}
/*
* Set Error
*/
c.setError(val);
//System.out.println("Error on concept "+c.getName()+": "+c.getError());
return true;
}
/**
* Stop the training cycle if the terminal conditions are true.
* @return
*/
private boolean checkTerminalConditions() {
if(Math.sqrt(error_vector[epoch])<=accuracy){
//System.out.println("STOP accuracy");
return true;
}
//int size = error_vector.length;
if(epoch>1){
double error_1 = Math.sqrt(error_vector[epoch-1]);
double error_2 = Math.sqrt(error_vector[epoch-2]);
double delta = Math.abs(error_1 - error_2);
//System.out.println("Error(-1): "+error_1+" Error(-2): "+error_2+" Delta: "+delta+" epsilon: "+epsilon);
if(delta <= epsilon){
System.out.println("STOP epsilon");
return true;
}
}
return false;
}
/**
* Reset the data structures after the training
*/
private void cleanUp() {
for(int i=0; i<epoch; i++) error_vector[i] = 0;
epoch = 0;
start = 0;
end = 0;
Collection<FcmConnection> cons = map.getConnections().values();
for(FcmConnection f : cons){
if(f instanceof LearningWeightedConnection){
LearningWeightedConnection w = (LearningWeightedConnection) f;
w.setChangeInWeight(0);
}
}
}
/**
* Set the inputs to the map from the given training set entry
* @param inputs: training set entry of input data
*/
private void setTrainingInputs(List<EntryStructure> inputs){
for(int i=0; i<inputs.size(); i++){
EntryStructure e1 = inputs.get(i);
//System.out.println("EntryStructure: "+e1);
LearningConcept lc = (LearningConcept) e1.getConcept();
if(lc != null && lc.isInput()){
InputLearningConcept ilc = (InputLearningConcept) lc;
ilc.setInput(e1.getValue());
}
}
}
/**
* Update weights of the connections.
*
* Iteratively call the updateWeights function in all the nodes
* until the whole network is updated
*/
private void updateNetworkWeights(){
//System.out.println("- Update weights...");
Iterator<Concept> it = map.getConcepts().values().iterator();
while(it.hasNext()){
Concept c = it.next();
if(c instanceof LearningConcept){
LearningConcept lc = (LearningConcept) c;
Set<FcmConnection> out = lc.getOutConnections();
Iterator<FcmConnection> it1 = out.iterator();
while(it1.hasNext()){
FcmConnection conn = it1.next();
LearningWeightedConnection wconn = null;
if(conn instanceof LearningWeightedConnection){
wconn = (LearningWeightedConnection) conn;
/*
* Update learning connection
*/
lc.getTrainingFunction().update(wconn);
}
}
}
}
}
@Override
public void setMap(FcmLearning map) {
this.map = map;
}
}