Package bgu.bio.learning.classifier

Source Code of bgu.bio.learning.classifier.Category

package bgu.bio.learning.classifier;

import gnu.trove.list.array.TLongArrayList;
import gnu.trove.map.hash.TObjectLongHashMap;

import java.util.HashMap;

import bgu.bio.adt.tuples.LongPair;
import bgu.bio.util.MathOperations;

public class NaiveBayes {
  private final int alpha;
  private HashMap<String, Category> categories;
  private TObjectLongHashMap<String> featuresCounters;

  public NaiveBayes() {
    alpha = 1;
    categories = new HashMap<String, Category>();
    featuresCounters = new TObjectLongHashMap<String>();
  }

  public void train(String categoryName, String... features) {
    if (!categories.containsKey(categoryName)) {
      categories.put(categoryName, new Category(categoryName));
    }
    Category category = categories.get(categoryName);
    category.incrementCounter();
    for (int i = 0; i < features.length; i++) {
      category.addFeature(features[i]);
      featuresCounters.adjustOrPutValue(features[i], 1, 1);
    }
  }

  private double calculateCategory(String categoryName, String... features) {
    TLongArrayList denomList = new TLongArrayList();
    TLongArrayList numList = new TLongArrayList();

    LongPair pair = new LongPair();
    /*
     * categoryProbability(categoryName, pair); gcd(pair);
     * denomList.add(pair.getSecond()); numList.add(pair.getFirst());
     */
    Category category = categories.get(categoryName);
    for (int i = 0; i < features.length; i++) {
      pair.setFirst(category.featureOccurences(features[i]) + alpha);
      pair.setSecond(featureOccurences(features[i]) + categories.size()
          * alpha);
      gcd(pair);
      numList.add(pair.getFirst());
      denomList.add(pair.getSecond());
    }
    // count the amount of zeros in the lists

    long numerator = 1, denominator = 1;
    for (int i = 0; i < denomList.size(); i++) {
      numerator *= numList.get(i);
      long tmp = MathOperations.gcd(numerator, denominator);
      if (tmp > 1) {
        numerator = numerator / tmp;
        denominator = denominator / tmp;
      }
      denominator *= denomList.get(i);
      tmp = MathOperations.gcd(numerator, denominator);
      if (tmp > 1) {
        numerator = numerator / tmp;
        denominator = denominator / tmp;
      }
    }
    return (1.0 * numerator) / denominator;
  }

  public String classify(String... features) {
    double maxVal = Double.NEGATIVE_INFINITY;
    String maxCategory = "";
    for (String categoryName : categories.keySet()) {
      final double ans = calculateCategory(categoryName, features);
      if (ans > maxVal) {
        maxCategory = categoryName;
        maxVal = ans;
      }
    }
    return maxCategory;
  }

  private void gcd(LongPair pair) {
    final long tmp = MathOperations.gcd(pair.getFirst(), pair.getSecond());
    pair.setFirst(pair.getFirst() / tmp);
    pair.setSecond(pair.getSecond() / tmp);
  }

  private long featureOccurences(String feature) {
    if (featuresCounters.contains(feature)) {
      return featuresCounters.get(feature);
    }
    return 0;

  }
}

class Category {
  private String name;
  private long count;
  private TObjectLongHashMap<String> featuresCounters;

  public Category(String name) {
    this.name = name;
    count = 0;
    featuresCounters = new TObjectLongHashMap<String>();
  }

  public long getCount() {
    return count;
  }

  public String getName() {
    return name;
  }

  public void addFeature(String feature) {
    featuresCounters.adjustOrPutValue(feature, 1, 1);

  }

  public void incrementCounter() {
    count++;
  }

  public long featureOccurences(String feature) {
    if (featuresCounters.contains(feature)) {
      return featuresCounters.get(feature);
    }
    return 0;

  }
}
TOP

Related Classes of bgu.bio.learning.classifier.Category

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.