package org.hivedb.hibernate;
import org.hibernate.Interceptor;
import org.hibernate.Session;
import org.hibernate.SessionFactory;
import org.hibernate.cfg.Configuration;
import org.hibernate.dialect.H2Dialect;
import org.hibernate.dialect.MySQLInnoDBDialect;
import org.hibernate.shards.Shard;
import org.hibernate.shards.ShardId;
import org.hibernate.shards.ShardedConfiguration;
import org.hibernate.shards.cfg.ConfigurationToShardConfigurationAdapter;
import org.hibernate.shards.cfg.ShardConfiguration;
import org.hibernate.shards.session.ShardedSessionFactory;
import org.hibernate.shards.session.ShardedSessionImpl;
import org.hibernate.shards.strategy.ShardStrategy;
import org.hibernate.shards.strategy.ShardStrategyFactory;
import org.hibernate.shards.strategy.ShardStrategyImpl;
import org.hibernate.shards.strategy.access.ShardAccessStrategy;
import org.hibernate.shards.util.Lists;
import org.hibernate.shards.util.Maps;
import org.hivedb.Hive;
import org.hivedb.HiveKeyNotFoundException;
import org.hivedb.Synchronizeable;
import org.hivedb.configuration.EntityConfig;
import org.hivedb.configuration.EntityHiveConfig;
import org.hivedb.meta.Node;
import org.hivedb.meta.persistence.CachingDataSourceProvider;
import org.hivedb.util.Combiner;
import org.hivedb.util.database.DriverLoader;
import org.hivedb.util.database.HiveDbDialect;
import org.hivedb.util.functional.Atom;
import org.hivedb.util.functional.Filter;
import org.hivedb.util.functional.Filter.BinaryPredicate;
import org.hivedb.util.functional.Transform;
import org.hivedb.util.functional.Transform.IdentityFunction;
import org.hivedb.util.functional.Unary;
import java.util.*;
import java.util.Map.Entry;
// TODO Node Set session factories have to go, combinatoric compexity
public class HiveSessionFactoryBuilderImpl implements HiveSessionFactoryBuilder, HiveSessionFactory, Observer, Synchronizeable {
private static final int NODE_SET_LIMIT = 1;
private static Map<HiveDbDialect, Class<?>> dialectMap = buildDialectMap();
private Map<Set<Integer>, SessionFactory> nodeSessionFactories;
private Collection<Class<?>> hibernateClasses;
private EntityHiveConfig config;
private ShardAccessStrategy accessStrategy;
private Properties overrides = new Properties();
private ShardedSessionFactory allNodesSessionFactory = null;
private Hive hive;
public HiveSessionFactoryBuilderImpl(String hiveUri, List<Class<?>> hibernateClasses, ShardAccessStrategy strategy) {
hive = Hive.load(hiveUri, CachingDataSourceProvider.getInstance());
this.hibernateClasses = hibernateClasses;
initialize(buildHiveConfiguration(hibernateClasses), hive, strategy);
}
public HiveSessionFactoryBuilderImpl(String hiveUri, List<Class<?>> mappedClasses, ShardAccessStrategy strategy, Properties overrides) {
this(hiveUri, mappedClasses, strategy);
this.overrides = overrides;
}
public HiveSessionFactoryBuilderImpl(EntityHiveConfig config, Hive hive, ShardAccessStrategy strategy) {
this.hive = hive;
this.hibernateClasses = flattenWithAssociatedClasses(config);
initialize(config, hive, strategy);
}
public HiveSessionFactoryBuilderImpl(EntityHiveConfig config, Collection<Class<?>> mappedClasses, Hive hive, ShardAccessStrategy strategy) {
this.hive = hive;
this.hibernateClasses = mappedClasses;
initialize(config, hive, strategy);
}
@SuppressWarnings("unchecked")
private Collection<Class<?>> flattenWithAssociatedClasses(EntityHiveConfig config) {
return Filter.getUnique(Transform.flatten(
Transform.map(new Unary<EntityConfig, Class<?>>() {
public Class<?> f(EntityConfig entityConfig) {
return entityConfig.getRepresentedInterface();
}
}, config.getEntityConfigs()),
Transform.flatMap(new Unary<EntityConfig, Collection<Class<?>>>() {
public Collection<Class<?>> f(EntityConfig entityConfig) {
return entityConfig.getAssociatedClasses();
}
}, config.getEntityConfigs())));
}
private void initialize(EntityHiveConfig config, Hive hive, ShardAccessStrategy strategy) {
this.accessStrategy = strategy;
this.config = config;
this.nodeSessionFactories = buildNodeSetSessionFactories();
this.allNodesSessionFactory = buildAllNodesSessionFactory();
hive.addObserver(this);
}
public ShardedSessionFactory getSessionFactory() {
return allNodesSessionFactory;
}
private Map<Set<Integer>, SessionFactory> buildNodeSetSessionFactories() {
final Map<Integer, Configuration> hibernateConfigs = getConfigurationsFromNodes(hive);
// Build non-sharded session factories for individual single-shard access
final Map<Integer, SessionFactory> hibernateSessionFactories = Transform.toMap(
new Transform.IdentityFunction<Integer>(),
new Unary<Integer, SessionFactory>() {
public SessionFactory f(Integer nodeId) {
return hibernateConfigs.get(nodeId).buildSessionFactory();
}
},
hibernateConfigs.keySet());
Collection<Set<Integer>> nodeSetCombinations = Combiner.generateSets(hibernateConfigs.keySet(), NODE_SET_LIMIT);
return Transform.toMap(
new IdentityFunction<Set<Integer>>(),
new Unary<Set<Integer>, SessionFactory>() {
public SessionFactory f(Set<Integer> nodeSet) {
return (nodeSet.size() == 1)
? hibernateSessionFactories.get(Atom.getFirstOrThrow(nodeSet)) // non-sharded
: buildMultiNodeSessionFactory(getSomeNodeConfigurations(nodeSet)); // sharded
}
}, nodeSetCombinations);
}
private ShardedSessionFactory buildAllNodesSessionFactory() {
List<ShardConfiguration> shardConfigs = getNodeConfigurations();
return buildMultiNodeSessionFactory(shardConfigs);
}
private ShardedSessionFactory buildMultiNodeSessionFactory(List<ShardConfiguration> shardConfigs) {
Configuration prototypeConfig = buildPrototypeConfiguration();
ShardedConfiguration shardedConfig = new ShardedConfiguration(prototypeConfig, shardConfigs, buildShardStrategyFactory());
return shardedConfig.buildShardedSessionFactory();
}
private Map<Integer, Configuration> getConfigurationsFromNodes(Hive hive) {
Map<Integer, Configuration> configMap = Maps.newHashMap();
for (Node node : hive.getNodes())
configMap.put(node.getId(), addClassesToConfig(createConfigurationFromNode(node, overrides)));
return configMap;
}
private List<ShardConfiguration> getNodeConfigurations() {
final Map<Integer, Configuration> nodeToHibernateConfigMap = getConfigurationsFromNodes(hive);
List<ShardConfiguration> configs = Lists.newArrayList();
for (Configuration hibernateConfig : nodeToHibernateConfigMap.values())
configs.add(new ConfigurationToShardConfigurationAdapter(hibernateConfig));
return configs;
}
private List<ShardConfiguration> getSomeNodeConfigurations(Set<Integer> nodeIds) {
return new ArrayList<ShardConfiguration>(Filter.grepAgainstList(
nodeIds,
getNodeConfigurations(),
new BinaryPredicate<Integer, ShardConfiguration>() {
@Override
public boolean f(Integer nodeId, ShardConfiguration shardConfiguration) {
return shardConfiguration.getShardId().equals(nodeId);
}
}));
}
private Configuration buildPrototypeConfiguration() {
Configuration hibernateConfig = null;
try {
hibernateConfig = createConfigurationFromNode(Atom.getFirstOrThrow(hive.getNodes()), overrides);
}
catch (Exception e) {
throw new RuntimeException("The hive has no nodes, so it is impossible to build a prototype configuration");
}
addClassesToConfig(hibernateConfig);
hibernateConfig.setProperty("hibernate.session_factory_name", "factory:prototype");
return hibernateConfig;
}
private Configuration addClassesToConfig(Configuration hibernateConfig) {
for (Class<?> clazz : hibernateClasses)
hibernateConfig.addClass(EntityResolver.getPersistedImplementation(clazz));
return hibernateConfig;
}
private EntityHiveConfig buildHiveConfiguration(Collection<Class<?>> classes) {
return new ConfigurationReader(classes).getHiveConfiguration();
}
private ShardStrategyFactory buildShardStrategyFactory() {
return new ShardStrategyFactory() {
public ShardStrategy newShardStrategy(List<ShardId> shardIds) {
return new ShardStrategyImpl(
new HiveShardSelector(config, hive),
new HiveShardResolver(config, hive),
accessStrategy);
}
};
}
public static Configuration createConfigurationFromNode(Node node, Properties overrides) {
Configuration config = new Configuration().configure();
config.setProperty("hibernate.session_factory_name", "factory:" + node.getName());
config.setProperty("hibernate.dialect", dialectMap.get(node.getDialect()).getName());
config.setProperty("hibernate.connection.driver_class", DriverLoader.getDriverClass(node.getDialect()));
config.setProperty("hibernate.connection.url", node.getUri());
config.setProperty("hibernate.connection.shard_id", node.getId().toString());
config.setProperty("hibernate.shard.enable_cross_shard_relationship_checks", "true");
for (Entry<Object, Object> prop : overrides.entrySet())
config.setProperty(prop.getKey().toString(), prop.getValue().toString());
return config;
}
public void update(Observable o, Object arg) {
sync();
}
public boolean sync() {
ShardedSessionFactory newFactory = buildAllNodesSessionFactory();
synchronized (this) {
this.allNodesSessionFactory = newFactory;
}
return true;
}
private static Map<HiveDbDialect, Class<?>> buildDialectMap() {
Map<HiveDbDialect, Class<?>> map = Maps.newHashMap();
map.put(HiveDbDialect.H2, H2Dialect.class);
map.put(HiveDbDialect.MySql, MySQLInnoDBDialect.class);
return map;
}
// ShardedSessionImpl
public Session openAllShardsSession() {
return openAllShardsSession(getDefaultInterceptor());
}
public Session openSession() {
return openAllShardsSession();
}
public Session openSession(Interceptor interceptor) {
return openAllShardsSession(interceptor);
}
private Session openAllShardsSession(Interceptor interceptor) {
return addOpenSessionEvents(allNodesSessionFactory.openSession(interceptor));
}
private Session addOpenSessionEvents(Session session) {
for (Shard shard : ((ShardedSessionImpl) session).getShards()) {
shard.addOpenSessionEvent(new RecordNodeOpenSessionEvent());
}
return session;
}
// SessionImpl
public Session openSession(Object primaryIndexKey) {
return openSession(
getNodeIdsOrThrow(primaryIndexKey),
getDefaultInterceptor());
}
private Collection<Integer> getNodeIdsOrThrow(Object primaryIndexKey) {
final Collection<Integer> nodeIds = hive.directory().getNodeIdsOfPrimaryIndexKey(primaryIndexKey);
if (nodeIds.size() == 0)
throw new HiveKeyNotFoundException(String.format("Primary index key %s was not found on any nodes", primaryIndexKey));
return nodeIds;
}
public Session openSession(Object primaryIndexKey, Interceptor interceptor) {
return openSession(
getNodeIdsOrThrow(primaryIndexKey),
interceptor);
}
public Session openSession(String resource, Object resourceId) {
final Collection<Integer> nodeIdsOfResourceId = hive.directory().getNodeIdsOfResourceId(resource, resourceId);
if (nodeIdsOfResourceId.size() == 0)
throw new UnsupportedOperationException(String.format("No nodes found for resource id %s of resource %s", resourceId, resource));
return openSession(
nodeIdsOfResourceId,
getDefaultInterceptor());
}
public Session openSession(String resource, Object resourceId, Interceptor interceptor) {
return openSession(
hive.directory().getNodeIdsOfResourceId(resource, resourceId),
interceptor);
}
public Session openSession(String resource, String indexName, Object secondaryIndexKey) {
return openSession(
hive.directory().getNodeIdsOfSecondaryIndexKey(resource, indexName, secondaryIndexKey),
getDefaultInterceptor());
}
public Session openSession(String resource, String indexName, Object secondaryIndexKey, Interceptor interceptor) {
return openSession(
hive.directory().getNodeIdsOfSecondaryIndexKey(resource, indexName, secondaryIndexKey),
interceptor);
}
@SuppressWarnings("unchecked")
private Session openSession(Collection<Integer> nodeIds, Interceptor interceptor) {
// We only create SessionFactories for 1 to NODE_SET_LIMIT nodes.
// If more are requested then we delegate to the allNodesSessionFactory
if (nodeIds.size() <= NODE_SET_LIMIT) {
Session session = nodeSessionFactories.get(new HashSet(nodeIds)).openSession(interceptor);
RecordNodeOpenSessionEvent.setNode(session);
return session;
} else {
return allNodesSessionFactory.openSession(interceptor);
}
}
public Interceptor getDefaultInterceptor() {
return new HiveInterceptorDecorator(config, hive);
}
@SuppressWarnings("unchecked")
public SessionFactory getSessionFactory(Integer nodeId) {
return nodeSessionFactories.get(new HashSet(Arrays.asList(nodeId)));
}
}