/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.zookeeper;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import java.util.Set;
import org.apache.zookeeper.ClientCnxn.EndOfStreamException;
import org.apache.zookeeper.ClientCnxn.Packet;
import org.apache.zookeeper.ZooDefs.OpCode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class ClientCnxnSocketNIO extends ClientCnxnSocket {
private static final Logger LOG = LoggerFactory
.getLogger(ClientCnxnSocketNIO.class);
private final Selector selector = Selector.open();
private SelectionKey sockKey;
private SocketAddress localSocketAddress;
private SocketAddress remoteSocketAddress;
ClientCnxnSocketNIO() throws IOException {
super();
}
@Override
boolean isConnected() {
return sockKey != null;
}
/**
* @return true if a packet was received
* @throws InterruptedException
* @throws IOException
*/
void doIO(List<Packet> pendingQueue, LinkedList<Packet> outgoingQueue, ClientCnxn cnxn)
throws InterruptedException, IOException {
SocketChannel sock = (SocketChannel) sockKey.channel();
if (sock == null) {
throw new IOException("Socket is null!");
}
if (sockKey.isReadable()) {
int rc = sock.read(incomingBuffer);
if (rc < 0) {
throw new EndOfStreamException(
"Unable to read additional data from server sessionid 0x"
+ Long.toHexString(sessionId)
+ ", likely server has closed socket");
}
if (!incomingBuffer.hasRemaining()) {
incomingBuffer.flip();
if (incomingBuffer == lenBuffer) {
recvCount++;
readLength();
} else if (!initialized) {
readConnectResult();
enableRead();
if (findSendablePacket(outgoingQueue,
cnxn.sendThread.clientTunneledAuthenticationInProgress()) != null) {
// Since SASL authentication has completed (if client is configured to do so),
// outgoing packets waiting in the outgoingQueue can now be sent.
enableWrite();
}
lenBuffer.clear();
incomingBuffer = lenBuffer;
updateLastHeard();
initialized = true;
} else {
sendThread.readResponse(incomingBuffer);
lenBuffer.clear();
incomingBuffer = lenBuffer;
updateLastHeard();
}
}
}
if (sockKey.isWritable()) {
synchronized(outgoingQueue) {
Packet p = findSendablePacket(outgoingQueue,
cnxn.sendThread.clientTunneledAuthenticationInProgress());
if (p != null) {
updateLastSend();
// If we already started writing p, p.bb will already exist
if (p.bb == null) {
if ((p.requestHeader != null) &&
(p.requestHeader.getType() != OpCode.ping) &&
(p.requestHeader.getType() != OpCode.auth)) {
p.requestHeader.setXid(cnxn.getXid());
}
p.createBB();
}
sock.write(p.bb);
if (!p.bb.hasRemaining()) {
sentCount++;
outgoingQueue.removeFirstOccurrence(p);
if (p.requestHeader != null
&& p.requestHeader.getType() != OpCode.ping
&& p.requestHeader.getType() != OpCode.auth) {
synchronized (pendingQueue) {
pendingQueue.add(p);
}
}
}
}
if (outgoingQueue.isEmpty()) {
// No more packets to send: turn off write interest flag.
// Will be turned on later by a later call to enableWrite(),
// from within ZooKeeperSaslClient (if client is configured
// to attempt SASL authentication), or in either doIO() or
// in doTransport() if not.
disableWrite();
} else if (!initialized && p != null && !p.bb.hasRemaining()) {
// On initial connection, write the complete connect request
// packet, but then disable further writes until after
// receiving a successful connection response. If the
// session is expired, then the server sends the expiration
// response and immediately closes its end of the socket. If
// the client is simultaneously writing on its end, then the
// TCP stack may choose to abort with RST, in which case the
// client would never receive the session expired event. See
// http://docs.oracle.com/javase/6/docs/technotes/guides/net/articles/connection_release.html
disableWrite();
} else {
// Just in case
enableWrite();
}
}
}
}
private Packet findSendablePacket(LinkedList<Packet> outgoingQueue,
boolean clientTunneledAuthenticationInProgress) {
synchronized (outgoingQueue) {
if (outgoingQueue.isEmpty()) {
return null;
}
if (outgoingQueue.getFirst().bb != null // If we've already starting sending the first packet, we better finish
|| !clientTunneledAuthenticationInProgress) {
return outgoingQueue.getFirst();
}
// Since client's authentication with server is in progress,
// send only the null-header packet queued by primeConnection().
// This packet must be sent so that the SASL authentication process
// can proceed, but all other packets should wait until
// SASL authentication completes.
ListIterator<Packet> iter = outgoingQueue.listIterator();
while (iter.hasNext()) {
Packet p = iter.next();
if (p.requestHeader == null) {
// We've found the priming-packet. Move it to the beginning of the queue.
iter.remove();
outgoingQueue.add(0, p);
return p;
} else {
// Non-priming packet: defer it until later, leaving it in the queue
// until authentication completes.
if (LOG.isDebugEnabled()) {
LOG.debug("deferring non-priming packet: " + p +
"until SASL authentication completes.");
}
}
}
// no sendable packet found.
return null;
}
}
@Override
void cleanup() {
if (sockKey != null) {
SocketChannel sock = (SocketChannel) sockKey.channel();
sockKey.cancel();
try {
sock.socket().shutdownInput();
} catch (IOException e) {
if (LOG.isDebugEnabled()) {
LOG.debug("Ignoring exception during shutdown input", e);
}
}
try {
sock.socket().shutdownOutput();
} catch (IOException e) {
if (LOG.isDebugEnabled()) {
LOG.debug("Ignoring exception during shutdown output",
e);
}
}
try {
sock.socket().close();
} catch (IOException e) {
if (LOG.isDebugEnabled()) {
LOG.debug("Ignoring exception during socket close", e);
}
}
try {
sock.close();
} catch (IOException e) {
if (LOG.isDebugEnabled()) {
LOG.debug("Ignoring exception during channel close", e);
}
}
}
try {
Thread.sleep(100);
} catch (InterruptedException e) {
if (LOG.isDebugEnabled()) {
LOG.debug("SendThread interrupted during sleep, ignoring");
}
}
sockKey = null;
}
@Override
void close() {
try {
if (LOG.isTraceEnabled()) {
LOG.trace("Doing client selector close");
}
selector.close();
if (LOG.isTraceEnabled()) {
LOG.trace("Closed client selector");
}
} catch (IOException e) {
LOG.warn("Ignoring exception during selector close", e);
}
}
/**
* create a socket channel.
* @return the created socket channel
* @throws IOException
*/
SocketChannel createSock() throws IOException {
SocketChannel sock;
sock = SocketChannel.open();
sock.configureBlocking(false);
sock.socket().setSoLinger(false, -1);
sock.socket().setTcpNoDelay(true);
return sock;
}
/**
* register with the selection and connect
* @param sock the {@link SocketChannel}
* @param addr the address of remote host
* @throws IOException
*/
void registerAndConnect(SocketChannel sock, InetSocketAddress addr)
throws IOException {
sockKey = sock.register(selector, SelectionKey.OP_CONNECT);
boolean immediateConnect = sock.connect(addr);
if (immediateConnect) {
sendThread.primeConnection();
}
}
@Override
void connect(InetSocketAddress addr) throws IOException {
SocketChannel sock = createSock();
try {
registerAndConnect(sock, addr);
} catch (IOException e) {
LOG.error("Unable to open socket to " + addr);
sock.close();
throw e;
}
initialized = false;
/*
* Reset incomingBuffer
*/
lenBuffer.clear();
incomingBuffer = lenBuffer;
}
/**
* Returns the address to which the socket is connected.
*
* @return ip address of the remote side of the connection or null if not
* connected
*/
@Override
SocketAddress getRemoteSocketAddress() {
return remoteSocketAddress;
}
/**
* Returns the local address to which the socket is bound.
*
* @return ip address of the remote side of the connection or null if not
* connected
*/
@Override
SocketAddress getLocalSocketAddress() {
return localSocketAddress;
}
private void updateSocketAddresses() {
Socket socket = ((SocketChannel) sockKey.channel()).socket();
localSocketAddress = socket.getLocalSocketAddress();
remoteSocketAddress = socket.getRemoteSocketAddress();
}
@Override
synchronized void wakeupCnxn() {
selector.wakeup();
}
@Override
void doTransport(int waitTimeOut, List<Packet> pendingQueue, LinkedList<Packet> outgoingQueue,
ClientCnxn cnxn)
throws IOException, InterruptedException {
selector.select(waitTimeOut);
Set<SelectionKey> selected;
synchronized (this) {
selected = selector.selectedKeys();
}
// Everything below and until we get back to the select is
// non blocking, so time is effectively a constant. That is
// Why we just have to do this once, here
updateNow();
for (SelectionKey k : selected) {
SocketChannel sc = ((SocketChannel) k.channel());
if ((k.readyOps() & SelectionKey.OP_CONNECT) != 0) {
if (sc.finishConnect()) {
updateLastSendAndHeard();
updateSocketAddresses();
sendThread.primeConnection();
}
} else if ((k.readyOps() & (SelectionKey.OP_READ | SelectionKey.OP_WRITE)) != 0) {
doIO(pendingQueue, outgoingQueue, cnxn);
}
}
if (sendThread.getZkState().isConnected()) {
synchronized(outgoingQueue) {
if (findSendablePacket(outgoingQueue,
cnxn.sendThread.clientTunneledAuthenticationInProgress()) != null) {
enableWrite();
}
}
}
selected.clear();
}
//TODO should this be synchronized?
@Override
void testableCloseSocket() throws IOException {
LOG.info("testableCloseSocket() called");
// sockKey may be concurrently accessed by multiple
// threads. We use tmp here to avoid a race condition
SelectionKey tmp = sockKey;
if (tmp!=null) {
((SocketChannel) tmp.channel()).socket().close();
}
}
@Override
synchronized void enableWrite() {
int i = sockKey.interestOps();
if ((i & SelectionKey.OP_WRITE) == 0) {
sockKey.interestOps(i | SelectionKey.OP_WRITE);
}
}
@Override
public synchronized void disableWrite() {
int i = sockKey.interestOps();
if ((i & SelectionKey.OP_WRITE) != 0) {
sockKey.interestOps(i & (~SelectionKey.OP_WRITE));
}
}
synchronized private void enableRead() {
int i = sockKey.interestOps();
if ((i & SelectionKey.OP_READ) == 0) {
sockKey.interestOps(i | SelectionKey.OP_READ);
}
}
@Override
synchronized void enableReadWriteOnly() {
sockKey.interestOps(SelectionKey.OP_READ | SelectionKey.OP_WRITE);
}
Selector getSelector() {
return selector;
}
@Override
void sendPacket(Packet p) throws IOException {
SocketChannel sock = (SocketChannel) sockKey.channel();
if (sock == null) {
throw new IOException("Socket is null!");
}
p.createBB();
ByteBuffer pbb = p.bb;
sock.write(pbb);
}
}