Package mia.classifier.ch14

Source Code of mia.classifier.ch14.TrainNewsGroups

package mia.classifier.ch14;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.Reader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.util.Version;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
import org.apache.mahout.vectorizer.encoders.Dictionary;

import com.google.common.collect.ConcurrentHashMultiset;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Iterables;
import com.google.common.collect.Multiset;

public class TrainNewsGroups {
  private static final int FEATURES = 10000;
  private static Multiset<String> overallCounts;
 
  public static void main(String[] args) {
    File base = new File(args[0]);
    overallCounts = HashMultiset.create();

    Map<String, Set<Integer>> traceDictionary = new TreeMap<String, Set<Integer>>();
    FeatureVectorEncoder encoder = new StaticWordValueEncoder("body");
    encoder.setProbes(2);
    encoder.setTraceDictionary(traceDictionary);
    FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept");
    bias.setTraceDictionary(traceDictionary);
    FeatureVectorEncoder lines = new ConstantValueEncoder("Lines");
    lines.setTraceDictionary(traceDictionary);
    Dictionary newsGroups = new Dictionary();
   
    OnlineLogisticRegression learningAlgorithm =
        new OnlineLogisticRegression(
              20, FEATURES, new L1())
            .alpha(1).stepOffset(1000)
            .decayExponent(0.9)
            .lambda(3.0e-5)
            .learningRate(20);
   
    List<File> files = new ArrayList<File>();
    for (File newsgroup : base.listFiles()) {
      newsGroups.intern(newsgroup.getName());
      files.addAll(Arrays.asList(newsgroup.listFiles()));
    }

    Collections.shuffle(files);
    System.out.printf("%d training files\n", files.size());
   
    double averageLL = 0.0;
    double averageCorrect = 0.0;
    double averageLineCount = 0.0;
    int k = 0;
    double step = 0.0;
    int[] bumps = new int[]{1, 2, 5};
    double lineCount = 0;
   
    Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_31);
   
/*    for (File file : files) {
        BufferedReader reader = new BufferedReader(new FileReader(file));
        String ng = file.getParentFile().getName();    
        int actual = newsGroups.intern(ng);
        Multiset<String> words = ConcurrentHashMultiset.create();
       
        String line = reader.readLine();
        while (line != null && line.length() > 0) {
          if (line.startsWith("Lines:")) {             
            String count = Iterables.get(onColon.split(line), 1);
            try {
              lineCount = Integer.parseInt(count);
              averageLineCount += (lineCount - averageLineCount)
                  / Math.min(k + 1, 1000);
            } catch (NumberFormatException e) {
              lineCount = averageLineCount;
            }
          }
          boolean countHeader = (
              line.startsWith("From:") || line.startsWith("Subject:")||
              line.startsWith("Keywords:")|| line.startsWith("Summary:"));
          do {
            StringReader in = new StringReader(line);
            if (countHeader) {
              countWords(analyzer, words, in);   
            }
            line = reader.readLine();
          } while (line.startsWith(" "));
        }
        countWords(analyzer, words, reader);     
        reader.close();
      }
   
    Vector v = new RandomAccessSparseVector(FEATURES);
    bias.addToVector(null, 1, v);
    lines.addToVector(null, lineCount / 30, v);
    logLines.addToVector(null, Math.log(lineCount + 1), v);
    for (String word : words.elementSet()) {
      encoder.addToVector(word, Math.log(1 + words.count(word)), v);
    }
    */
   
    /*double mu = Math.min(k + 1, 200);
    double ll = learningAlgorithm.logLikelihood(actual, v);  #1
    averageLL = averageLL + (ll - averageLL) / mu;

    Vector p = new DenseVector(20);
    learningAlgorithm.classifyFull(p, v);
    int estimated = p.maxValueIndex();

    int correct = (estimated == actual? 1 : 0);
    averageCorrect = averageCorrect + (correct - averageCorrect) / mu;*/
   
/*    learningAlgorithm.train(actual, v);
    k++;
    int bump = bumps[(int) Math.floor(step) % bumps.length];
    int scale = (int) Math.pow(10, Math.floor(step / bumps.length));
    if (k % (bump * scale) == 0) {
      step += 0.25;
      System.out.printf("%10d %10.3f %10.3f %10.2f %s %s\n",
           k, ll, averageLL, averageCorrect * 100, ng,
           newsGroups.values().get(estimated));
    }
    learningAlgorithm.close();*/
   
  }

  private static void countWords(Analyzer analyzer, Collection<String> words, Reader in) throws IOException {
      TokenStream ts = analyzer.tokenStream("text", in);
      ts.addAttribute(CharTermAttribute.class);
      while (ts.incrementToken()) {
        String s = ts.getAttribute(CharTermAttribute.class).toString();
        words.add(s);
      }
      /*overallCounts.addAll(words);*/
    }
}
TOP

Related Classes of mia.classifier.ch14.TrainNewsGroups

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.