Package winterwell.utils.io

Source Code of winterwell.utils.io.RSIterator

package winterwell.utils.io;

import java.io.InputStream;
import java.io.ObjectInputStream;
import java.lang.reflect.Field;
import java.math.BigInteger;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import javax.persistence.EntityManager;
import javax.persistence.Query;

import winterwell.utils.Printer;
import winterwell.utils.ReflectionUtils;
import winterwell.utils.StrUtils;
import winterwell.utils.Utils;
import winterwell.utils.containers.Cache;
import winterwell.utils.containers.Containers;
import winterwell.utils.containers.Pair;

class RSIterator implements Iterator<Object[]> {

  private int cols;
  private Boolean hasNext;
  private final ResultSet rs;

  public RSIterator(ResultSet rs) {
    assert rs != null;
    this.rs = rs;
    try {
      cols = rs.getMetaData().getColumnCount();
    } catch (SQLException e) {
      throw Utils.runtime(e);
    }
  }

  private void advanceIfNeeded() {
    if (hasNext != null)
      return;
    try {
      hasNext = rs.next();
    } catch (SQLException e) {
      throw Utils.runtime(e);
    }
  }

  @Override
  public boolean hasNext() {
    // handle repeated calls without repeated advances
    advanceIfNeeded();
    return hasNext;
  }

  @Override
  public Object[] next() {
    // do we need to advance? Not if #hasNext() was called just before
    advanceIfNeeded();
    // Either next or hasNext will now trigger an advance
    hasNext = null;
    try {
      Object[] row = new Object[cols];
      for (int i = 0; i < row.length; i++) {
        row[i] = rs.getObject(i + 1);
      }
      return row;
    } catch (SQLException e) {
      throw Utils.runtime(e);
    }
  }

  @Override
  public void remove() {
    try {
      rs.deleteRow();
    } catch (SQLException e) {
      throw Utils.runtime(e);
    }
  }

}

/**
* Static general purpose SQL utils
*
* @author daniel
* @testedby {@link SqlUtilsTest}
*/
public class SqlUtils {

  /**
   * table name is lower-cased.
   *
   * @see SqlUtils#getTableColumns(String)
   */
  static final Cache<String, List<Pair<String>>> table2columns = new Cache(
      100);

  /**
   * Provides streaming row-by-row access, if the resultset was created for
   * that.
   *
   * @param rs
   * @return
   */
  public static Iterable<Object[]> asIterable(final ResultSet rs) {
    return new Iterable<Object[]>() {
      @Override
      public Iterator<Object[]> iterator() {
        return new RSIterator(rs);
      }
    };
  }

  private static void close(Statement statement) {
    try {
      statement.close();
    } catch (SQLException e) {
      throw Utils.runtime(e);
    }
  }

  // public <X> List<X> deserialise(ResultSet results, Class<X> klass) {
  // }

  public static boolean existsTable(String table, EntityManager em) {
    try {
      List<Pair<String>> cols = getTableColumns(em, table);
      assert !cols.isEmpty();
      return true;
    } catch (IllegalArgumentException ex) {
      // it doesn't exist
      return false;
    }
  }

  /**
   *
   * @param col
   *            Case insensitive
   * @param table
   *            Case insensitive
   * @param em
   * @return -1 if not found
   * @throws IllegalArgumentException
   *             if the table does not exist
   */
  public static int getColumnIndex(EntityManager em, String col, String table) {
    assert table != null && col != null;
    List<Pair<String>> cols = getTableColumns(em, table);
    return getColumnIndex(col, cols);
  }

  /**
   *
   * @param col
   * @param cols
   * @return -1 if not found
   */
  private static int getColumnIndex(String col, List<Pair<String>> cols) {
    assert cols != null;
    for (int i = 0; i < cols.size(); i++) {
      Pair<String> pair = cols.get(i);
      if (pair.first.equalsIgnoreCase(col))
        return i;
    }
    return -1;
  }

  /**
   * TODO ??How to use the connection pool?? (we favour bonecp over c3p0)
   *
   * @param dbUrl
   * @param user
   * @param password
   * @return
   */
  public static Connection getConnection(String dbUrl, String user,
      String password) {
    try {
      // HACK assume Postgres & load the Postgres driver if needed
      Class.forName("org.postgresql.Driver");
      Connection con = DriverManager.getConnection(dbUrl, user, password);

      // This is needed for streaming mode, so set it on by default
      // -- you must then explicitly call commit
      con.setAutoCommit(false);

      return con;
    } catch (Exception e) {
      throw Utils.runtime(e);
    }
  }

  public static List<Pair<String>> getTableColumns(Connection con, String tbl) {
    Statement statement = null;
    try {
      statement = con.createStatement();
      statement.setMaxRows(1);
      ResultSet rs = statement.executeQuery("select * from " + tbl);
      ResultSetMetaData meta = rs.getMetaData();
      List<Pair<String>> list = new ArrayList();
      for (int i = 0; i < meta.getColumnCount(); i++) {
        String name = meta.getColumnName(i + 1);
        String type = meta.getColumnTypeName(i + 1);
        Pair p = new Pair(name, type);
        list.add(p);
      }
      return list;
    } catch (SQLException e) {
      throw Utils.runtime(e);
    } finally {
      close(statement);
    }
  }

  /**
   * Get info on a table scheme. This uses a cache, so repeated calls are
   * cheap.
   *
   * @param table
   *            Case insensitive
   * @return List of (column_name, data_type) pairs
   * @throws IllegalArgumentException
   *             if the table does not exist
   */
  public static List<Pair<String>> getTableColumns(EntityManager em,
      String table) {
    table = table.toLowerCase();

    // cached?
    List<Pair<String>> cols = table2columns.get(table);
    if (cols != null)
      return cols;

    Query q = em
        .createNativeQuery("SELECT column_name, data_type FROM information_schema.columns WHERE lower(table_name) = '"
            + table + "' order by ordinal_position;");
    List<Object[]> rs = q.getResultList();

    if (rs.isEmpty())
      throw new IllegalArgumentException("No such table: " + table);

    cols = new ArrayList();
    for (Object[] objects : rs) {
      Pair p = new Pair(objects[0], objects[1]);
      cols.add(p);
    }

    // cache it
    table2columns.put(table, cols);
    return cols;
  }

  /**
   * A crude imitation of what Hibernate does: mapping database rows into Java
   * objects.
   *
   * @param select
   *            The select clause. E.g. "x.xid,x.contents". This will be split
   *            to WARNING: assumes you use x. to refer to the main object.
   * @param rs
   *            Database rows
   * @param klass
   * @return
   */
  public static <X> List<X> inflate(String select, List<Object[]> rs,
      Class<X> klass) {
    try {
      // x.id, x.contents, etc.
      Field[] fields = inflate2_whichColumns(select, klass);

      // process rows
      List<X> rows = new ArrayList(rs.size());
      for (Object[] row : rs) {
        X x = klass.newInstance();
        for (int i = 0; i < row.length; i++) {
          Field f = fields[i];
          if (f == null) {
            continue;
          }
          Object val = row[i];
          // conversions!
          val = inflate2_convert(val, f.getType());
          f.set(x, val);
        }
        rows.add(x);
      }

      return rows;
    } catch (Exception ex) {
      throw Utils.runtime(ex);
    }

  }

  private static Object inflate2_convert(Object val, Class type)
      throws Exception {
    if (val == null)
      return null;
    Class<? extends Object> vc = val.getClass();
    if (vc == type)
      return val;
    // numbers
    if (vc == BigInteger.class) {
      if (type == Long.class)
        return ((BigInteger) val).longValue();
      if (type == Double.class)
        return ((BigInteger) val).doubleValue();
    }
    // enums
    if (ReflectionUtils.isa(type, Enum.class)) {
      Object[] ks = type.getEnumConstants();
      assert ks != null : type;
      assert val instanceof Integer : val;
      int i = (Integer) val;
      return ks[i];
    }
    // exceptions
    if (ReflectionUtils.isa(type, Throwable.class)) {
      byte[] bytes = (byte[]) val;
      InputStream in = new FastByteArrayInputStream(bytes, bytes.length);
      ObjectInputStream objIn = new ObjectInputStream(in);
      Object data = objIn.readObject();
      objIn.close();
      return data;
    }
    return val;
  }

  /**
   * Assume column-names = field-names, and we use "x.field" to refer to them
   * in the select!
   *
   * @param select
   * @param klass
   * @return
   */
  static <X> Field[] inflate2_whichColumns(String select, Class<X> klass) {
    // We just want the top-level selected columns
    Pattern pSelectColsFrom = Pattern.compile("^select\\s+(.+?)\\s+from",
        Pattern.CASE_INSENSITIVE);
    Matcher m = pSelectColsFrom.matcher(select);
    if (m.find()) {
      select = m.group(1);
    }

    // NB: these are "x.field"
    // Build map of column-num to Field
    String[] cols = select.split(",");
    Field[] fields = new Field[cols.length];
    for (int i = 0; i < fields.length; i++) {
      if (!cols[i].startsWith("x.")) {
        // ignore it?!
        // This allows for other columns to be mixed in
      }
      String col = cols[i].substring(2);
      Field field = ReflectionUtils.getField(klass, col);
      if (field == null) {
        // ignore it?!
        // This allows for other columns to be mixed in
        continue;
      }
      field.setAccessible(true);
      fields[i] = field;
    }
    return fields;
  }

  /**
   * Encode the text to escape any SQL characters, and add surrounding 's.
   * This is for use as String constants -- not for use in LIKE statements
   * (which need extra escaping).
   *
   * @param text
   *            Can be null (returns "null").
   * @return e.g. don't => 'don''t'
   */
  public static String sqlEncode(String text) {
    if (text == null)
      return "null";
    text = "'" + text.replace("'", "''") + "'";
    return text;
  }

  /**
   * upsert = update if exists + insert if new
   *
   * @param table
   * @param idColumns
   *            The columns which identify the row. E.g. the primary key. Must
   *            not contain any nulls. Should be a subset of columns.
   * @param col2val
   *            The values to set for every non-null column. Any missing
   *            columns will be set to null.
   * @param specialCaseId
   *            If true, the "id" column is treated specially --
   *            nextval(hibernate_sequence) is used for the insert, and no
   *            change is made on update. This is a hack 'cos row ids don't
   *            travel across servers.
   * @return An upsert query, with it's parameters set.
   */
  public static Query upsert(EntityManager em, String table,
      String[] idColumns, Map<String, ?> col2val, boolean specialCaseId) {
    List<Pair<String>> columnInfo = getTableColumns(em, table);
    // safety check inputs
    assert idColumns.length <= columnInfo.size();
    for (String idc : idColumns) {
      assert idc != null : Printer.toString(idColumns);
      int i = getColumnIndex(idc, columnInfo);
      assert i != -1 : idc + " vs " + columnInfo;
    }
    if (specialCaseId) {
      assert !Containers.contains("id", idColumns);
      assert getColumnIndex(em, "id", table) != -1 : table + " = "
          + getTableColumns(em, table);
    } else {
      assert !col2val.containsKey("id") : col2val;
    }

    // the identifying where clause
    StringBuilder whereClause = new StringBuilder(" where (");
    for (int i = 0; i < idColumns.length; i++) {
      String col = idColumns[i];
      // SQL hack: " is null" not "=null"
      if (col2val.get(col) == null) {
        whereClause.append(col + " is null and ");
        continue;
      }
      whereClause.append(col + "=:" + col + " and ");
    }
    StrUtils.pop(whereClause, 4);
    whereClause.append(")");

    StringBuilder upsert = new StringBuilder();

    // 1. update where exists
    // ... create a=:a parameters
    upsert.append("update " + table + " set ");
    for (Pair<String> colInfo : columnInfo) {
      // Keep Hibernate happy - avoid oid and setParameter with null
      if ("oid".equals(colInfo.second)
          || col2val.get(colInfo.first) == null) {
        upsert.append(colInfo.first + "=null,");
        continue;
      }
      // SoDash/Hibernate HACK: don't change the id!
      if (specialCaseId && "id".equals(colInfo.first)) {
        continue;
      }
      // a normal update
      upsert.append(colInfo.first + "=:" + colInfo.first + ",");
    }
    // lose the trailing ,
    if (upsert.charAt(upsert.length() - 1) == ',') {
      StrUtils.pop(upsert, 1);
    }
    upsert.append(whereClause);
    upsert.append(";\n");

    // 2. insert where not exists
    upsert.append("insert into " + table + " select ");
    for (Pair<String> colInfo : columnInfo) {
      // Keep Hibernate happy - avoid oid and setParameter with null
      if ("oid".equals(colInfo.second)
          || col2val.get(colInfo.first) == null) {
        upsert.append("null,");
        continue;
      }
      // SoDash/Hibernate HACK: ignore any given id value & use the local
      // sequence?
      if (specialCaseId && "id".equals(colInfo.first)) {
        upsert.append("nextval('hibernate_sequence'),");
        continue;
      }
      // a normal insert
      upsert.append(":" + colInfo.first + ",");
    }
    StrUtils.pop(upsert, 1);
    upsert.append(" where not exists (select 1 from " + table + whereClause
        + ");");

    // create query
    Query q = em.createNativeQuery(upsert.toString());

    // set params
    for (Pair<String> colInfo : columnInfo) {
      Object v = col2val.get(colInfo.first);
      // did this column get ignored?
      if ("oid".equals(colInfo.second) || v == null) {
        continue;
      }
      if (specialCaseId && "id".equals(colInfo.first)) {
        continue;
      }
      // set param
      q.setParameter(colInfo.first, v);
    }

    return q;
  }

}
TOP

Related Classes of winterwell.utils.io.RSIterator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.