// add only nodes that have not been observed
for(Node n:c1.getNodes())
if(n.getStatesSize() >1)
marginaliseOver.add(n);
Matrix k11, k22, k12, k21, h1, h2;
for (Node n : marginaliseOver) {
if (n.getType() == Node.CONTINUOUS_NODE_TYPE)
y1.add(n);
else if (n.getType() == Node.PROBABILISTIC_NODE_TYPE)
d1.add(n);
}
for (Node n : c1.getNodes()) {
if (n.getType() == Node.CONTINUOUS_NODE_TYPE && !y1.contains(n))
y2.add(n);
else if (n.getType() == Node.PROBABILISTIC_NODE_TYPE
&& !d1.contains(n))
d2.add(n);
}
//remove duplicate variables
y1 = removeDuplicates(y1); //continuous varibles to be marginalized over
y2 = removeDuplicates(y2); //continuous variable remaining
d1 = removeDuplicates(d1); //discrete variables to be marginalized over
d2 = removeDuplicates(d2); //discrete variables remaining
System.out.println("y1 "+y1.toString()+" y2 "+y2.toString()+" d1 "+d1.toString()+" d2 "+d2.toString());
/* if (continuous node)
* marginalize over continuous node;
* update clique potential;
* else if (discrete node)
* convert clique potential to moment form;
* marginalize over discrete node;
* update clique potential;
* end
*
* get the normalization constant from the updated clique potential
* return the constant.
*/
if (y1.size() != 0 && !result.hasNoDescreteStates) {
for (int i = 0; i < ghkPot.tableSize(); i++) {
System.out.println("Enter marginalize continuous nodes at pos "+i);
System.out.println("Continuous nodes are "+ghkPot.getContinuousNodeList().toString());
System.out.println("H Size "+ghkPot.getHValue(i).gethMatrix().length);
printMatrix(ghkPot.getHValue(i).gethMatrix());
System.out.println("K Size "+ghkPot.getKValue(i).getkMatrix().length);
printMatrix(ghkPot.getKValue(i).getkMatrix());
// G component
double d = 0;
k11 = new Matrix(getK11(ghkPot.getKValue(i).getkMatrix(),ghkPot));
h1 = new Matrix(getH1(ghkPot.getHValue(i).gethMatrix(), ghkPot));
double logValue = k11.det() > 0 ? Math.log(k11.det()):k11.det();
System.out.println("Calculated values are "+ghkPot.getGValue(i)+" and "+ (0.5* (y1.size() * Math.log(2 * Math.PI)- logValue + (h1.transpose().times(matInverse(k11)).times(h1)).det())));
d = ghkPot.getGValue(i)
+ 0.5
* (y1.size() * Math.log(2 * Math.PI)
- logValue + (h1.transpose().times(matInverse(k11)).times(h1)).det());
result.setGValue(i,d);
System.out.println("setting g value as "+d);
// H component
if(y2.size()==0)
result.getHValue(i).sethMatrix(getZeroMatrix(result.getHValue(i).gethMatrix()));
else{
k11 = new Matrix(getK11(ghkPot.getKValue(i).getkMatrix(),ghkPot));
k21 = new Matrix(getK21(ghkPot.getKValue(i).getkMatrix(),ghkPot));
h1 = new Matrix(getH1(ghkPot.getHValue(i).gethMatrix(),ghkPot));
h2 = new Matrix(getH2(ghkPot.getHValue(i).gethMatrix(),ghkPot));
result.getHValue(i).sethMatrix(h2.minus((k21.times(matInverse(k11)).times(h1)))
.getArray());
}
System.out.println("setting h value as ");
printMatrix(result.getHValue(i).gethMatrix());
// K Component
if(y2.size()==0)
result.getKValue(i).setkMatrix(getZeroMatrix(result.getKValue(i).getkMatrix()));
else{
k11 = new Matrix(getK11(ghkPot.getKValue(i).getkMatrix(),ghkPot));
k21 = new Matrix(getK21(ghkPot.getKValue(i).getkMatrix(),ghkPot));
k12 = new Matrix(getK12(ghkPot.getKValue(i).getkMatrix(),ghkPot));
k22 = new Matrix(getK22(ghkPot.getKValue(i).getkMatrix(),ghkPot));
result.getKValue(i).setkMatrix((k22.minus(k21.times(matInverse(k11)).times(
k12))).getArray());
}
System.out.println("setting k value as ");
printMatrix(result.getKValue(i).getkMatrix());
result.setContinuousNodeList(y2);
}
System.out.println("Continuous marginalisation result is");
printGHK(result);
} else // has no descrete states
if (y1.size() != 0 && result.hasNoDescreteStates){
System.out.println("enter continuous marginalize at singleton state");
GHKPotential ghk = result.getGhkPot();
System.out.println("H Size "+ghk.getH().gethMatrix().length);
printMatrix(ghk.getH().gethMatrix());
System.out.println("K Size "+ghk.getK().getkMatrix().length);
printMatrix(ghk.getK().getkMatrix());
// G component
double d = 0;
k11 = new Matrix(getK11(ghk.getK().getkMatrix(),ghkPot));
h1 = new Matrix(getH1(ghk.getH().gethMatrix(), ghkPot));
double logValue = k11.det() > 0 ? Math.log(k11.det()):k11.det();
d = ghk.getG()
+ 0.5
* (y1.size() * Math.log(2 * Math.PI)
- logValue + (h1.transpose().times(
matInverse(k11)).times(h1)).det());
ghk.setG(d);
System.out.println("setting g as "+d);
// H component
if(y2.size()==0)
ghk.getH().sethMatrix(getZeroMatrix(ghk.getH().gethMatrix()));
else{
k11 = new Matrix(getK11(ghk.getK().getkMatrix(),ghkPot));
k21 = new Matrix(getK21(ghk.getK().getkMatrix(),ghkPot));
h1 = new Matrix(getH1(ghk.getH().gethMatrix(),ghkPot));
h2 = new Matrix(getH2(ghk.getH().gethMatrix(),ghkPot));
ghk.getH().sethMatrix(h2.minus((k21.times(matInverse(k11)).times(h1)))
.getArray());
}
System.out.println("Setting h as ");
printMatrix(ghk.getH().gethMatrix());
// K Component
if(y2.size()==0)
ghk.getK().setkMatrix(getZeroMatrix(ghk.getK().getkMatrix()));
else{
k11 = new Matrix(getK11(ghk.getK().getkMatrix(),ghkPot));
k21 = new Matrix(getK21(ghk.getK().getkMatrix(),ghkPot));
k12 = new Matrix(getK12(ghk.getK().getkMatrix(),ghkPot));
k22 = new Matrix(getK22(ghk.getK().getkMatrix(),ghkPot));
ghk.getK().setkMatrix((k22.minus(k21.times(matInverse(k11)).times(
k12))).getArray());
System.out.println("Setting k as ");
printMatrix(ghk.getK().getkMatrix());
// set in result
result.setGhkPot(ghk);
result.setContinuousNodeList(y2);
System.out.println("continuous marg result is");
printGHK(result);
}
}
// marginalise over discrete nodes only, no continuous remaining
// not handled the situation where I sum over all descrete variables
if (d1.size() != 0 && y1.size() == 0 && y2.size() == 0) {
// simple addition when you have only discrete variables; add up jpd
// i.e we do not have a continuous component residue
if( countInference == 1){
GHKPotentialTable temp = new GHKPotentialTable();
temp = getCopyof(result);
for (Node node : d1){
System.out.println(node.getName()+" has "+node.getStatesSize()+" states");
temp.sum(result.getVariableList().indexOf(node));
temp.getVariableList().remove(node);
}
ArrayList<Node> varList = (ArrayList<Node>) result.getVariableList();
temp.setVariableList((List<Node>) varList.clone());
result = getCopyof(temp);
//set variables
result.setVariableList((List<Node>) d2.clone());
result.setContinuousNodeList((List<Node>) y2.clone());
System.out.println("The result from simple add is ");
printGHK(result);
}else{
PMCPotentialTable mchar = convertToMomentCharec(result);
PMCPotentialTable temp = new PMCPotentialTable();
temp = (PMCPotentialTable) getCopyof(mchar);
// load the current values in the table
double prob[] = new double[temp.tableSize()];
for(int i=0;i<temp.tableSize();i++) prob[i]=temp.getJPDValue(i);
ArrayList<MeanMatrix> mm = new ArrayList<MeanMatrix>();
for(int i=0;i<temp.tableSize();i++) mm.add(temp.getMeanMatValue(i));
ArrayList<CoVarMatrix> cvm = new ArrayList<CoVarMatrix>();
for(int i=0;i<temp.tableSize();i++) cvm.add(temp.getCoVarMatValue(i));
// final values
DoubleCollection resJpd = new DoubleCollection();
ArrayList<MeanMatrix> resMM = new ArrayList<MeanMatrix>();
ArrayList<CoVarMatrix> resCVM = new ArrayList<CoVarMatrix>();
temp = (PMCPotentialTable) getCopyof(mchar);
// new jpd
// sum it
for(Node node:d1){
temp.sum(temp.getVariableList().indexOf(node));
temp.getVariableList().remove(temp.getVariableList().indexOf(node));
}
resJpd = temp.getJpd();
}
/*GHKPotentialTable temp = new GHKPotentialTable();
temp = getCopyof(result);
for (Node node : d1){
System.out.println(node.getName()+" has "+node.getStatesSize()+" states");
temp.sum(result.getVariableList().indexOf(node));
temp.getVariableList().remove(node);
}
ArrayList<Node> varList = (ArrayList<Node>) result.getVariableList();
temp.setVariableList((List<Node>) varList.clone());
result = getCopyof(temp);
//set variables
result.setVariableList((List<Node>) d2.clone());
result.setContinuousNodeList((List<Node>) y2.clone());
System.out.println("The result from simple add is ");
printGHK(result);
*/
} else if (d1.size() != 0) { // this case is to marginalize discrete variable
// do I do array or matrix multiplication ?
// convert
PMCPotentialTable mchar = convertToMomentCharec(result);
PMCPotentialTable temp = new PMCPotentialTable();
temp = (PMCPotentialTable) getCopyof(mchar);
// load the current values in the table
double prob[] = new double[temp.tableSize()];
for(int i=0;i<temp.tableSize();i++) prob[i]=temp.getJPDValue(i);
ArrayList<MeanMatrix> mm = new ArrayList<MeanMatrix>();
for(int i=0;i<temp.tableSize();i++) mm.add(temp.getMeanMatValue(i));
ArrayList<CoVarMatrix> cvm = new ArrayList<CoVarMatrix>();
for(int i=0;i<temp.tableSize();i++) cvm.add(temp.getCoVarMatValue(i));
// final values
DoubleCollection resJpd = new DoubleCollection();
ArrayList<MeanMatrix> resMM = new ArrayList<MeanMatrix>();
ArrayList<CoVarMatrix> resCVM = new ArrayList<CoVarMatrix>();
temp = (PMCPotentialTable) getCopyof(mchar);
// new jpd
// sum it
for(Node node:d1){
temp.sum(temp.getVariableList().indexOf(node));
temp.getVariableList().remove(temp.getVariableList().indexOf(node));
}
resJpd = temp.getJpd();
temp = new PMCPotentialTable();
temp = (PMCPotentialTable) getCopyof(mchar);
// new mean matrix
ArrayList<MeanMatrix> mm2 = (ArrayList<MeanMatrix>) mm.clone();
for(int i=0;i<temp.tableSize();i++){
for(int j=0;j<mm2.get(i).getMeanMatrix().length;j++)
for(int k=0;k<mm2.get(i).getMeanMatrix()[0].length;k++)
if(mm2.get(i).getMeanMatrix()[j][k]!=0)
mm2.get(i).getMeanMatrix()[j][k] = mm2.get(i).getMeanMatrix()[j][k]*prob[i];
}
temp.setMeanMatrices(mm2);
// sum it
for(Node node:d1){
temp.sum(temp.getVariableList().indexOf(node));
temp.getVariableList().remove(temp.getVariableList().indexOf(node));
}
mm2 = temp.getMeanMatrices();
for(int i=0;i<temp.tableSize();i++){
for(int j=0;j<mm2.get(i).getMeanMatrix().length;j++)
for(int k=0;k<mm2.get(i).getMeanMatrix()[0].length;k++)
if(mm2.get(i).getMeanMatrix()[j][k]!=0 && resJpd.get(i)!=0)
mm2.get(i).getMeanMatrix()[j][k] = mm2.get(i).getMeanMatrix()[j][k]/resJpd.get(i);
}
resMM=(ArrayList<MeanMatrix>)mm2.clone();
temp = new PMCPotentialTable();
temp = (PMCPotentialTable)getCopyof(mchar);
// new co-variance matrix
ArrayList<CoVarMatrix> cvm2 = new ArrayList<CoVarMatrix>();
cvm2 = (ArrayList<CoVarMatrix>) cvm.clone();
for(int i=0;i<temp.tableSize();i++){
for(int j=0;j<cvm2.get(i).getCoVarMatrix().length;j++)
for(int k=0;k<cvm2.get(i).getCoVarMatrix()[0].length;k++)
if(cvm2.get(i).getCoVarMatrix()[j][k]!=0)
cvm2.get(i).getCoVarMatrix()[j][k] =cvm2.get(i).getCoVarMatrix()[j][k]*prob[i];
}
temp.setCoVarMatrices(cvm2);
//sum it
for(Node node:d1){
temp.sum(temp.getVariableList().indexOf(node));
temp.getVariableList().remove(temp.getVariableList().indexOf(node));
}
cvm2= temp.getCoVarMatrices();
for(int i=0;i<temp.tableSize();i++){
for(int j=0;j<cvm2.get(i).getCoVarMatrix().length;j++)
for(int k=0;k<cvm2.get(i).getCoVarMatrix()[0].length;k++)
if(cvm2.get(i).getCoVarMatrix()[j][k]!=0 && resJpd.get(i)!=0)
cvm2.get(i).getCoVarMatrix()[j][k] = cvm2.get(i).getCoVarMatrix()[j][k]/resJpd.get(i);
}
PMCPotentialTable uOriginal = (PMCPotentialTable) getCopyof(mchar);
PMCPotentialTable uNew = (PMCPotentialTable) getCopyof(mchar);
uNew.setMeanMatrices(resMM);
uNew.setContinuousNodeList(y2);
uNew.setVariableList(d2);
uNew.computeFactors();
uOriginal.PMCTableOp(uNew, PMCPotentialTable.MINUS_OPERATOR);
ArrayList<MeanMatrix> uMinusNewU = uOriginal.getMeanMatrices();
temp = new PMCPotentialTable();
temp = (PMCPotentialTable) getCopyof(mchar);
ArrayList<CoVarMatrix> cvm3 = new ArrayList<CoVarMatrix>();
for(int i=0;i<temp.tableSize();i++){
Matrix uMinusu2 = new Matrix(uMinusNewU.get(i).getMeanMatrix());
double[][] tmp = (uMinusu2.transpose()).times(uMinusu2).times(prob[i]).getArray();
CoVarMatrix cv = new CoVarMatrix();
cv.setCoVarMatrix(tmp);
cvm3.add(cv);
}
temp.setCoVarMatrices(cvm3);