package storm.starter.bolt;
import backtype.storm.Config;
import backtype.storm.generated.GlobalStreamId;
import backtype.storm.task.OutputCollector;
import backtype.storm.task.TopologyContext;
import backtype.storm.topology.OutputFieldsDeclarer;
import backtype.storm.topology.base.BaseRichBolt;
import backtype.storm.tuple.Fields;
import backtype.storm.tuple.Tuple;
import backtype.storm.utils.TimeCacheMap;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class SingleJoinBolt extends BaseRichBolt {
OutputCollector _collector;
Fields _idFields;
Fields _outFields;
int _numSources;
TimeCacheMap<List<Object>, Map<GlobalStreamId, Tuple>> _pending;
Map<String, GlobalStreamId> _fieldLocations;
public SingleJoinBolt(Fields outFields) {
_outFields = outFields;
}
@Override
public void prepare(Map conf, TopologyContext context, OutputCollector collector) {
_fieldLocations = new HashMap<String, GlobalStreamId>();
_collector = collector;
int timeout = ((Number) conf.get(Config.TOPOLOGY_MESSAGE_TIMEOUT_SECS)).intValue();
_pending = new TimeCacheMap<List<Object>, Map<GlobalStreamId, Tuple>>(timeout, new ExpireCallback());
_numSources = context.getThisSources().size();
Set<String> idFields = null;
for(GlobalStreamId source: context.getThisSources().keySet()) {
Fields fields = context.getComponentOutputFields(source.get_componentId(), source.get_streamId());
Set<String> setFields = new HashSet<String>(fields.toList());
if(idFields==null) idFields = setFields;
else idFields.retainAll(setFields);
for(String outfield: _outFields) {
for(String sourcefield: fields) {
if(outfield.equals(sourcefield)) {
_fieldLocations.put(outfield, source);
}
}
}
}
_idFields = new Fields(new ArrayList<String>(idFields));
if(_fieldLocations.size()!=_outFields.size()) {
throw new RuntimeException("Cannot find all outfields among sources");
}
}
@Override
public void execute(Tuple tuple) {
List<Object> id = tuple.select(_idFields);
GlobalStreamId streamId = new GlobalStreamId(tuple.getSourceComponent(), tuple.getSourceStreamId());
if(!_pending.containsKey(id)) {
_pending.put(id, new HashMap<GlobalStreamId, Tuple>());
}
Map<GlobalStreamId, Tuple> parts = _pending.get(id);
if(parts.containsKey(streamId)) throw new RuntimeException("Received same side of single join twice");
parts.put(streamId, tuple);
if(parts.size()==_numSources) {
_pending.remove(id);
List<Object> joinResult = new ArrayList<Object>();
for(String outField: _outFields) {
GlobalStreamId loc = _fieldLocations.get(outField);
joinResult.add(parts.get(loc).getValueByField(outField));
}
_collector.emit(new ArrayList<Tuple>(parts.values()), joinResult);
for(Tuple part: parts.values()) {
_collector.ack(part);
}
}
}
@Override
public void declareOutputFields(OutputFieldsDeclarer declarer) {
declarer.declare(_outFields);
}
private class ExpireCallback implements TimeCacheMap.ExpiredCallback<List<Object>, Map<GlobalStreamId, Tuple>> {
@Override
public void expire(List<Object> id, Map<GlobalStreamId, Tuple> tuples) {
for(Tuple tuple: tuples.values()) {
_collector.fail(tuple);
}
}
}
}