/**
* JBoss, Home of Professional Open Source Copyright Red Hat, Inc., and individual contributors
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the
* License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific
* language governing permissions and limitations under the License.
*/
package org.jboss.aerogear.simplepush.server.datastore;
import java.net.MalformedURLException;
import java.nio.charset.Charset;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import org.ektorp.BulkDeleteDocument;
import org.ektorp.ViewQuery;
import org.ektorp.ViewResult;
import org.ektorp.ViewResult.Row;
import org.ektorp.http.HttpClient;
import org.ektorp.http.StdHttpClient;
import org.ektorp.impl.StdCouchDbConnector;
import org.ektorp.impl.StdCouchDbInstance;
import org.ektorp.support.DesignDocument;
import org.ektorp.support.DesignDocument.View;
import org.jboss.aerogear.simplepush.protocol.Ack;
import org.jboss.aerogear.simplepush.protocol.impl.AckImpl;
import org.jboss.aerogear.simplepush.server.Channel;
import org.jboss.aerogear.simplepush.server.DefaultChannel;
/**
* DataStore that uses a CouchDB database for storage.
*/
public class CouchDBDataStore implements DataStore {
private static final String UAID_FIELD = "uaid";
private static final String TYPE_FIELD = "type";
private static final String TOKEN_FIELD = "token";
private static final String CHID_FIELD = "chid";
private static final String VERSION_FIELD = "version";
private static final String DOC_FIELD = "doc";
private final HttpClient httpClient;
private final StdCouchDbInstance stdCouchDbInstance;
private final StdCouchDbConnector db;
private final DesignDocument designDocument;
private final static Charset UTF_8 = Charset.forName("UTF-8");
public CouchDBDataStore(final String url, final String dbName) {
try {
httpClient = new StdHttpClient.Builder().url(url).build();
} catch (final MalformedURLException e) {
throw new IllegalStateException(e);
}
stdCouchDbInstance = new StdCouchDbInstance(httpClient);
db = new StdCouchDbConnector(dbName, stdCouchDbInstance);
db.createDatabaseIfNotExists();
designDocument = new DesignDocument("_design/channels");
addView(designDocument, Views.CHANNEL);
addView(designDocument, Views.UAID);
addView(designDocument, Views.TOKEN);
addView(designDocument, Views.UNACKS);
addView(designDocument, Views.SERVER);
if (!db.contains(designDocument.getId())) {
db.create(designDocument);
}
}
private void addView(final DesignDocument doc, final Views view) {
if (!doc.containsView(view.viewName())) {
doc.addView(view.viewName(), new View(view.mapFunction()));
}
}
@Override
public void savePrivateKeySalt(final byte[] salt) {
final byte[] privateKeySalt = getPrivateKeySalt();
if (privateKeySalt.length == 0) {
final Map<String, String> map = new HashMap<String, String>(2);
map.put(TYPE_FIELD, Views.SERVER.viewName());
map.put("salt", new String(salt, UTF_8));
db.create(map);
}
}
@Override
public byte[] getPrivateKeySalt() {
final ViewQuery viewQuery = new ViewQuery().dbPath(db.path()).viewName(Views.SERVER.viewName()).designDocId(designDocument.getId());
final ViewResult viewResult = db.queryView(viewQuery);
if (viewResult.isEmpty()) {
return new byte[]{};
}
final Row row = viewResult.getRows().get(0);
return row.getKeyAsNode().get("salt").asText().getBytes(UTF_8);
}
@Override
public boolean saveChannel(final Channel channel) {
db.create(channelAsMap(channel));
return true;
}
private Map<String, String> channelAsMap(final Channel channel) {
final Map<String, String> map = new HashMap<String, String>(5);
map.put(UAID_FIELD, channel.getUAID());
map.put(TYPE_FIELD, Views.CHANNEL.viewName());
map.put(TOKEN_FIELD, channel.getEndpointToken());
map.put(CHID_FIELD, channel.getChannelId());
map.put(VERSION_FIELD, Long.toString(channel.getVersion()));
return map;
}
@Override
public Channel getChannel(final String channelId) throws ChannelNotFoundException {
return channelFromJson(getChannelJson(channelId));
}
private JsonNode getChannelJson(final String channelId) throws ChannelNotFoundException {
final ViewResult viewResult = db.queryView(query(Views.CHANNEL.viewName(), channelId));
final List<Row> rows = viewResult.getRows();
if (rows.isEmpty()) {
throw new ChannelNotFoundException("Cound not find channel", channelId);
}
if (rows.size() > 1) {
throw new IllegalStateException("There should not be multiple channelId with the same id");
}
return rows.get(0).getValueAsNode();
}
private Channel channelFromJson(final JsonNode node) {
final JsonNode doc = node.get("doc");
return new DefaultChannel(doc.get(UAID_FIELD).asText(),
doc.get(CHID_FIELD).asText(),
doc.get(VERSION_FIELD).asLong(),
doc.get(TOKEN_FIELD).asText());
}
@Override
public void removeChannels(final String uaid) {
final ViewResult viewResult = db.queryView(query(Views.UAID.viewName(), uaid));
final List<Row> rows = viewResult.getRows();
final Set<String> channelIds = new HashSet<String>(rows.size());
for (Row row : rows) {
final JsonNode json = row.getValueAsNode().get(DOC_FIELD);
channelIds.add(json.get(CHID_FIELD).asText());
}
removeChannels(channelIds);
}
private ViewQuery query(final String viewName, final String key) {
return new ViewQuery()
.dbPath(db.path())
.viewName(viewName)
.designDocId(designDocument.getId())
.key(key);
}
@Override
public void removeChannels(final Set<String> channelIds) {
final ViewResult viewResult = db.queryView(channelsQuery(channelIds));
final List<Row> rows = viewResult.getRows();
final Collection<BulkDeleteDocument> removals = new LinkedHashSet<BulkDeleteDocument>();
for (Row row : rows) {
final JsonNode json = row.getValueAsNode();
removals.add(BulkDeleteDocument.of(json.get(DOC_FIELD)));
}
db.executeBulk(removals);
}
private ViewQuery channelsQuery(final Set<String> keys) {
return new ViewQuery()
.dbPath(db.path())
.viewName(Views.CHANNEL.viewName())
.designDocId(designDocument.getId())
.keys(keys);
}
@Override
public Set<String> getChannelIds(final String uaid) {
final ViewResult viewResult = db.queryView(query(Views.UAID.viewName(), uaid));
final List<Row> rows = viewResult.getRows();
if (rows.isEmpty()) {
return Collections.emptySet();
}
final Set<String> channelIds = new HashSet<String> (rows.size());
for (Row row : rows) {
channelIds.add(row.getValueAsNode().get(DOC_FIELD).get(CHID_FIELD).asText());
}
return channelIds;
}
@Override
public String updateVersion(final String endpointToken, final long version) throws VersionException, ChannelNotFoundException {
final ViewResult viewResult = db.queryView(query(Views.TOKEN.viewName(), endpointToken));
final List<Row> rows = viewResult.getRows();
if (rows.isEmpty()) {
throw new ChannelNotFoundException("Cound not find channel for endpointToken", endpointToken);
}
final ObjectNode node = (ObjectNode) rows.get(0).getValueAsNode().get(DOC_FIELD);
final long currentVersion = node.get(VERSION_FIELD).asLong();
if (version <= currentVersion) {
throw new VersionException("version [" + version + "] must be greater than the current version [" + currentVersion + "]");
}
node.put(VERSION_FIELD, String.valueOf(version));
db.update(node);
return node.get(CHID_FIELD).asText();
}
@Override
public String saveUnacknowledged(final String channelId, final long version) throws ChannelNotFoundException {
final JsonNode json = getChannelJson(channelId);
final Map<String, String> unack = docToAckMap((ObjectNode) json.get(DOC_FIELD), version);
db.create(unack);
return unack.get(UAID_FIELD);
}
private Map<String, String> docToAckMap(final ObjectNode doc, final long version) {
final String uaid = doc.get(UAID_FIELD).asText();
final String chid = doc.get(CHID_FIELD).asText();
final String token = doc.get(TOKEN_FIELD).asText();
final Map<String, String> map = new HashMap<String, String>(5);
map.put(UAID_FIELD, uaid);
map.put(TYPE_FIELD, "ack");
map.put(TOKEN_FIELD, token);
map.put(CHID_FIELD, chid);
map.put(VERSION_FIELD, Long.toString(version));
return map;
}
@Override
public Set<Ack> getUnacknowledged(final String uaid) {
final ViewResult viewResult = db.queryView(query(Views.UNACKS.viewName(), uaid));
return rowsToAcks(viewResult.getRows());
}
@Override
public Set<Ack> removeAcknowledged(final String uaid, final Set<Ack> acked) {
final ViewResult viewResult = db.queryView(query(Views.UNACKS.viewName(), uaid));
final List<Row> rows = viewResult.getRows();
final Collection<BulkDeleteDocument> removals = new LinkedHashSet<BulkDeleteDocument>();
for (Iterator<Row> iter = rows.iterator(); iter.hasNext(); ) {
final Row row = iter.next();
final JsonNode json = row.getValueAsNode();
final JsonNode doc = json.get(DOC_FIELD);
final String channelId = doc.get(CHID_FIELD).asText();
for (Ack ack : acked) {
if (ack.getChannelId().equals(channelId)) {
removals.add(BulkDeleteDocument.of(doc));
iter.remove();
}
}
}
db.executeBulk(removals);
return rowsToAcks(rows);
}
private Set<Ack> rowsToAcks(final List<Row> rows) {
if (rows.isEmpty()) {
return Collections.emptySet();
}
final Set<Ack> unacks = new HashSet<Ack>(rows.size());
for (Row row : rows) {
final JsonNode json = row.getValueAsNode().get(DOC_FIELD);
unacks.add(new AckImpl(json.get(CHID_FIELD).asText(), json.get(VERSION_FIELD).asLong()));
}
return unacks;
}
}