/**
* JBoss, Home of Professional Open Source
* Copyright Red Hat, Inc., and individual contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.jboss.aerogear.simplepush.server.datastore;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import javax.persistence.EntityManager;
import javax.persistence.NoResultException;
import javax.persistence.Persistence;
import javax.persistence.Query;
import javax.persistence.TypedQuery;
import org.jboss.aerogear.simplepush.protocol.Ack;
import org.jboss.aerogear.simplepush.protocol.impl.AckImpl;
import org.jboss.aerogear.simplepush.server.Channel;
import org.jboss.aerogear.simplepush.server.DefaultChannel;
import org.jboss.aerogear.simplepush.server.datastore.model.AckDTO;
import org.jboss.aerogear.simplepush.server.datastore.model.ChannelDTO;
import org.jboss.aerogear.simplepush.server.datastore.model.Server;
import org.jboss.aerogear.simplepush.server.datastore.model.UserAgentDTO;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A DataStore implementation that use Java Persistence API (JPA) to store data for the SimplePush Server.
*/
public final class JpaDataStore implements DataStore {
private final Logger logger = LoggerFactory.getLogger(JpaDataStore.class);
private final JpaExecutor jpaExecutor;
private final static Charset UTF_8 = Charset.forName("UTF-8");
/**
* Sole constructor.
*
* @param persistenceUnit the name of the persistence unit to be used.
*/
public JpaDataStore(final String persistenceUnit) {
jpaExecutor = new JpaExecutor(Persistence.createEntityManagerFactory(persistenceUnit));
}
@Override
public void savePrivateKeySalt(final byte[] salt) {
final byte[] privateKeySalt = getPrivateKeySalt();
if (privateKeySalt.length != 0) {
return;
}
final JpaOperation<Void> saveSalt = new JpaOperation<Void>() {
@Override
public Void perform(final EntityManager em) {
em.persist(new Server(new String(salt, UTF_8)));
return null;
}
};
jpaExecutor.execute(saveSalt);
}
@Override
public byte[] getPrivateKeySalt() {
final JpaOperation<byte[]> saveChannel = new JpaOperation<byte[]>() {
@Override
public byte[] perform(final EntityManager em) {
final Query query = em.createQuery("SELECT s FROM Server s");
final Server server = (Server) query.getSingleResult();
return server.getSalt().getBytes(UTF_8);
}
};
try {
return jpaExecutor.execute(saveChannel);
} catch (final Exception e) {
if (! (e instanceof NoResultException)) {
logger.debug("Exception while trying to find a the servers salt");
}
return new byte[]{};
}
}
@Override
public boolean saveChannel(final Channel channel) {
final JpaOperation<Boolean> saveChannel = new JpaOperation<Boolean>() {
@Override
public Boolean perform(final EntityManager em) {
UserAgentDTO userAgent = em.find(UserAgentDTO.class, channel.getUAID());
if (userAgent == null) {
userAgent = new UserAgentDTO(channel.getUAID());
}
userAgent.addChannel(channel.getChannelId(), channel.getVersion(), channel.getEndpointToken());
em.merge(userAgent);
return Boolean.TRUE;
}
};
try {
return jpaExecutor.execute(saveChannel);
} catch (final Exception e) {
logger.error("Could not save channel [" + channel.getChannelId() + "]", e);
return false;
}
}
@Override
public Channel getChannel(final String channelId) throws ChannelNotFoundException {
final JpaOperation<ChannelDTO> findChannel = new JpaOperation<ChannelDTO>() {
@Override
public ChannelDTO perform(EntityManager em) {
return em.find(ChannelDTO.class, channelId);
}
};
final ChannelDTO dto = jpaExecutor.execute(findChannel);
if (dto == null) {
throw new ChannelNotFoundException("No Channel for [" + channelId + "] was found", channelId);
}
return new DefaultChannel(dto.getUserAgent().getUaid(), dto.getChannelId(), dto.getVersion(), dto.getEndpointToken());
}
@Override
public void removeChannels(final Set<String> channelIds) {
if (channelIds == null || channelIds.isEmpty()) {
return;
}
final JpaOperation<Integer> removeChannel = new JpaOperation<Integer>() {
@Override
public Integer perform(EntityManager em) {
final Query delete = em.createQuery("DELETE from ChannelDTO c where c.channelId in (:channelIds)");
delete.setParameter("channelIds", channelIds);
return delete.executeUpdate();
}
};
jpaExecutor.execute(removeChannel);
}
@Override
public Set<String> getChannelIds(final String uaid) {
final JpaOperation<Set<String>> getChannelIds = new JpaOperation<Set<String>>() {
@Override
public Set<String> perform(final EntityManager em) {
final Set<String> channels = new HashSet<String>();
final UserAgentDTO userAgent = em.find(UserAgentDTO.class, uaid);
if (userAgent != null) {
for (ChannelDTO dto : userAgent.getChannels()) {
channels.add(dto.getChannelId());
}
}
return channels;
}
};
return jpaExecutor.execute(getChannelIds);
}
@Override
public void removeChannels(final String uaid) {
final JpaOperation<Void> removeChannels = new JpaOperation<Void>() {
@Override
public Void perform(final EntityManager em) {
final UserAgentDTO userAgent = em.find(UserAgentDTO.class, uaid);
if (userAgent != null) {
final Set<ChannelDTO> channels = userAgent.getChannels();
for (ChannelDTO channelDTO : channels) {
em.remove(channelDTO);
}
channels.clear();
userAgent.setChannels(channels);
}
return null;
}
};
jpaExecutor.execute(removeChannels);
logger.debug("Deleted all channels for UserAgent [" + uaid + "]");
}
@Override
public String updateVersion(final String endpointToken, final long version) throws VersionException, ChannelNotFoundException {
final JpaOperation<ChannelDTO> updateVersion = new JpaOperation<ChannelDTO>() {
@Override
public ChannelDTO perform(final EntityManager em) {
final TypedQuery<ChannelDTO> select = em.createQuery("SELECT c FROM ChannelDTO c where c.endpointToken = :endpointToken", ChannelDTO.class);
select.setParameter("endpointToken", endpointToken);
final List<ChannelDTO> resultList = select.getResultList();
if (resultList.isEmpty()) {
return null;
}
final ChannelDTO channelDTO = resultList.get(0);
if (channelDTO != null) {
if (version > channelDTO.getVersion()) {
channelDTO.setVersion(version);
em.merge(channelDTO);
} else {
throw new VersionException("New version [" + version + "] must be greater than current version [" + channelDTO.getVersion() + "]");
}
}
return channelDTO;
}
};
try {
final ChannelDTO channelDto = jpaExecutor.execute(updateVersion);
if (channelDto == null) {
throw new ChannelNotFoundException("No Channel for endpointToken [" + endpointToken + "] was found", endpointToken);
}
return channelDto.getChannelId();
} catch (final JpaException e) {
final Throwable cause = e.getCause();
if (cause instanceof VersionException) {
throw (VersionException) cause;
}
throw e;
}
}
@Override
public String saveUnacknowledged(final String channelId, final long version) throws ChannelNotFoundException {
final JpaOperation<String> saveAcks = new JpaOperation<String>() {
@Override
public String perform(final EntityManager em) {
final ChannelDTO channel = em.find(ChannelDTO.class, channelId);
final UserAgentDTO userAgent = channel.getUserAgent();
final Set<AckDTO> dtos = new HashSet<AckDTO>();
dtos.add(new AckDTO(userAgent, channel.getChannelId(), version));
userAgent.setAcks(dtos);
em.merge(userAgent);
return userAgent.getUaid();
}
};
return jpaExecutor.execute(saveAcks);
}
@Override
public Set<Ack> getUnacknowledged(final String uaid) {
final JpaOperation<Set<Ack>> getUnacks = new JpaOperation<Set<Ack>>() {
@Override
public Set<Ack> perform(final EntityManager em) {
final UserAgentDTO userAgent = em.find(UserAgentDTO.class, uaid);
if (userAgent == null) {
return Collections.emptySet();
}
final HashSet<Ack> acks = new HashSet<Ack>();
for (AckDTO ackDTO : userAgent.getAcks()) {
acks.add(new AckImpl(ackDTO.getChannelId(), ackDTO.getVersion()));
}
return acks;
}
};
return jpaExecutor.execute(getUnacks);
}
@Override
public Set<Ack> removeAcknowledged(final String uaid, final Set<Ack> acked) {
final JpaOperation<Set<Ack>> removeAck = new JpaOperation<Set<Ack>>() {
@Override
public Set<Ack> perform(final EntityManager em) {
final List<String> channelIds = new ArrayList<String>(acked.size());
for (Ack ack : acked) {
channelIds.add(ack.getChannelId());
}
final Query delete = em.createQuery("DELETE from AckDTO c where c.channelId in (:channelIds)");
delete.setParameter("channelIds", channelIds);
delete.executeUpdate();
final UserAgentDTO userAgent = em.find(UserAgentDTO.class, uaid);
final Set<AckDTO> acks = userAgent.getAcks();
final Set<Ack> unacked = new HashSet<Ack>(acks.size());
for (AckDTO ackDto : acks) {
unacked.add(new AckImpl(ackDto.getChannelId(), ackDto.getVersion()));
}
return unacked;
}
};
return jpaExecutor.execute(removeAck);
}
}