Package com.ansj.vec

Source Code of com.ansj.vec.Word2VEC

package com.ansj.vec;

import java.io.BufferedInputStream;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.TreeSet;

import org.ansj.util.MatrixUtil;

import com.ansj.vec.domain.WordEntry;

public class Word2VEC {

  public static void main(String[] args) throws IOException {

    // Learn learn = new Learn();
    // learn.learnFile(new File("library/xh.txt"));
    // learn.saveModel(new File("library/javaSkip1"));

    Word2VEC vec = new Word2VEC();
    vec.loadJavaModel("library/javaSkip1");

    // System.out.println("中国" + "\t" +
    // Arrays.toString(vec.getWordVector("中国")));
    // ;
    // System.out.println("毛泽东" + "\t" +
    // Arrays.toString(vec.getWordVector("毛泽东")));
    // ;
    // System.out.println("足球" + "\t" +
    // Arrays.toString(vec.getWordVector("足球")));

    // Word2VEC vec2 = new Word2VEC();
    // vec2.loadGoogleModel("library/vectors.bin") ;
    //
    //
    String str = "毛泽东";
    long start = System.currentTimeMillis();
    for (int i = 0; i < 100; i++) {
      System.out.println(vec.distance(str));
      ;
    }
    System.out.println(System.currentTimeMillis() - start);

    System.out.println(System.currentTimeMillis() - start);
    // System.out.println(vec2.distance(str));
    //
    //
    // //男人 国王 女人
    // System.out.println(vec.analogy("邓小平", "毛泽东思想", "毛泽东"));
    // System.out.println(vec2.analogy("毛泽东", "毛泽东思想", "邓小平"));
  }

  private HashMap<String, float[]> wordMap = new HashMap<String, float[]>();

  private int words;
  private int size;
  private int topNSize = 40;

  /**
   * 加载模型
   *
   * @param path
   *            模型的路径
   * @throws IOException
   */
  public void loadGoogleModel(String path) throws IOException {
    DataInputStream dis = null;
    BufferedInputStream bis = null;
    double len = 0;
    float vector = 0;
    try {
      bis = new BufferedInputStream(new FileInputStream(path));
      dis = new DataInputStream(bis);
      // //读取词数
      words = Integer.parseInt(readString(dis));
      // //大小
      size = Integer.parseInt(readString(dis));
      String word;
      float[] vectors = null;
      for (int i = 0; i < words; i++) {
        word = readString(dis);
        vectors = new float[size];
        len = 0;
        for (int j = 0; j < size; j++) {
          vector = readFloat(dis);
          len += vector * vector;
          vectors[j] = (float) vector;
        }
        len = Math.sqrt(len);

        for (int j = 0; j < size; j++) {
          vectors[j] /= len;
        }

        wordMap.put(word, vectors);
        dis.read();
      }
    } finally {
      bis.close();
      dis.close();
    }
  }

  /**
   * 加载模型
   *
   * @param path
   *            模型的路径
   * @throws IOException
   */
  public void loadJavaModel(String path) throws IOException {
    try (DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(path)))) {
      words = dis.readInt();
      size = dis.readInt();

      float vector = 0;

      String key = null;
      float[] value = null;
      for (int i = 0; i < words; i++) {
        double len = 0;
        key = dis.readUTF();
        value = new float[size];
        for (int j = 0; j < size; j++) {
          vector = dis.readFloat();
          len += vector * vector;
          value[j] = vector;
        }

        len = Math.sqrt(len);

        for (int j = 0; j < size; j++) {
          value[j] /= len;
        }
        wordMap.put(key, value);
      }

    }
  }

  private static final int MAX_SIZE = 50;

  /**
   * 近义词
   *
   * @return
   */
  public TreeSet<WordEntry> analogy(String word0, String word1, String word2) {
    float[] wv0 = getWordVector(word0);
    float[] wv1 = getWordVector(word1);
    float[] wv2 = getWordVector(word2);

    if (wv1 == null || wv2 == null || wv0 == null) {
      return null;
    }
    float[] wordVector = new float[size];
    for (int i = 0; i < size; i++) {
      wordVector[i] = wv1[i] - wv0[i] + wv2[i];
    }
    float[] tempVector;
    String name;
    List<WordEntry> wordEntrys = new ArrayList<WordEntry>(topNSize);
    for (Entry<String, float[]> entry : wordMap.entrySet()) {
      name = entry.getKey();
      if (name.equals(word0) || name.equals(word1) || name.equals(word2)) {
        continue;
      }
      float dist = 0;
      tempVector = entry.getValue();
      for (int i = 0; i < wordVector.length; i++) {
        dist += wordVector[i] * tempVector[i];
      }
      insertTopN(name, dist, wordEntrys);
    }
    return new TreeSet<WordEntry>(wordEntrys);
  }

  private void insertTopN(String name, float score, List<WordEntry> wordsEntrys) {
    // TODO Auto-generated method stub
    if (wordsEntrys.size() < topNSize) {
      wordsEntrys.add(new WordEntry(name, score));
      return;
    }
    float min = Float.MAX_VALUE;
    int minOffe = 0;
    for (int i = 0; i < topNSize; i++) {
      WordEntry wordEntry = wordsEntrys.get(i);
      if (min > wordEntry.score) {
        min = wordEntry.score;
        minOffe = i;
      }
    }

    if (score > min) {
      wordsEntrys.set(minOffe, new WordEntry(name, score));
    }

  }

  public Set<WordEntry> distance(String queryWord) {

    float[] center = wordMap.get(queryWord);
    if (center == null) {
      return Collections.emptySet();
    }

    int resultSize = wordMap.size() < topNSize ? wordMap.size() : topNSize;
    TreeSet<WordEntry> result = new TreeSet<WordEntry>();

    double min = Float.MIN_VALUE;
    for (Map.Entry<String, float[]> entry : wordMap.entrySet()) {
      float[] vector = entry.getValue();
      float dist = 0;
      for (int i = 0; i < vector.length; i++) {
        dist += center[i] * vector[i];
      }

      if (dist > min) {
        result.add(new WordEntry(entry.getKey(), dist));
        if (resultSize < result.size()) {
          result.pollLast();
        }
        min = result.last().score;
      }
    }
    result.pollFirst();

    return result;
  }

  public Set<WordEntry> distance(List<String> words) {

    float[] center = null;
    for (String word : words) {
      center = sum(center, wordMap.get(word));
    }

    if (center == null) {
      return Collections.emptySet();
    }

    int resultSize = wordMap.size() < topNSize ? wordMap.size() : topNSize;
    TreeSet<WordEntry> result = new TreeSet<WordEntry>();

    double min = Float.MIN_VALUE;
    for (Map.Entry<String, float[]> entry : wordMap.entrySet()) {
      float[] vector = entry.getValue();
      float dist = 0;
      for (int i = 0; i < vector.length; i++) {
        dist += center[i] * vector[i];
      }

      if (dist > min) {
        result.add(new WordEntry(entry.getKey(), dist));
        if (resultSize < result.size()) {
          result.pollLast();
        }
        min = result.last().score;
      }
    }
    result.pollFirst();

    return result;
  }

  private float[] sum(float[] center, float[] fs) {
    // TODO Auto-generated method stub

    if (center == null && fs == null) {
      return null;
    }

    if (fs == null) {
      return center;
    }

    if (center == null) {
      return fs;
    }

    for (int i = 0; i < fs.length; i++) {
      center[i] += fs[i];
    }

    return center;
  }

  /**
   * 得到词向量
   *
   * @param word
   * @return
   */
  public float[] getWordVector(String word) {
    return wordMap.get(word);
  }

  public static float readFloat(InputStream is) throws IOException {
    byte[] bytes = new byte[4];
    is.read(bytes);
    return getFloat(bytes);
  }

  /**
   * 读取一个float
   *
   * @param b
   * @return
   */
  public static float getFloat(byte[] b) {
    int accum = 0;
    accum = accum | (b[0] & 0xff) << 0;
    accum = accum | (b[1] & 0xff) << 8;
    accum = accum | (b[2] & 0xff) << 16;
    accum = accum | (b[3] & 0xff) << 24;
    return Float.intBitsToFloat(accum);
  }

  /**
   * 读取一个字符串
   *
   * @param dis
   * @return
   * @throws IOException
   */
  private static String readString(DataInputStream dis) throws IOException {
    // TODO Auto-generated method stub
    byte[] bytes = new byte[MAX_SIZE];
    byte b = dis.readByte();
    int i = -1;
    StringBuilder sb = new StringBuilder();
    while (b != 32 && b != 10) {
      i++;
      bytes[i] = b;
      b = dis.readByte();
      if (i == 49) {
        sb.append(new String(bytes));
        i = -1;
        bytes = new byte[MAX_SIZE];
      }
    }
    sb.append(new String(bytes, 0, i + 1));
    return sb.toString();
  }

  public int getTopNSize() {
    return topNSize;
  }

  public void setTopNSize(int topNSize) {
    this.topNSize = topNSize;
  }

  public HashMap<String, float[]> getWordMap() {
    return wordMap;
  }

  public int getWords() {
    return words;
  }

  public int getSize() {
    return size;
  }

}
TOP

Related Classes of com.ansj.vec.Word2VEC

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.