Package cc.mallet.grmm.util

Source Code of cc.mallet.grmm.util.ModelReader

/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.util;


import java.io.BufferedReader;
import java.io.IOException;
import java.io.Reader;
import java.util.List;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.regex.Pattern;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;

import cc.mallet.grmm.types.*;

import gnu.trove.THashMap;
import bsh.Interpreter;
import bsh.EvalError;

/**
* $Id: ModelReader.java,v 1.1 2007/10/22 21:37:58 mccallum Exp $
*/
public class ModelReader {

  private static THashMap allClasses;
  static {
    allClasses = new THashMap ();
    // add new classes here
    allClasses.put ("potts", PottsTableFactor.class);
    allClasses.put ("unary", BoltzmannUnaryFactor.class);
    allClasses.put ("binaryunary", BinaryUnaryFactor.class);
    allClasses.put ("binarypair", BoltzmannPairFactor.class);
    allClasses.put ("uniform", UniformFactor.class);
    allClasses.put ("normal", UniNormalFactor.class);
    allClasses.put ("beta", BetaFactor.class);
  }

  private THashMap name2var = new THashMap ();

  public static Assignment readFromMatrix (VarSet vars, Reader in) throws IOException
  {
    Variable[] varr = vars.toVariableArray ();
    Interpreter interpreter = new Interpreter ();
    BufferedReader bIn = new BufferedReader (in);
    Assignment assn = new Assignment ();
    String line;

    while ((line = bIn.readLine ()) != null) {
      String[] fields = line.split ("\\s+");
      Object[] vals = new Object [fields.length];
      for (int i = 0; i < fields.length; i++) {
        try {
          vals[i] = interpreter.eval (fields[i]);
        } catch (EvalError e) {
          throw new RuntimeException ("Error reading line: "+line, e);
        }
      }
      assn.addRow (varr, vals);
    }

    return assn;
  }


  public FactorGraph readModel (BufferedReader in) throws IOException
  {
    List factors = new ArrayList ();

    String line;
    while ((line = in.readLine ()) != null) {
      try {
  if (Pattern.matches ("^\\s*$", line)) { continue; }
        String[] fields = line.split ("\\s+");
        if (fields[0].equalsIgnoreCase ("VAR")) {
          // a variable declaration
          handleVariableDecl (fields);
        } else {
          // a factor line
          Factor factor = factorFromLine (fields);
          factors.add (factor);
        }
      } catch (Exception e) {
        throw new RuntimeException ("Error reading line:\n"+line, e);
      }
    }

    FactorGraph fg = new FactorGraph ();
    for (Iterator it = factors.iterator (); it.hasNext ();) {
      Factor factor = (Factor) it.next ();
      fg.multiplyBy (factor);
    }

    return fg;
  }

  private void handleVariableDecl (String[] fields)
  {
    int colonIdx = findColon (fields);

    if (fields.length != colonIdx + 2) throw new IllegalArgumentException ("Invalid syntax");

    String numOutsString = fields[colonIdx+1];
    int numOutcomes;
    if (numOutsString.equalsIgnoreCase ("continuous")) {
      numOutcomes = Variable.CONTINUOUS;
    } else {
      numOutcomes = Integer.parseInt (numOutsString);
    }

    for (int i = 0; i < colonIdx; i++) {
      String name = fields[i];
      Variable var = new Variable (numOutcomes);
      var.setLabel (name);
      name2var.put (name, var);
    }
  }

  private int findColon (String[] fields)
  {
    for (int i = 0; i < fields.length; i++) {
      if (fields[i].equals (":")) {
        return i;
      }
    }
    throw new IllegalArgumentException ("Invalid syntax.");
  }

  private Factor factorFromLine (String[] fields)
  {
    int idx = findTwiddle (fields);
    return constructFactor (fields, idx);
  }

  private int findTwiddle (String[] fields)
  {
    for (int i = 0; i < fields.length; i++) {
      if (fields[i].equals ("~")) {
        return i;
      }
    }
    return -1;
  }

  private Factor constructFactor (String[] fields, int idx)
  {
    Class factorClass = determineFactorClass (fields, idx);
    Object[] args = determineFactorArgs (fields, idx);
    Constructor factorCtor = findCtor (factorClass, args);
    Factor factor;
    try {
      factor = (Factor) factorCtor.newInstance (args);
    } catch (InstantiationException e) {
      throw new RuntimeException (e);
    } catch (IllegalAccessException e) {
      throw new RuntimeException (e);
    } catch (InvocationTargetException e) {
      throw new RuntimeException (e);
    }
    return factor;
  }

  private Constructor findCtor (Class factorClass, Object[] args)
  {
    Class[] argClass = new Class[args.length];
    for (int i = 0; i < args.length; i++) {
      argClass[i] = args[i].getClass ();
      // special case
      if (argClass[i] == Double.class) { argClass[i] = double.class; }
    }
    try {
      return factorClass.getDeclaredConstructor (argClass);
    } catch (NoSuchMethodException e) {
  StringBuffer buf = new StringBuffer("Invalid argments for factor "+factorClass+"\n");
  buf.append ("Args were:\n");
  for (int i = 0; i < args.length; i++) {
      buf.append(args[i]);
      buf.append(" ");
  }
  buf.append("\n");
  for (int i = 0; i < args.length; i++) {
      buf.append(args[i].getClass());
      buf.append(" ");
  }
  buf.append("\n");
  throw new RuntimeException (buf.toString());
    }
  }

  private Class determineFactorClass (String[] fields, int twiddleIdx)
  {
    String factorName = fields [twiddleIdx + 1].toLowerCase ();
    Class theClass = (Class) allClasses.get (factorName);
    if (theClass != null) {
      return theClass;
    } else {
      throw new RuntimeException ("Could not determine factor class from "+factorName);
    }
  }

  private Object[] determineFactorArgs (String[] fields, int twiddleIdx)
  {
    List args = new ArrayList (fields.length);
    for (int i = 0; i < twiddleIdx; i++) {
      args.add (varFromName (fields[i], true));
    }

    for (int i = twiddleIdx+2; i < fields.length; i++) {
      args.add (varFromName (fields[i], false));
    }

    return args.toArray ();
  }

    private static Pattern nbrRegex = Pattern.compile ("[+-]?\\d+(?:\\.\\d+)?(E[+-]\\d+)?");

  private Object varFromName (String name, boolean preTwiddle)
  {
    if (nbrRegex.matcher(name).matches ()) {
      return new Double (Double.parseDouble (name));
    } else if (name2var.contains (name)) {
      return name2var.get (name);
    } else {
      Variable var = (preTwiddle) ? new Variable (2) : new Variable (Variable.CONTINUOUS);
      var.setLabel (name);
      name2var.put (name, var);
      return var;
    }
  }


}
TOP

Related Classes of cc.mallet.grmm.util.ModelReader

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.