package winterwell.utils.io;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.lang.reflect.Field;
import java.math.BigInteger;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.persistence.EntityManager;
import javax.persistence.Query;
import winterwell.utils.Printer;
import winterwell.utils.ReflectionUtils;
import winterwell.utils.StrUtils;
import winterwell.utils.Utils;
import winterwell.utils.containers.Cache;
import winterwell.utils.containers.Containers;
import winterwell.utils.containers.Pair;
class RSIterator implements Iterator<Object[]> {
private int cols;
private Boolean hasNext;
private final ResultSet rs;
public RSIterator(ResultSet rs) {
assert rs != null;
this.rs = rs;
try {
cols = rs.getMetaData().getColumnCount();
} catch (SQLException e) {
throw Utils.runtime(e);
}
}
private void advanceIfNeeded() {
if (hasNext != null)
return;
try {
hasNext = rs.next();
} catch (SQLException e) {
throw Utils.runtime(e);
}
}
@Override
public boolean hasNext() {
// handle repeated calls without repeated advances
advanceIfNeeded();
return hasNext;
}
@Override
public Object[] next() {
// do we need to advance? Not if #hasNext() was called just before
advanceIfNeeded();
// Either next or hasNext will now trigger an advance
hasNext = null;
try {
Object[] row = new Object[cols];
for (int i = 0; i < row.length; i++) {
row[i] = rs.getObject(i + 1);
}
return row;
} catch (SQLException e) {
throw Utils.runtime(e);
}
}
@Override
public void remove() {
try {
rs.deleteRow();
} catch (SQLException e) {
throw Utils.runtime(e);
}
}
}
/**
* Static general purpose SQL utils
*
* @author daniel
* @testedby {@link SqlUtilsTest}
*/
public class SqlUtils {
/**
* table name is lower-cased.
*
* @see SqlUtils#getTableColumns(String)
*/
static final Cache<String, List<Pair<String>>> table2columns = new Cache(
100);
/**
* Provides streaming row-by-row access, if the resultset was created for
* that.
*
* @param rs
* @return
*/
public static Iterable<Object[]> asIterable(final ResultSet rs) {
return new Iterable<Object[]>() {
@Override
public Iterator<Object[]> iterator() {
return new RSIterator(rs);
}
};
}
private static void close(Statement statement) {
try {
statement.close();
} catch (SQLException e) {
throw Utils.runtime(e);
}
}
// public <X> List<X> deserialise(ResultSet results, Class<X> klass) {
// }
public static boolean existsTable(String table, EntityManager em) {
try {
List<Pair<String>> cols = getTableColumns(em, table);
assert !cols.isEmpty();
return true;
} catch (IllegalArgumentException ex) {
// it doesn't exist
return false;
}
}
/**
*
* @param col
* Case insensitive
* @param table
* Case insensitive
* @param em
* @return -1 if not found
* @throws IllegalArgumentException
* if the table does not exist
*/
public static int getColumnIndex(EntityManager em, String col, String table) {
assert table != null && col != null;
List<Pair<String>> cols = getTableColumns(em, table);
return getColumnIndex(col, cols);
}
/**
*
* @param col
* @param cols
* @return -1 if not found
*/
private static int getColumnIndex(String col, List<Pair<String>> cols) {
assert cols != null;
for (int i = 0; i < cols.size(); i++) {
Pair<String> pair = cols.get(i);
if (pair.first.equalsIgnoreCase(col))
return i;
}
return -1;
}
/**
* TODO ??How to use the connection pool?? (we favour bonecp over c3p0)
*
* @param dbUrl
* @param user
* @param password
* @return
*/
public static Connection getConnection(String dbUrl, String user,
String password) {
try {
// HACK assume Postgres & load the Postgres driver if needed
Class.forName("org.postgresql.Driver");
Connection con = DriverManager.getConnection(dbUrl, user, password);
// This is needed for streaming mode, so set it on by default
// -- you must then explicitly call commit
con.setAutoCommit(false);
return con;
} catch (Exception e) {
throw Utils.runtime(e);
}
}
public static List<Pair<String>> getTableColumns(Connection con, String tbl) {
Statement statement = null;
try {
statement = con.createStatement();
statement.setMaxRows(1);
ResultSet rs = statement.executeQuery("select * from " + tbl);
ResultSetMetaData meta = rs.getMetaData();
List<Pair<String>> list = new ArrayList();
for (int i = 0; i < meta.getColumnCount(); i++) {
String name = meta.getColumnName(i + 1);
String type = meta.getColumnTypeName(i + 1);
Pair p = new Pair(name, type);
list.add(p);
}
return list;
} catch (SQLException e) {
throw Utils.runtime(e);
} finally {
close(statement);
}
}
/**
* Get info on a table scheme. This uses a cache, so repeated calls are
* cheap.
*
* @param table
* Case insensitive
* @return List of (column_name, data_type) pairs
* @throws IllegalArgumentException
* if the table does not exist
*/
public static List<Pair<String>> getTableColumns(EntityManager em,
String table) {
table = table.toLowerCase();
// cached?
List<Pair<String>> cols = table2columns.get(table);
if (cols != null)
return cols;
Query q = em
.createNativeQuery("SELECT column_name, data_type FROM information_schema.columns WHERE lower(table_name) = '"
+ table + "' order by ordinal_position;");
List<Object[]> rs = q.getResultList();
if (rs.isEmpty())
throw new IllegalArgumentException("No such table: " + table);
cols = new ArrayList();
for (Object[] objects : rs) {
Pair p = new Pair(objects[0], objects[1]);
cols.add(p);
}
// cache it
table2columns.put(table, cols);
return cols;
}
/**
* A crude imitation of what Hibernate does: mapping database rows into Java
* objects.
*
* @param select
* The select clause. E.g. "x.xid,x.contents". This will be split
* to WARNING: assumes you use x. to refer to the main object.
* @param rs
* Database rows
* @param klass
* @return
*/
public static <X> List<X> inflate(String select, List<Object[]> rs,
Class<X> klass) {
try {
// x.id, x.contents, etc.
Field[] fields = inflate2_whichColumns(select, klass);
// process rows
List<X> rows = new ArrayList(rs.size());
for (Object[] row : rs) {
X x = klass.newInstance();
for (int i = 0; i < row.length; i++) {
Field f = fields[i];
if (f == null) {
continue;
}
Object val = row[i];
// conversions!
val = inflate2_convert(val, f.getType());
f.set(x, val);
}
rows.add(x);
}
return rows;
} catch (Exception ex) {
throw Utils.runtime(ex);
}
}
private static Object inflate2_convert(Object val, Class type)
throws Exception {
if (val == null)
return null;
Class<? extends Object> vc = val.getClass();
if (vc == type)
return val;
// numbers
if (vc == BigInteger.class) {
if (type == Long.class)
return ((BigInteger) val).longValue();
if (type == Double.class)
return ((BigInteger) val).doubleValue();
}
// enums
if (ReflectionUtils.isa(type, Enum.class)) {
Object[] ks = type.getEnumConstants();
assert ks != null : type;
assert val instanceof Integer : val;
int i = (Integer) val;
return ks[i];
}
// exceptions
if (ReflectionUtils.isa(type, Throwable.class)) {
byte[] bytes = (byte[]) val;
InputStream in = new FastByteArrayInputStream(bytes, bytes.length);
ObjectInputStream objIn = new ObjectInputStream(in);
Object data = objIn.readObject();
objIn.close();
return data;
}
return val;
}
/**
* Assume column-names = field-names, and we use "x.field" to refer to them
* in the select!
*
* @param select
* @param klass
* @return
*/
static <X> Field[] inflate2_whichColumns(String select, Class<X> klass) {
// We just want the top-level selected columns
Pattern pSelectColsFrom = Pattern.compile("^select\\s+(.+?)\\s+from",
Pattern.CASE_INSENSITIVE);
Matcher m = pSelectColsFrom.matcher(select);
if (m.find()) {
select = m.group(1);
}
// NB: these are "x.field"
// Build map of column-num to Field
String[] cols = select.split(",");
Field[] fields = new Field[cols.length];
for (int i = 0; i < fields.length; i++) {
if (!cols[i].startsWith("x.")) {
// ignore it?!
// This allows for other columns to be mixed in
}
String col = cols[i].substring(2);
Field field = ReflectionUtils.getField(klass, col);
if (field == null) {
// ignore it?!
// This allows for other columns to be mixed in
continue;
}
field.setAccessible(true);
fields[i] = field;
}
return fields;
}
/**
* Encode the text to escape any SQL characters, and add surrounding 's.
* This is for use as String constants -- not for use in LIKE statements
* (which need extra escaping).
*
* @param text
* Can be null (returns "null").
* @return e.g. don't => 'don''t'
*/
public static String sqlEncode(String text) {
if (text == null)
return "null";
text = "'" + text.replace("'", "''") + "'";
return text;
}
/**
* upsert = update if exists + insert if new
*
* @param table
* @param idColumns
* The columns which identify the row. E.g. the primary key. Must
* not contain any nulls. Should be a subset of columns.
* @param col2val
* The values to set for every non-null column. Any missing
* columns will be set to null.
* @param specialCaseId
* If true, the "id" column is treated specially --
* nextval(hibernate_sequence) is used for the insert, and no
* change is made on update. This is a hack 'cos row ids don't
* travel across servers.
* @return An upsert query, with it's parameters set.
*/
public static Query upsert(EntityManager em, String table,
String[] idColumns, Map<String, ?> col2val, boolean specialCaseId) {
List<Pair<String>> columnInfo = getTableColumns(em, table);
// safety check inputs
assert idColumns.length <= columnInfo.size();
for (String idc : idColumns) {
assert idc != null : Printer.toString(idColumns);
int i = getColumnIndex(idc, columnInfo);
assert i != -1 : idc + " vs " + columnInfo;
}
if (specialCaseId) {
assert !Containers.contains("id", idColumns);
assert getColumnIndex(em, "id", table) != -1 : table + " = "
+ getTableColumns(em, table);
} else {
assert !col2val.containsKey("id") : col2val;
}
// the identifying where clause
StringBuilder whereClause = new StringBuilder(" where (");
for (int i = 0; i < idColumns.length; i++) {
String col = idColumns[i];
// SQL hack: " is null" not "=null"
if (col2val.get(col) == null) {
whereClause.append(col + " is null and ");
continue;
}
whereClause.append(col + "=:" + col + " and ");
}
StrUtils.pop(whereClause, 4);
whereClause.append(")");
StringBuilder upsert = new StringBuilder();
// 1. update where exists
// ... create a=:a parameters
upsert.append("update " + table + " set ");
for (Pair<String> colInfo : columnInfo) {
// Keep Hibernate happy - avoid oid and setParameter with null
if ("oid".equals(colInfo.second)
|| col2val.get(colInfo.first) == null) {
upsert.append(colInfo.first + "=null,");
continue;
}
// SoDash/Hibernate HACK: don't change the id!
if (specialCaseId && "id".equals(colInfo.first)) {
continue;
}
// a normal update
upsert.append(colInfo.first + "=:" + colInfo.first + ",");
}
// lose the trailing ,
if (upsert.charAt(upsert.length() - 1) == ',') {
StrUtils.pop(upsert, 1);
}
upsert.append(whereClause);
upsert.append(";\n");
// 2. insert where not exists
upsert.append("insert into " + table + " select ");
for (Pair<String> colInfo : columnInfo) {
// Keep Hibernate happy - avoid oid and setParameter with null
if ("oid".equals(colInfo.second)
|| col2val.get(colInfo.first) == null) {
upsert.append("null,");
continue;
}
// SoDash/Hibernate HACK: ignore any given id value & use the local
// sequence?
if (specialCaseId && "id".equals(colInfo.first)) {
upsert.append("nextval('hibernate_sequence'),");
continue;
}
// a normal insert
upsert.append(":" + colInfo.first + ",");
}
StrUtils.pop(upsert, 1);
upsert.append(" where not exists (select 1 from " + table + whereClause
+ ");");
// create query
Query q = em.createNativeQuery(upsert.toString());
// set params
for (Pair<String> colInfo : columnInfo) {
Object v = col2val.get(colInfo.first);
// did this column get ignored?
if ("oid".equals(colInfo.second) || v == null) {
continue;
}
if (specialCaseId && "id".equals(colInfo.first)) {
continue;
}
// set param
q.setParameter(colInfo.first, v);
}
return q;
}
}