Package org.apache.mahout.utils.vectors

Source Code of org.apache.mahout.utils.vectors.VectorHelper$TDoublePQ

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.utils.vectors;

import com.google.common.base.Function;
import com.google.common.collect.Collections2;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.lucene.util.PriorityQueue;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.FileLineIterator;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.Vector.Element;
import org.apache.mahout.math.map.OpenObjectIntHashMap;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Comparator;
import java.util.regex.Pattern;

/** Static utility methods related to vectors. */
public final class VectorHelper {

  private static final Pattern TAB_PATTERN = Pattern.compile("\t");


  private VectorHelper() {
  }

  public static String vectorToCSVString(Vector vector, boolean namesAsComments) throws IOException {
    Appendable bldr = new StringBuilder(2048);
    vectorToCSVString(vector, namesAsComments, bldr);
    return bldr.toString();
  }

  public static String buildJson(Iterable<Pair<String, Double>> iterable) {
    return buildJson(iterable, new StringBuilder(2048));
  }

  public static String buildJson(Iterable<Pair<String, Double>> iterable, StringBuilder bldr) {
    bldr.append('{');
    for (Pair<String, Double> p : iterable) {
      bldr.append(p.getFirst());
      bldr.append(':');
      bldr.append(p.getSecond());
      bldr.append(',');
    }
    if (bldr.length() > 1) {
      bldr.setCharAt(bldr.length() - 1, '}');
    }
    return bldr.toString();
  }

  public static List<Pair<Integer, Double>> topEntries(Vector vector, int maxEntries) {

    // Get the size of nonZero elements in the input vector
    int sizeOfNonZeroElementsInVector = Iterables.size(vector.nonZeroes());

    // If the sizeOfNonZeroElementsInVector < maxEntries then set maxEntries = sizeOfNonZeroElementsInVector
    // otherwise the call to queue.pop() returns a Pair(null, null) and the subsequent call
    // to pair.getFirst() throws a NullPointerException
    if (sizeOfNonZeroElementsInVector < maxEntries) {
      maxEntries = sizeOfNonZeroElementsInVector;
    }

    PriorityQueue<Pair<Integer, Double>> queue = new TDoublePQ<Integer>(-1, maxEntries);
    for (Element e : vector.nonZeroes()) {
      queue.insertWithOverflow(Pair.of(e.index(), e.get()));
    }
    List<Pair<Integer, Double>> entries = Lists.newArrayList();
    Pair<Integer, Double> pair;
    while ((pair = queue.pop()) != null) {
      if (pair.getFirst() > -1) {
        entries.add(pair);
      }
    }
    Collections.sort(entries, new Comparator<Pair<Integer, Double>>() {
      @Override
      public int compare(Pair<Integer, Double> a, Pair<Integer, Double> b) {
        return b.getSecond().compareTo(a.getSecond());
      }
    });
    return entries;
  }

  public static List<Pair<Integer, Double>> firstEntries(Vector vector, int maxEntries) {
    List<Pair<Integer, Double>> entries = Lists.newArrayList();
    Iterator<Vector.Element> it = vector.nonZeroes().iterator();
    int i = 0;
    while (it.hasNext() && i++ < maxEntries) {
      Vector.Element e = it.next();
      entries.add(Pair.of(e.index(), e.get()));
    }
    return entries;
  }

  public static List<Pair<String, Double>> toWeightedTerms(Collection<Pair<Integer, Double>> entries,
                                                           final String[] dictionary) {
    if (dictionary != null) {
      return Lists.newArrayList(Collections2.transform(entries,
        new Function<Pair<Integer, Double>, Pair<String, Double>>() {
          @Override
          public Pair<String, Double> apply(Pair<Integer, Double> p) {
            return Pair.of(dictionary[p.getFirst()], p.getSecond());
          }
        }));
    } else {
      return Lists.newArrayList(Collections2.transform(entries,
        new Function<Pair<Integer, Double>, Pair<String, Double>>() {
          @Override
          public Pair<String, Double> apply(Pair<Integer, Double> p) {
            return Pair.of(Integer.toString(p.getFirst()), p.getSecond());
          }
        }));
    }
  }

  public static String vectorToJson(Vector vector, String[] dictionary, int maxEntries, boolean sort) {
    return buildJson(toWeightedTerms(sort
            ? topEntries(vector, maxEntries)
            : firstEntries(vector, maxEntries), dictionary));
  }

  public static void vectorToCSVString(Vector vector,
                                       boolean namesAsComments,
                                       Appendable bldr) throws IOException {
    if (namesAsComments && vector instanceof NamedVector) {
      bldr.append('#').append(((NamedVector) vector).getName()).append('\n');
    }
    Iterator<Vector.Element> iter = vector.all().iterator();
    boolean first = true;
    while (iter.hasNext()) {
      if (first) {
        first = false;
      } else {
        bldr.append(',');
      }
      Vector.Element elt = iter.next();
      bldr.append(String.valueOf(elt.get()));
    }
    bldr.append('\n');
  }

  /**
   * Read in a dictionary file. Format is:
   * <p/>
   * <pre>
   * term DocFreq Index
   * </pre>
   */
  public static String[] loadTermDictionary(File dictFile) throws IOException {
    InputStream in = new FileInputStream(dictFile);
    try {
      return loadTermDictionary(in);
    } finally {
      in.close();
    }
  }

  /**
   * Read a dictionary in {@link org.apache.hadoop.io.SequenceFile} generated by
   * {@link org.apache.mahout.vectorizer.DictionaryVectorizer}
   *
   * @param filePattern <PATH TO DICTIONARY>/dictionary.file-*
   */
  public static String[] loadTermDictionary(Configuration conf, String filePattern) {
    OpenObjectIntHashMap<String> dict = new OpenObjectIntHashMap<String>();
    for (Pair<Text, IntWritable> record
        : new SequenceFileDirIterable<Text, IntWritable>(new Path(filePattern), PathType.GLOB, null, null, true,
                                                         conf)) {
      dict.put(record.getFirst().toString(), record.getSecond().get());
    }
    String[] dictionary = new String[dict.size()];
    for (String feature : dict.keys()) {
      dictionary[dict.get(feature)] = feature;
    }
    return dictionary;
  }

  /**
   * Read in a dictionary file. Format is: First line is the number of entries
   * <p/>
   * <pre>
   * term DocFreq Index
   * </pre>
   */
  private static String[] loadTermDictionary(InputStream is) throws IOException {
    FileLineIterator it = new FileLineIterator(is);

    int numEntries = Integer.parseInt(it.next());
    String[] result = new String[numEntries];

    while (it.hasNext()) {
      String line = it.next();
      if (line.startsWith("#")) {
        continue;
      }
      String[] tokens = TAB_PATTERN.split(line);
      if (tokens.length < 3) {
        continue;
      }
      int index = Integer.parseInt(tokens[2]); // tokens[1] is the doc freq
      result[index] = tokens[0];
    }
    return result;
  }

  private static final class TDoublePQ<T> extends PriorityQueue<Pair<T, Double>> {
    private final T sentinel;

    private TDoublePQ(T sentinel, int size) {
      super(size);
      this.sentinel = sentinel;
    }

    @Override
    protected boolean lessThan(Pair<T, Double> a, Pair<T, Double> b) {
      return a.getSecond().compareTo(b.getSecond()) < 0;
    }

    @Override
    protected Pair<T, Double> getSentinelObject() {
      return Pair.of(sentinel, Double.NEGATIVE_INFINITY);
    }
  }
}
TOP

Related Classes of org.apache.mahout.utils.vectors.VectorHelper$TDoublePQ

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.