Package org.apache.mahout.classifier.naivebayes

Source Code of org.apache.mahout.classifier.naivebayes.NaiveBayesModel

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.classifier.naivebayes;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.mahout.classifier.naivebayes.trainer.NaiveBayesTrainer;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.SparseMatrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;

/**
* NaiveBayesModel holds the weight Matrix, the feature and label sums and the weight normalizer vectors.
*/
public class NaiveBayesModel {

  private static final String MODEL = "NaiveBayesModel";

  private Vector labelSum;
  private Vector perlabelThetaNormalizer;
  private Vector featureSum;
  private Matrix weightMatrix;
  private float alphaI;
  private double vocabCount;
  private double totalSum;
 
  private NaiveBayesModel() {
    // do nothing
  }
 
  public NaiveBayesModel(Matrix matrix, Vector featureSum, Vector labelSum, Vector thetaNormalizer, float alphaI) {
    this.weightMatrix = matrix;
    this.featureSum = featureSum;
    this.labelSum = labelSum;
    this.perlabelThetaNormalizer = thetaNormalizer;
    this.vocabCount = featureSum.getNumNondefaultElements();
    this.totalSum = labelSum.zSum();
    this.alphaI = alphaI;
  }

  private void setLabelSum(Vector labelSum) {
    this.labelSum = labelSum;
  }


  public void setPerlabelThetaNormalizer(Vector perlabelThetaNormalizer) {
    this.perlabelThetaNormalizer = perlabelThetaNormalizer;
  }


  public void setFeatureSum(Vector featureSum) {
    this.featureSum = featureSum;
  }


  public void setWeightMatrix(Matrix weightMatrix) {
    this.weightMatrix = weightMatrix;
  }


  public void setAlphaI(float alphaI) {
    this.alphaI = alphaI;
  }


  public void setVocabCount(double vocabCount) {
    this.vocabCount = vocabCount;
  }


  public void setTotalSum(double totalSum) {
    this.totalSum = totalSum;
  }
 
  public Vector getLabelSum() {
    return labelSum;
  }

  public Vector getPerlabelThetaNormalizer() {
    return perlabelThetaNormalizer;
  }

  public Vector getFeatureSum() {
    return featureSum;
  }

  public Matrix getWeightMatrix() {
    return weightMatrix;
  }

  public float getAlphaI() {
    return alphaI;
  }

  public double getVocabCount() {
    return vocabCount;
  }

  public double getTotalSum() {
    return totalSum;
  }
 
  public int getNumLabels() {
    return labelSum.size();
  }

  public static String getModelName() {
    return MODEL;
  }
 
  // CODE USED FOR SERIALIZATION
  public static NaiveBayesModel fromMRTrainerOutput(Path output, Configuration conf) {
    Path classVectorPath = new Path(output, NaiveBayesTrainer.CLASS_VECTORS);
    Path sumVectorPath = new Path(output, NaiveBayesTrainer.SUM_VECTORS);
    Path thetaSumPath = new Path(output, NaiveBayesTrainer.THETA_SUM);

    NaiveBayesModel model = new NaiveBayesModel();
    model.setAlphaI(conf.getFloat(NaiveBayesTrainer.ALPHA_I, 1.0f));

    int featureCount = 0;
    int labelCount = 0;
    // read feature sums and label sums
    for (Pair<Text,VectorWritable> record
         : new SequenceFileIterable<Text, VectorWritable>(sumVectorPath, true, conf)) {
      Text key = record.getFirst();
      VectorWritable value = record.getSecond();
      if (key.toString().equals(BayesConstants.FEATURE_SUM)) {
        model.setFeatureSum(value.get());
        featureCount = value.get().getNumNondefaultElements();
        model.setVocabCount(featureCount);      
      } else  if (key.toString().equals(BayesConstants.LABEL_SUM)) {
        model.setLabelSum(value.get());
        model.setTotalSum(value.get().zSum());
        labelCount = value.get().size();
      }
    }

    // read the class matrix
    Matrix matrix = new SparseMatrix(new int[] {labelCount, featureCount});
    for (Pair<IntWritable,VectorWritable> record
         : new SequenceFileIterable<IntWritable,VectorWritable>(classVectorPath, true, conf)) {
      IntWritable label = record.getFirst();
      VectorWritable value = record.getSecond();
      matrix.assignRow(label.get(), value.get());
    }
   
    model.setWeightMatrix(matrix);

    // read theta normalizer
    for (Pair<Text,VectorWritable> record
         : new SequenceFileIterable<Text,VectorWritable>(thetaSumPath, true, conf)) {
      Text key = record.getFirst();
      VectorWritable value = record.getSecond();
      if (key.toString().equals(BayesConstants.LABEL_THETA_NORMALIZER)) {
        model.setPerlabelThetaNormalizer(value.get());
      }
    }

    return model;
  }
 
  public static void validate(NaiveBayesModel model) {
    if (model == null) {
      return; // empty models are valid
    }

    if (model.getAlphaI() <= 0) {
      throw new IllegalArgumentException(
          "Error: AlphaI has to be greater than 0!");
    }

    if (model.getVocabCount() <= 0) {
      throw new IllegalArgumentException(
          "Error: The vocab count has to be greater than 0!");
    }

    if (model.getVocabCount() <= 0) {
      throw new IllegalArgumentException(
          "Error: The vocab count has to be greater than 0!");
    }
   
    if (model.getTotalSum() <= 0) {
      throw new IllegalArgumentException(
          "Error: The vocab count has to be greater than 0!");
    }   

    if (model.getLabelSum() == null || model.getLabelSum().getNumNondefaultElements() <= 0) {
      throw new IllegalArgumentException(
          "Error: The number of labels has to be greater than 0 or defined!");
   
   
    if (model.getPerlabelThetaNormalizer() == null
        || model.getPerlabelThetaNormalizer().getNumNondefaultElements() <= 0) {
      throw new IllegalArgumentException(
          "Error: The number of theta normalizers has to be greater than 0 or defined!");
    }
   
    if (model.getFeatureSum() == null || model.getFeatureSum().getNumNondefaultElements() <= 0) {
      throw new IllegalArgumentException(
          "Error: The number of features has to be greater than 0 or defined!");
    }
  }
}
TOP

Related Classes of org.apache.mahout.classifier.naivebayes.NaiveBayesModel

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.