Package edu.umd.hooka.ttables

Source Code of edu.umd.hooka.ttables.TTable_monolithic

package edu.umd.hooka.ttables;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.IntBuffer;
import java.nio.FloatBuffer;

import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.log4j.Logger;
import org.apache.log4j.Level;

import edu.umd.hooka.alignment.IndexedFloatArray;


/**
* Data structure that stores translation probabilities p(f|e) or
* counts c(f,e).  The set of values (*,e) associated with a particular
* e are stored adjacent to one another in an array.  For a given f
* in (f,e) the location of the f is found using a binary search.
*
* Layout: http://www.umiacs.umd.edu/~redpony/ttable-structure.png
*
* @author redpony
*
*/
public class TTable_monolithic extends TTable implements Cloneable {

  int[] _ef;                       // length <= |E|x|F|
  int[] _e;                        // length = |E|
  float[] _values;
  IndexedFloatArray _nullValues;   // length = |F|
  Path _datapath;
  FileSystem _fs;
 
  private static final Logger myLogger = Logger.getLogger(TTable_monolithic.class);
 
  int eLen;
  int indexLen;
 
  public Object clone() {
    TTable_monolithic res = new TTable_monolithic();
    res._ef = _ef.clone();
    res._e = _e.clone();
    res._values = _values.clone();
    res.eLen = eLen;
    res.indexLen = indexLen;
    res._nullValues = (IndexedFloatArray)_nullValues.clone();
    return res;
  }
 
  public TTable_monolithic() {}
  public TTable_monolithic(FileSystem fs, Path p) throws IOException {
    _fs = fs; _datapath = p;
    this.readFields(_fs.open(_datapath));
  }

  public TTable_monolithic(int[] e, int[] ef, int maxF) {
    _ef = ef;
    _e = e;
    _nullValues = new IndexedFloatArray(maxF + 1);
    eLen = _e.length;
    indexLen = _ef.length;
    _values = new float[indexLen];
  }

  public TTable_monolithic(int[] e, int[] ef, int maxF, FileSystem fs, Path p) {
    _ef = ef;
    _e = e;
    _nullValues = new IndexedFloatArray(maxF + 1);
    eLen = _e.length;
    indexLen = _ef.length;
    _values = new float[indexLen];
    _fs = fs;
    _datapath = p;
  }
 
  public int getMaxF() {
    return _nullValues.size() - 1;
  }
 
  public int getMaxE() {
    return _e.length - 1;
  }
 
  int binSearch(int e, int f) {
    int min = _e[e];
    int max = _e[e+1] - 1;
    while (min <= max) {
      int mid = (min + max) / 2;
      if (_ef[mid] > f)
        max = mid - 1;
      else if (_ef[mid] < f)
        min = mid + 1;
      else
        return mid;
    }
    throw new RuntimeException("Couldn't find (" + f + "," + e +")");
  }
   
  public void add(int e, int f, float delta) {
    if (e == 0)
      _nullValues.add(f, delta);
    else
      _values[binSearch(e,f)] += delta;
  }
 
  public void set(int e, int f, float value) {
    if (e == 0)
      _nullValues.set(f, value);
    else
      _values[binSearch(e,f)] = value;   
  }
 
  public void setDistribution(float[] x) {
    _values = x.clone();
  }
 
  public void setNullDistribution(float[] x) {
    myLogger.setLevel(Level.DEBUG);
    myLogger.debug("Length of input array is " + x.length);
    _nullValues = new IndexedFloatArray(x.length);
    _nullValues.set(0, 0.0f);
    for(int i=1; i<x.length; i++)
      _nullValues.set(i, x[i]);
  }
 
  /*
  public void setNullDistribution(float[] x) {
    _nullValues = (IndexedFloatArray)(x.clone());
  }
  */

  public void set(int e, IndexedFloatArray fs) {
    if (e == 0) {
      _nullValues.copyFrom(fs);
      return;
    }
    int from = _e[e];
    int len = _e[e+1] - from;
    if (len != fs.size())
      throw new RuntimeException("Mismatch lengths: in ttable there are " +
          fs.size() + " parameters for e="+e+ ", but in the IA there are " + len);
    fs.copyTo(_values, from);
  }
 
  /*
  public void plusEquals(int e, IndexedFloatArray fs) {
    if (e == 0)
      _nullValues.plusEquals(fs);
    int from = _e[e];
    int len = _e[e+1] - from;
    if (len != fs.size())
      throw new RuntimeException("Mismatch lengths: in ttable there are " +
          fs.size() + " parameters for e="+e+ ", but in the IA there are " + len);
    fs.addTo(_values, from);
  }*/
 
  public float get(int e, int f) {
    if (e == 0)
      return _nullValues.get(f);
    else {
      int min = _e[e];
      int max = _e[e+1] - 1;
      while (min <= max) {
        int mid = (min + max) / 2;
        if (_ef[mid] > f)
          max = mid - 1;
        else if (_ef[mid] < f)
          min = mid + 1;
        else
          return _values[mid];
      }
      return 0.0f;
    }
  }
 
  public void clear() {
    java.util.Arrays.fill(_values, 0.0f);
  }
 
  public void prune(float threshold) {
    throw new RuntimeException("Not implemented");
  }
 
  public void normalize() {
    _nullValues.normalize();
    for (int e = 1; e<_e.length - 1; e++) {
      int bf = _e[e];
      int ef = _e[e+1];
      float total = 0.0f;
      for (int f = bf; f < ef; f++) {
        total += _values[f];
      }
      // make uniform
      if (total == 0.0f) {
        float u = 1.0f/(float)(ef - bf);
        for (int f = bf; f < ef; f++)
          _values[f] = u;
      } else {
        // normalize
        for (int f = bf; f < ef; f++) {
          _values[f] /= total;
        }
      }
    }
  }
   
  public void readFields(DataInput in) throws IOException {
    int bbLen = in.readInt();
    ByteBuffer bb=ByteBuffer.allocate(bbLen);
    in.readFully(bb.array());
    IntBuffer ib = bb.asIntBuffer();
    _e = new int[bbLen/4];
    ib.get(_e);
    eLen = _e.length;

    if (_nullValues == null)
      _nullValues = new IndexedFloatArray();
    _nullValues.readFields(in);
 
    bbLen = in.readInt();
    bb=ByteBuffer.allocate(bbLen);
    in.readFully(bb.array());
    ib = bb.asIntBuffer();
    _ef = new int[bbLen/4];
    ib.get(_ef);
   
    bb=ByteBuffer.allocate(bbLen);
    in.readFully(bb.array());
    FloatBuffer fb = bb.asFloatBuffer();
    _values = new float[bbLen/4];
    fb.get(_values);
    indexLen = _values.length;
  }

  public void write(DataOutput out) throws IOException {
    int bbLen = eLen * 4;
    out.writeInt(bbLen);
    ByteBuffer bb=ByteBuffer.allocate(bbLen);
    IntBuffer ib = bb.asIntBuffer();
    ib.put(_e, 0, eLen);
    out.write(bb.array());

    _nullValues.write(out);
   
    bbLen = indexLen * 4;
    out.writeInt(bbLen);
    bb=ByteBuffer.allocate(bbLen);
    ib=bb.asIntBuffer();
    ib.put(_ef, 0, indexLen);
    out.write(bb.array());
   
    bb=ByteBuffer.allocate(bbLen);
    FloatBuffer fb=bb.asFloatBuffer();
    fb.put(_values, 0, indexLen);
    out.write(bb.array());
  }
 
  public String toString() {
    StringBuffer sb = new StringBuffer();
    sb.append("NULL: ").append(_nullValues.toString()).append("\n");
    for (int e = 1; e<_e.length - 1; e++) {
      int bfi = _e[e];
      int efi = _e[e+1];
      for (int fi = bfi; fi < efi; fi++) {
        sb.append("e=").append(e)
            .append(" f=").append(_ef[fi]).append(" val=").append(_values[fi]).append("\n");
      }
    }
    return sb.toString();
  }

  @Override
  public void write() throws IOException {
    this.write(_fs.create(_datapath));
  }
}
TOP

Related Classes of edu.umd.hooka.ttables.TTable_monolithic

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.