package gannuCF.classifiers;
import gannuNLP.data.AmbiguousWord;
import gannuNLP.data.Sense;
import gannuUtil.Util;
import java.util.ArrayList;
import java.util.Collections;
/**
* Implementation of a Naive-Bayes classifier.
* @author Francisco Viveros-Jiménez
*/
public class NaiveBayes extends Classifier {
private static final long serialVersionUID = 1L;
/**
* Mean values extracted for each attribute per class.
* The first dimension controls the target class number.
* The second dimension controls the attribute number.
*/
float [][]u;
/**
* Standard deviation values extracted for each attribute per class.
* The first dimension controls the target class number.
* The second dimension controls the attribute number.
*/
float [][]s;
/**
* Instantiate this classifier.
*/
public NaiveBayes()
{
super();
this.name="NaiveBayes";
}
@Override
public void train(int classes,int sampleCount, int attributeCount, float[][] samples) {
this.classes=classes;
this.sampleCount=sampleCount;
this.attributeCount=attributeCount;
u=new float[classes][attributeCount];
s=new float[classes][attributeCount];
int classCount[]=new int[classes];
float count=3.0f;
if(this.getValue("pseudocount")!=null)
count=Float.parseFloat(this.getValue("pseudocount"));
//initialize the arrays
for(int c=0;c<classes;c++)
{
classCount[c]=0;
for(int j=0;j<attributeCount;j++)
{
u[c][j]=count;
s[c][j]=0.0f;
}
}
//calculate means
for(int i=0;i<sampleCount;i++)
{
int classIndex=(int)Math.round(samples[i][this.attributeCount]);
classCount[classIndex]++;
for(int j=0;j<attributeCount;j++)
{
u[classIndex][j]+=samples[i][j];
}
}
for(int c=0;c<classes;c++)
{
for(int j=0;j<attributeCount;j++)
{
u[c][j]=u[c][j]/((float)classCount[c]);
}
}
//calculate variance values
for(int i=0;i<sampleCount;i++)
{
int classIndex=(int)Math.round(samples[i][this.attributeCount]);
for(int j=0;j<attributeCount;j++)
{
s[classIndex][j]+=((float)Math.pow(samples[i][j]-u[classIndex][j],2.0));
}
}
for(int c=0;c<classes;c++)
{
for(int j=0;j<attributeCount;j++)
{
s[c][j]=s[c][j]/(((float)classCount[c])-1.0f);
}
}
}
@Override
public float[] classify(float[] sample) {
float w[]=new float[this.classes];
for(int c=0;c<this.classes;c++)
{
w[c]=1.0f/((float)this.classes);
for(int j=0;j<this.attributeCount;j++)
{
float a=(1.0f/((float)Math.sqrt(2.0f*((float)Math.PI)*(s[c][j]))));
float b=((float)Math.exp(
-((float)Math.pow(sample[j]-u[c][j], 2.0f))
/(2.0f*(s[c][j]))
));
w[c]=a*b;
}
}
return w;
}
@Override
public void train(AmbiguousWord target) throws Exception {
this.features=new ArrayList<String>();
this.classes=target.getSenses().size();
for(Sense s:target.getSenses())
{
for(String word:s.getBagOfWords())
{
features.add(word);
}
}
features=Util.removeDuplicates(features);
Collections.sort(features);
this.attributeCount=this.features.size();
this.features.trimToSize();
u=new float[classes][attributeCount];
s=new float[classes][attributeCount];
int classCount[]=new int[classes];
float count=3.0f;
if(this.getValue("pseudocount")!=null)
count=Float.parseFloat(this.getValue("pseudocount"));
//initialize the arrays
for(int c=0;c<classes;c++)
{
classCount[c]=0;
for(int j=0;j<attributeCount;j++)
{
u[c][j]=count;
s[c][j]=0.0f;
}
}
//calculate means
for(int classIndex=0;classIndex<target.getSenses().size();classIndex++)
{
Sense s=target.getSenses().get(classIndex);
classCount[classIndex]+=s.getSamples().size();
for(String word:s.getBagOfWords())
{
u[classIndex][Collections.binarySearch(this.features, word)]+=1.0f;
}
}
for(int c=0;c<classes;c++)
{
for(int j=0;j<attributeCount;j++)
{
u[c][j]=u[c][j]/((float)classCount[c]);
}
}
//calculate variance values
for(int classIndex=0;classIndex<target.getSenses().size();classIndex++)
{
Sense s=target.getSenses().get(classIndex);
ArrayList<Integer> indexes=new ArrayList<Integer>(s.getBagOfWords().size());
for(String word:s.getBagOfWords())
{
Integer tmp=new Integer(Collections.binarySearch(this.features, word));
if(!indexes.contains(tmp))
indexes.add(tmp);
}
for(int j=0;j<attributeCount;j++)
{
Integer tmp=new Integer(j);
if(!indexes.contains(tmp))
this.s[classIndex][j]+=((float)Math.pow(u[classIndex][j],2.0));
else
{
float w=0.0f;
for(String word:s.getBagOfWords())
{
if(word.equals(this.features.get(tmp.intValue())))
{
w+=1.0f;
}
}
this.s[classIndex][j]+=((float)Math.pow(w-u[classIndex][j],2.0));
}
}
}
for(int c=0;c<classes;c++)
{
for(int j=0;j<attributeCount;j++)
{
s[c][j]=s[c][j]/(((float)classCount[c])-1.0f);
}
}
}
@Override
public float[] classify(ArrayList<String> Sample) {
float sample[]=new float[features.size()];
for(int j=0;j<features.size();j++)
sample[j]=0.0f;
for(String word:Sample)
{
int index=Collections.binarySearch(features, word);
if(index>=0)
sample[index]+=1.0f;
}
return this.classify(sample);
}
}