Package gannuCF.classifiers

Source Code of gannuCF.classifiers.NaiveBayes

package gannuCF.classifiers;

import gannuNLP.data.AmbiguousWord;
import gannuNLP.data.Sense;
import gannuUtil.Util;

import java.util.ArrayList;
import java.util.Collections;
/**
* Implementation of a Naive-Bayes classifier.
* @author Francisco Viveros-Jiménez
*/
public class NaiveBayes extends Classifier {

  private static final long serialVersionUID = 1L;
  /**
   * Mean values extracted for each attribute per class.
   * The first dimension controls the target class number.
   * The second dimension controls the attribute number.
   */
  float [][]u;
  /**
   * Standard deviation values extracted for each attribute per class.
   * The first dimension controls the target class number.
   * The second dimension controls the attribute number.
   */
  float [][]s;
  /**
   * Instantiate this classifier.
   */
  public NaiveBayes()
  {
    super();
    this.name="NaiveBayes";
  }
  @Override
  public void train(int classes,int sampleCount, int attributeCount, float[][] samples) {
    this.classes=classes;
    this.sampleCount=sampleCount;
    this.attributeCount=attributeCount;
    u=new float[classes][attributeCount];
    s=new float[classes][attributeCount];
    int classCount[]=new int[classes];
    float count=3.0f;
    if(this.getValue("pseudocount")!=null)
      count=Float.parseFloat(this.getValue("pseudocount"));
    //initialize the arrays
    for(int c=0;c<classes;c++)
    {
      classCount[c]=0;
      for(int j=0;j<attributeCount;j++)
      {
        u[c][j]=count;
        s[c][j]=0.0f;
      }
    }
    //calculate means
    for(int i=0;i<sampleCount;i++)
    {
      int classIndex=(int)Math.round(samples[i][this.attributeCount]);
      classCount[classIndex]++;
      for(int j=0;j<attributeCount;j++)
      {
       
        u[classIndex][j]+=samples[i][j];
      }
    }
    for(int c=0;c<classes;c++)
    {
      for(int j=0;j<attributeCount;j++)
      {
        u[c][j]=u[c][j]/((float)classCount[c]);
      }
    }
    //calculate variance values
    for(int i=0;i<sampleCount;i++)
    {
      int classIndex=(int)Math.round(samples[i][this.attributeCount]);
     
      for(int j=0;j<attributeCount;j++)
      {
       
        s[classIndex][j]+=((float)Math.pow(samples[i][j]-u[classIndex][j],2.0));
      }
    }
    for(int c=0;c<classes;c++)
    {
      for(int j=0;j<attributeCount;j++)
      {
        s[c][j]=s[c][j]/(((float)classCount[c])-1.0f);
      }
    }
  }

  @Override
  public float[] classify(float[] sample) {
    float w[]=new float[this.classes];
    for(int c=0;c<this.classes;c++)
    {
      w[c]=1.0f/((float)this.classes);
      for(int j=0;j<this.attributeCount;j++)
      {
        float a=(1.0f/((float)Math.sqrt(2.0f*((float)Math.PI)*(s[c][j]))));
        float b=((float)Math.exp(
            -((float)Math.pow(sample[j]-u[c][j], 2.0f))
            /(2.0f*(s[c][j]))
            ));
        w[c]=a*b;
      }
    }
   
    return w;
  }
  @Override
  public void train(AmbiguousWord target) throws Exception {
    this.features=new ArrayList<String>();
    this.classes=target.getSenses().size();   
    for(Sense s:target.getSenses())
    {   
      for(String word:s.getBagOfWords())
      {
        features.add(word);
      }
    }
    features=Util.removeDuplicates(features);
    Collections.sort(features);
    this.attributeCount=this.features.size();
    this.features.trimToSize();
    u=new float[classes][attributeCount];
    s=new float[classes][attributeCount];
    int classCount[]=new int[classes];
    float count=3.0f;
    if(this.getValue("pseudocount")!=null)
      count=Float.parseFloat(this.getValue("pseudocount"));
    //initialize the arrays
    for(int c=0;c<classes;c++)
    {
      classCount[c]=0;
      for(int j=0;j<attributeCount;j++)
      {
        u[c][j]=count;
        s[c][j]=0.0f;
      }
    }
    //calculate means
    for(int classIndex=0;classIndex<target.getSenses().size();classIndex++)
    {
      Sense s=target.getSenses().get(classIndex);
      classCount[classIndex]+=s.getSamples().size();
      for(String word:s.getBagOfWords())
      {       
        u[classIndex][Collections.binarySearch(this.features, word)]+=1.0f;
      }
    }
    for(int c=0;c<classes;c++)
    {
      for(int j=0;j<attributeCount;j++)
      {
        u[c][j]=u[c][j]/((float)classCount[c]);
      }
    }
    //calculate variance values
    for(int classIndex=0;classIndex<target.getSenses().size();classIndex++)
    {
      Sense s=target.getSenses().get(classIndex);
      ArrayList<Integer> indexes=new ArrayList<Integer>(s.getBagOfWords().size());   
      for(String word:s.getBagOfWords())
      {
        Integer tmp=new Integer(Collections.binarySearch(this.features, word));
        if(!indexes.contains(tmp))
          indexes.add(tmp);
      }
      for(int j=0;j<attributeCount;j++)
      {
        Integer tmp=new Integer(j);
        if(!indexes.contains(tmp))
           this.s[classIndex][j]+=((float)Math.pow(u[classIndex][j],2.0));
        else
        {
          float w=0.0f;
          for(String word:s.getBagOfWords())
          {
            if(word.equals(this.features.get(tmp.intValue())))
            {
              w+=1.0f;
            }
          }
          this.s[classIndex][j]+=((float)Math.pow(w-u[classIndex][j],2.0));
        }
      }
    }
    for(int c=0;c<classes;c++)
    {
      for(int j=0;j<attributeCount;j++)
      {
        s[c][j]=s[c][j]/(((float)classCount[c])-1.0f);
      }
    }

  }
  @Override
  public float[] classify(ArrayList<String> Sample) {
    float sample[]=new float[features.size()];
    for(int j=0;j<features.size();j++)
      sample[j]=0.0f;
    for(String word:Sample)
    {
      int index=Collections.binarySearch(features, word);
      if(index>=0)
        sample[index]+=1.0f;
    }
    return this.classify(sample);
  }

}
TOP

Related Classes of gannuCF.classifiers.NaiveBayes

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.