/*
*
* Copyright 2013 Netflix, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package com.netflix.zeno.fastblob.state;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.concurrent.atomic.AtomicLongArray;
import java.util.concurrent.atomic.AtomicReference;
/**
* This is a lock-free, thread-safe version of a {@link java.util.BitSet}.<p/>
*
* Instead of a long array to hold the bits, this implementation uses an AtomicLongArray, then
* does the appropriate compare-and-swap operations when setting the bits.
*
* @author dkoszewnik
*
*/
public class ThreadSafeBitSet {
private final int numLongsPerSegment;
private final int log2SegmentSize;
private final int segmentMask;
private final AtomicReference<ThreadSafeBitSetSegments> segments;
public ThreadSafeBitSet() {
this(14); /// 16384 bits, 2048 bytes, 256 longs per segment
}
public ThreadSafeBitSet(int log2SegmentSizeInBits) {
if(log2SegmentSizeInBits < 6)
throw new IllegalArgumentException("Cannot specify fewer than 64 bits in each segment!");
this.log2SegmentSize = log2SegmentSizeInBits;
this.numLongsPerSegment = (1 << (log2SegmentSizeInBits - 6));
this.segmentMask = numLongsPerSegment - 1;
segments = new AtomicReference<ThreadSafeBitSetSegments>();
segments.set(new ThreadSafeBitSetSegments(1, numLongsPerSegment));
}
private ThreadSafeBitSet(ThreadSafeBitSetSegments segments, int log2SegmentSizeInBits) {
this.log2SegmentSize = log2SegmentSizeInBits;
this.numLongsPerSegment = (1 << (log2SegmentSizeInBits - 6));
this.segmentMask = numLongsPerSegment - 1;
this.segments = new AtomicReference<ThreadSafeBitSetSegments>();
this.segments.set(segments);
}
public void set(int position) {
int segmentPosition = position >>> log2SegmentSize; /// which segment -- div by num bits per segment
int longPosition = (position >>> 6) & segmentMask; /// which long in the segment -- remainder of div by num bits per segment
int bitPosition = position & 0x3F; /// which bit in the long -- remainder of div by num bits in long (64)
AtomicLongArray segment = getSegment(segmentPosition);
long mask = 1L << bitPosition;
// Thread safety: we need to loop until we win the race to set the long value.
while(true) {
// determine what the new long value will be after we set the appropriate bit.
long currentLongValue = segment.get(longPosition);
long newLongValue = currentLongValue | mask;
// if no other thread has modified the value since we read it, we won the race and we are done.
if(segment.compareAndSet(longPosition, currentLongValue, newLongValue))
break;
}
}
public boolean get(int position) {
int segmentPosition = position >>> log2SegmentSize; /// which segment -- div by num bits per segment
int longPosition = (position >>> 6) & segmentMask; /// which long in the segment -- remainder of div by num bits per segment
int bitPosition = position & 0x3F; /// which bit in the long -- remainder of div by num bits in long (64)
AtomicLongArray segment = getSegment(segmentPosition);
long mask = 1L << bitPosition;
return ((segment.get(longPosition) & mask) != 0);
}
/**
* @return the number of bits which are set in this bit set.
*/
public int cardinality() {
ThreadSafeBitSetSegments segments = this.segments.get();
int numSetBits = 0;
for(int i=0;i<segments.numSegments();i++) {
AtomicLongArray segment = segments.getSegment(i);
for(int j=0;j<segment.length();j++) {
numSetBits += Long.bitCount(segment.get(j));
}
}
return numSetBits;
}
/**
* @return the number of bits which are current specified by this bit set. This is the maximum value
* to which you might need to iterate, if you were to iterate over all bits in this set.
*/
public int currentCapacity() {
return segments.get().numSegments() * (1 << log2SegmentSize);
}
/**
* Clear all bits to 0.
*/
public void clearAll() {
ThreadSafeBitSetSegments segments = this.segments.get();
for(int i=0;i<segments.numSegments();i++) {
AtomicLongArray segment = segments.getSegment(i);
for(int j=0;j<segment.length();j++) {
segment.set(j, 0L);
}
}
}
/**
* Return a new bit set which contains all bits which are contained in this bit set, and which are NOT contained in the <code>other</code> bit set.<p/>
*
* In other words, return a new bit set, which is a bitwise and with the bitwise not of the other bit set.
*
* @param other
* @return
*/
public ThreadSafeBitSet andNot(ThreadSafeBitSet other) {
if(other.log2SegmentSize != log2SegmentSize)
throw new IllegalArgumentException("Segment sizes must be the same");
ThreadSafeBitSetSegments thisSegments = this.segments.get();
ThreadSafeBitSetSegments otherSegments = other.segments.get();
ThreadSafeBitSetSegments newSegments = new ThreadSafeBitSetSegments(thisSegments.numSegments(), numLongsPerSegment);
for(int i=0;i<thisSegments.numSegments();i++) {
AtomicLongArray thisArray = thisSegments.getSegment(i);
AtomicLongArray otherArray = (i < otherSegments.numSegments()) ? otherSegments.getSegment(i) : null;
AtomicLongArray newArray = newSegments.getSegment(i);
for(int j=0;j<thisArray.length();j++) {
long thisLong = thisArray.get(j);
long otherLong = (otherArray == null) ? 0 : otherArray.get(j);
newArray.set(j, thisLong & ~otherLong);
}
}
ThreadSafeBitSet andNot = new ThreadSafeBitSet(log2SegmentSize);
andNot.segments.set(newSegments);
return andNot;
}
/**
* Return a new bit set which contains all bits which are contained in *any* of the specified bit sets.
*
* @param bitSets
* @return
*/
public static ThreadSafeBitSet orAll(ThreadSafeBitSet... bitSets) {
if(bitSets.length == 0)
return new ThreadSafeBitSet();
int log2SegmentSize = bitSets[0].log2SegmentSize;
int numLongsPerSegment = bitSets[0].numLongsPerSegment;
ThreadSafeBitSetSegments segments[] = new ThreadSafeBitSetSegments[bitSets.length];
int maxNumSegments = 0;
for(int i=0;i<bitSets.length;i++) {
if(bitSets[i].log2SegmentSize != log2SegmentSize)
throw new IllegalArgumentException("Segment sizes must be the same");
segments[i] = bitSets[i].segments.get();
if(segments[i].numSegments() > maxNumSegments)
maxNumSegments = segments[i].numSegments();
}
ThreadSafeBitSetSegments newSegments = new ThreadSafeBitSetSegments(maxNumSegments, numLongsPerSegment);
AtomicLongArray segment[] = new AtomicLongArray[segments.length];
for(int i=0;i<maxNumSegments;i++) {
for(int j=0;j<segments.length;j++) {
segment[j] = i < segments[j].numSegments() ? segments[j].getSegment(i) : null;
}
AtomicLongArray newSegment = newSegments.getSegment(i);
for(int j=0;j<numLongsPerSegment;j++) {
long value = 0;
for(int k=0;k<segments.length;k++) {
if(segment[k] != null)
value |= segment[k].get(j);
}
newSegment.set(j, value);
}
}
ThreadSafeBitSet or = new ThreadSafeBitSet(log2SegmentSize);
or.segments.set(newSegments);
return or;
}
/**
* Get the segment at <code>segmentIndex</code>. If this segment does not yet exist, create it.
*
* @param segmentIndex
* @return
*/
private AtomicLongArray getSegment(int segmentIndex) {
ThreadSafeBitSetSegments visibleSegments = segments.get();
while(visibleSegments.numSegments() <= segmentIndex) {
/// Thread safety: newVisibleSegments contains all of the segments from the currently visible segments, plus extra.
/// all of the segments in the currently visible segments are canonical and will not change.
ThreadSafeBitSetSegments newVisibleSegments = new ThreadSafeBitSetSegments(visibleSegments, segmentIndex + 1, numLongsPerSegment);
/// because we are using a compareAndSet, if this thread "wins the race" and successfully sets this variable, then the segments
/// which are newly defined in newVisibleSegments become canonical.
if(segments.compareAndSet(visibleSegments, newVisibleSegments)) {
visibleSegments = newVisibleSegments;
} else {
/// If we "lose the race" and are growing the ThreadSafeBitSet segments larger,
/// then we will gather the new canonical sets from the update which we missed on the next iteration of this loop.
/// Newly defined segments in newVisibleSegments will be discarded, they do not get to become canonical.
visibleSegments = segments.get();
}
}
return visibleSegments.getSegment(segmentIndex);
}
private static class ThreadSafeBitSetSegments {
private final AtomicLongArray segments[];
private ThreadSafeBitSetSegments(int numSegments, int segmentLength) {
AtomicLongArray segments[] = new AtomicLongArray[numSegments];
for(int i=0;i<numSegments;i++) {
segments[i] = new AtomicLongArray(segmentLength);
}
/// Thread safety: Because this.segments is final, the preceding operations in this constructor are guaranteed to be visible to any
/// other thread which accesses this.segments.
this.segments = segments;
}
private ThreadSafeBitSetSegments(ThreadSafeBitSetSegments copyFrom, int numSegments, int segmentLength) {
AtomicLongArray segments[] = new AtomicLongArray[numSegments];
for(int i=0;i<numSegments;i++) {
segments[i] = i < copyFrom.numSegments() ? copyFrom.getSegment(i) : new AtomicLongArray(segmentLength);
}
/// see above re: thread-safety of this assignment
this.segments = segments;
}
public int numSegments() {
return segments.length;
}
public AtomicLongArray getSegment(int index) {
return segments[index];
}
}
/**
* Serialize this ThreadSafeBitSet to an OutputStream
*/
public void serializeTo(DataOutputStream os) throws IOException {
os.write(log2SegmentSize);
ThreadSafeBitSetSegments segments = this.segments.get();
os.writeInt(segments.numSegments());
for(int i=0;i<segments.numSegments();i++) {
AtomicLongArray arr = segments.getSegment(i);
for(int j=0;j<arr.length();j++) {
os.writeLong(arr.get(j));
}
}
}
/**
* Deserialize a ThreadSafeBitSet from an InputStream
*/
public static ThreadSafeBitSet deserializeFrom(DataInputStream dis) throws IOException {
int log2SegmentSize = dis.read();
int numLongsPerSegment = (1 << (log2SegmentSize - 6));
int numSegments = dis.readInt();
ThreadSafeBitSetSegments segments = new ThreadSafeBitSetSegments(numSegments, numLongsPerSegment);
for(int i=0;i<segments.numSegments();i++) {
AtomicLongArray arr = segments.getSegment(i);
for(int j=0;j<numLongsPerSegment;j++) {
arr.set(j, dis.readLong());
}
}
return new ThreadSafeBitSet(segments, log2SegmentSize);
}
}