/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.thrift;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.FutureTask;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
/**
* This class uses a single thread to set up non-blocking sockets to a set
* of remote servers (hostname and port pairs), and sends a same request to
* all these servers. It then fetches responses from servers.
*
* Parameters:
* int maxRecvBufBytesPerServer - an upper limit for receive buffer size
* per server (in byte). If a response from a server exceeds this limit, the
* client will not allocate memory or read response data for it.
*
* int fetchTimeoutSeconds - time limit for fetching responses from all
* servers (in second). After the timeout, the fetch job is stopped and
* available responses are returned.
*
* ByteBuffer requestBuf - request message that is sent to all servers.
*
* Output:
* Responses are stored in an array of ByteBuffers. Index of elements in
* this array corresponds to index of servers in the server list. Content in
* a ByteBuffer may be in one of the following forms:
* 1. First 4 bytes form an integer indicating length of following data,
* then followed by the data.
* 2. First 4 bytes form an integer indicating length of following data,
* then followed by nothing - this happens when the response data size
* exceeds maxRecvBufBytesPerServer, and the client will not read any
* response data.
* 3. No data in the ByteBuffer - this happens when the server does not
* return any response within fetchTimeoutSeconds.
*
* In some special cases (no servers are given, fetchTimeoutSeconds less
* than or equal to 0, requestBuf is null), the return is null.
*
* Note:
* It assumes all remote servers are TNonblockingServers and use
* TFramedTransport.
*
*/
public class TNonblockingMultiFetchClient {
private static final Logger LOGGER = LoggerFactory.getLogger(
TNonblockingMultiFetchClient.class.getName()
);
// if the size of the response msg exceeds this limit (in byte), we will
// not read the msg
private int maxRecvBufBytesPerServer;
// time limit for fetching data from all servers (in second)
private int fetchTimeoutSeconds;
// store request that will be sent to servers
private ByteBuffer requestBuf;
private ByteBuffer requestBufDuplication;
// a list of remote servers
private List<InetSocketAddress> servers;
// store fetch results
private TNonblockingMultiFetchStats stats;
private ByteBuffer[] recvBuf;
public TNonblockingMultiFetchClient(int maxRecvBufBytesPerServer,
int fetchTimeoutSeconds, ByteBuffer requestBuf,
List<InetSocketAddress> servers) {
this.maxRecvBufBytesPerServer = maxRecvBufBytesPerServer;
this.fetchTimeoutSeconds = fetchTimeoutSeconds;
this.requestBuf = requestBuf;
this.servers = servers;
stats = new TNonblockingMultiFetchStats();
recvBuf = null;
}
public synchronized int getMaxRecvBufBytesPerServer() {
return maxRecvBufBytesPerServer;
}
public synchronized int getFetchTimeoutSeconds() {
return fetchTimeoutSeconds;
}
/**
* return a duplication of requestBuf, so that requestBuf will not
* be modified by others.
*/
public synchronized ByteBuffer getRequestBuf() {
if (requestBuf == null) {
return null;
} else {
if (requestBufDuplication == null) {
requestBufDuplication = requestBuf.duplicate();
}
return requestBufDuplication;
}
}
public synchronized List<InetSocketAddress> getServerList() {
if (servers == null) {
return null;
}
return Collections.unmodifiableList(servers);
}
public synchronized TNonblockingMultiFetchStats getFetchStats() {
return stats;
}
/**
* main entry function for fetching from servers
*/
public synchronized ByteBuffer[] fetch() {
// clear previous results
recvBuf = null;
stats.clear();
if (servers == null || servers.size() == 0 ||
requestBuf == null || fetchTimeoutSeconds <= 0) {
return recvBuf;
}
ExecutorService executor = Executors.newSingleThreadExecutor();
MultiFetch multiFetch = new MultiFetch();
FutureTask<?> task = new FutureTask(multiFetch, null);
executor.execute(task);
try {
task.get(fetchTimeoutSeconds, TimeUnit.SECONDS);
} catch(InterruptedException ie) {
// attempt to cancel execution of the task.
task.cancel(true);
LOGGER.error("interrupted during fetch: "+ie.toString());
} catch(ExecutionException ee) {
// attempt to cancel execution of the task.
task.cancel(true);
LOGGER.error("exception during fetch: "+ee.toString());
} catch(TimeoutException te) {
// attempt to cancel execution of the task.
task.cancel(true);
LOGGER.error("timeout for fetch: "+te.toString());
}
executor.shutdownNow();
multiFetch.close();
return recvBuf;
}
/**
* Private class that does real fetch job.
* Users are not allowed to directly use this class, as its run()
* function may run forever.
*/
private class MultiFetch implements Runnable {
private Selector selector;
/**
* main entry function for fetching.
*
* Server responses are stored in TNonblocingMultiFetchClient.recvBuf,
* and fetch statistics is in TNonblockingMultiFetchClient.stats.
*
* Sanity check for parameters has been done in
* TNonblockingMultiFetchClient before calling this function.
*/
public void run() {
long t1 = System.currentTimeMillis();
int numTotalServers = servers.size();
stats.setNumTotalServers(numTotalServers);
// buffer for receiving response from servers
recvBuf = new ByteBuffer[numTotalServers];
// buffer for sending request
ByteBuffer sendBuf[] = new ByteBuffer[numTotalServers];
long numBytesRead[] = new long[numTotalServers];
int frameSize[] = new int[numTotalServers];
boolean hasReadFrameSize[] = new boolean[numTotalServers];
try {
selector = Selector.open();
} catch (IOException e) {
LOGGER.error("selector opens error: "+e.toString());
return;
}
for (int i = 0; i < numTotalServers; i++) {
// create buffer to send request to server.
sendBuf[i] = requestBuf.duplicate();
// create buffer to read response's frame size from server
recvBuf[i] = ByteBuffer.allocate(4);
stats.incTotalRecvBufBytes(4);
InetSocketAddress server = servers.get(i);
SocketChannel s = null;
SelectionKey key = null;
try {
s = SocketChannel.open();
s.configureBlocking(false);
// now this method is non-blocking
s.connect(server);
key = s.register(selector, s.validOps());
// attach index of the key
key.attach(i);
} catch (Exception e) {
stats.incNumConnectErrorServers();
String err = String.format("set up socket to server %s error: %s",
server.toString(), e.toString());
LOGGER.error(err);
// free resource
if (s != null) {
try {s.close();} catch (Exception ex) {}
}
if (key != null) {
key.cancel();
}
}
}
// wait for events
while (stats.getNumReadCompletedServers() +
stats.getNumConnectErrorServers() < stats.getNumTotalServers()) {
// if the thread is interrupted (e.g., task is cancelled)
if (Thread.currentThread().isInterrupted()) {
return;
}
try{
selector.select();
} catch (Exception e) {
LOGGER.error("selector selects error: "+e.toString());
continue;
}
Iterator<SelectionKey> it = selector.selectedKeys().iterator();
while (it.hasNext()) {
SelectionKey selKey = it.next();
it.remove();
// get previously attached index
int index = (Integer)selKey.attachment();
if (selKey.isValid() && selKey.isConnectable()) {
// if this socket throws an exception (e.g., connection refused),
// print error msg and skip it.
try {
SocketChannel sChannel = (SocketChannel)selKey.channel();
sChannel.finishConnect();
} catch (Exception e) {
stats.incNumConnectErrorServers();
String err = String.format("socket %d connects to server %s " +
"error: %s",
index, servers.get(index).toString(), e.toString());
LOGGER.error(err);
}
}
if (selKey.isValid() && selKey.isWritable()) {
if (sendBuf[index].hasRemaining()) {
// if this socket throws an exception, print error msg and
// skip it.
try {
SocketChannel sChannel = (SocketChannel)selKey.channel();
sChannel.write(sendBuf[index]);
} catch (Exception e) {
String err = String.format("socket %d writes to server %s " +
"error: %s",
index, servers.get(index).toString(), e.toString());
LOGGER.error(err);
}
}
}
if (selKey.isValid() && selKey.isReadable()) {
// if this socket throws an exception, print error msg and
// skip it.
try {
SocketChannel sChannel = (SocketChannel)selKey.channel();
int bytesRead = sChannel.read(recvBuf[index]);
if (bytesRead > 0) {
numBytesRead[index] += bytesRead;
if (!hasReadFrameSize[index] &&
recvBuf[index].remaining()==0) {
// if the frame size has been read completely, then prepare
// to read the actual frame.
frameSize[index] = recvBuf[index].getInt(0);
if (frameSize[index] <= 0) {
stats.incNumInvalidFrameSize();
String err = String.format("Read an invalid frame size %d"
+ " from %s. Does the server use TFramedTransport? ",
frameSize[index], servers.get(index).toString());
LOGGER.error(err);
sChannel.close();
continue;
}
if (frameSize[index] + 4 > stats.getMaxResponseBytes()) {
stats.setMaxResponseBytes(frameSize[index]+4);
}
if (frameSize[index] + 4 > maxRecvBufBytesPerServer) {
stats.incNumOverflowedRecvBuf();
String err = String.format("Read frame size %d from %s,"
+ " total buffer size would exceed limit %d",
frameSize[index], servers.get(index).toString(),
maxRecvBufBytesPerServer);
LOGGER.error(err);
sChannel.close();
continue;
}
// reallocate buffer for actual frame data
recvBuf[index] = ByteBuffer.allocate(frameSize[index] + 4);
recvBuf[index].putInt(frameSize[index]);
stats.incTotalRecvBufBytes(frameSize[index]);
hasReadFrameSize[index] = true;
}
if (hasReadFrameSize[index] &&
numBytesRead[index] >= frameSize[index]+4) {
// has read all data
sChannel.close();
stats.incNumReadCompletedServers();
long t2 = System.currentTimeMillis();
stats.setReadTime(t2-t1);
}
}
} catch (Exception e) {
String err = String.format("socket %d reads from server %s " +
"error: %s",
index, servers.get(index).toString(), e.toString());
LOGGER.error(err);
}
}
}
}
}
/**
* dispose any resource allocated
*/
public void close() {
try {
if (selector.isOpen()) {
Iterator<SelectionKey> it = selector.keys().iterator();
while (it.hasNext()) {
SelectionKey selKey = it.next();
SocketChannel sChannel = (SocketChannel)selKey.channel();
sChannel.close();
}
selector.close();
}
} catch (IOException e) {
LOGGER.error("free resource error: "+e.toString());
}
}
}
}