Package cc.mrlda

Source Code of cc.mrlda.InformedPrior

package cc.mrlda;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.StringTokenizer;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.GnuParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.OptionBuilder;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.apache.log4j.Logger;

import com.google.common.base.Preconditions;

import edu.umd.cloud9.io.array.ArrayListOfIntsWritable;
import edu.umd.cloud9.util.map.HMapIV;

public class InformedPrior extends Configured implements Tool {
  static final Logger sLogger = Logger.getLogger(InformedPrior.class);

  public static final String ETA = "eta";
  public static final String INFORMED_PRIOR_OPTION = "informedprior";

  // informed prior on beta matrix
  public static final float DEFAULT_INFORMED_LOG_ETA = (float) Math.log(1000.0);
  public static final float DEFAULT_UNINFORMED_LOG_ETA = (float) Math.log(0.001);

  @SuppressWarnings("unchecked")
  public int run(String[] args) throws Exception {
    Options options = new Options();

    options.addOption(Settings.HELP_OPTION, false, "print the help message");
    options.addOption(OptionBuilder.withArgName(Settings.PATH_INDICATOR).hasArg()
        .withDescription("input file").create(Settings.INPUT_OPTION));
    options.addOption(OptionBuilder.withArgName(Settings.PATH_INDICATOR).hasArg()
        .withDescription("output file").create(Settings.OUTPUT_OPTION));
    options.addOption(OptionBuilder.withArgName(Settings.PATH_INDICATOR).hasArg()
        .withDescription("term index file").create(ParseCorpusOptions.INDEX));

    String termIndex = null;
    String output = null;
    String input = null;

    CommandLineParser parser = new GnuParser();
    HelpFormatter formatter = new HelpFormatter();
    try {
      CommandLine line = parser.parse(options, args);

      if (line.hasOption(Settings.HELP_OPTION)) {
        formatter.printHelp(InformedPrior.class.getName(), options);
        System.exit(0);
      }

      if (line.hasOption(Settings.INPUT_OPTION)) {
        input = line.getOptionValue(Settings.INPUT_OPTION);
      } else {
        throw new ParseException("Parsing failed due to " + Settings.INPUT_OPTION
            + " not initialized...");
      }

      if (line.hasOption(Settings.OUTPUT_OPTION)) {
        output = line.getOptionValue(Settings.OUTPUT_OPTION);
        if (output.endsWith(Path.SEPARATOR)) {
          output = output + ETA;
        }
      } else {
        throw new ParseException("Parsing failed due to " + Settings.OUTPUT_OPTION
            + " not initialized...");
      }

      if (line.hasOption(ParseCorpusOptions.INDEX)) {
        termIndex = line.getOptionValue(ParseCorpusOptions.INDEX);
      } else {
        throw new ParseException("Parsing failed due to " + ParseCorpusOptions.INDEX
            + " not initialized...");
      }
    } catch (ParseException pe) {
      System.err.println(pe.getMessage());
      formatter.printHelp(InformedPrior.class.getName(), options);
      System.exit(0);
    } catch (NumberFormatException nfe) {
      System.err.println(nfe.getMessage());
      System.exit(0);
    }

    // Delete the output directory if it exists already
    JobConf conf = new JobConf(InformedPrior.class);
    FileSystem fs = FileSystem.get(conf);

    Path inputPath = new Path(input);
    Preconditions.checkArgument(fs.exists(inputPath) && fs.isFile(inputPath),
        "Illegal input file...");

    Path termIndexPath = new Path(termIndex);
    Preconditions.checkArgument(fs.exists(termIndexPath) && fs.isFile(termIndexPath),
        "Illegal term index file...");

    Path outputPath = new Path(output);
    fs.delete(outputPath, true);

    SequenceFile.Reader sequenceFileReader = null;
    SequenceFile.Writer sequenceFileWriter = null;
    BufferedReader bufferedReader = null;
    fs.createNewFile(outputPath);
    try {
      bufferedReader = new BufferedReader(new InputStreamReader(fs.open(inputPath)));
      sequenceFileReader = new SequenceFile.Reader(fs, termIndexPath, conf);
      sequenceFileWriter = new SequenceFile.Writer(fs, conf, outputPath, IntWritable.class,
          ArrayListOfIntsWritable.class);
      exportTerms(bufferedReader, sequenceFileReader, sequenceFileWriter);
      sLogger.info("Successfully index the informed prior to " + outputPath);
    } finally {
      bufferedReader.close();
      IOUtils.closeStream(sequenceFileReader);
      IOUtils.closeStream(sequenceFileWriter);
    }

    return 0;
  }

  public static void exportTerms(BufferedReader bufferedReader,
      SequenceFile.Reader sequenceFileReader, SequenceFile.Writer sequenceFileWriter)
      throws IOException {
    Map<String, Integer> termIndex = ParseCorpus.importParameter(sequenceFileReader);

    IntWritable intWritable = new IntWritable();
    ArrayListOfIntsWritable arrayListOfIntsWritable = new ArrayListOfIntsWritable();

    StringTokenizer stk = null;
    String temp = null;

    String line = bufferedReader.readLine();
    int index = 0;
    while (line != null) {
      index++;
      intWritable.set(index);
      arrayListOfIntsWritable.clear();

      stk = new StringTokenizer(line);
      while (stk.hasMoreTokens()) {
        temp = stk.nextToken();
        if (termIndex.containsKey(temp)) {
          arrayListOfIntsWritable.add(termIndex.get(temp));
        } else {
          sLogger.info("How embarrassing! Term " + temp + " not found in the index file...");
        }
      }

      sequenceFileWriter.append(intWritable, arrayListOfIntsWritable);
      line = bufferedReader.readLine();
    }
  }

  public static float getLogEta(int termID, Set<Integer> knownTerms) {
    if (knownTerms != null && knownTerms.contains(termID)) {
      return DEFAULT_INFORMED_LOG_ETA;
    }
    return DEFAULT_UNINFORMED_LOG_ETA;
  }

  public static HMapIV<Set<Integer>> importEta(SequenceFile.Reader sequenceFileReader)
      throws IOException {
    HMapIV<Set<Integer>> lambdaMap = new HMapIV<Set<Integer>>();

    IntWritable intWritable = new IntWritable();
    ArrayListOfIntsWritable arrayListOfInts = new ArrayListOfIntsWritable();

    while (sequenceFileReader.next(intWritable, arrayListOfInts)) {
      Preconditions.checkArgument(intWritable.get() > 0, "Invalid eta prior for term "
          + intWritable.get() + "...");

      // topic is from 1 to K
      int topicIndex = intWritable.get();
      Set<Integer> hashset = new HashSet<Integer>();

      Iterator<Integer> itr = arrayListOfInts.iterator();
      while (itr.hasNext()) {
        hashset.add(itr.next());
      }

      lambdaMap.put(topicIndex, hashset);
    }
    return lambdaMap;
  }

  public static void main(String[] args) throws Exception {
    int res = ToolRunner.run(new Configuration(), new InformedPrior(), args);
    System.exit(res);
  }
}
TOP

Related Classes of cc.mrlda.InformedPrior

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.