/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.client.api.async.impl;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock.ReadLock;
import java.util.concurrent.locks.ReentrantReadWriteLock.WriteLock;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.classification.InterfaceAudience.Private;
import org.apache.hadoop.classification.InterfaceStability.Unstable;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.api.records.Container;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
import org.apache.hadoop.yarn.api.records.ContainerStatus;
import org.apache.hadoop.yarn.api.records.NodeId;
import org.apache.hadoop.yarn.api.records.Token;
import org.apache.hadoop.yarn.client.api.NMClient;
import org.apache.hadoop.yarn.client.api.async.NMClientAsync;
import org.apache.hadoop.yarn.client.api.impl.NMClientImpl;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.event.AbstractEvent;
import org.apache.hadoop.yarn.event.EventHandler;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.ipc.RPCUtil;
import org.apache.hadoop.yarn.state.InvalidStateTransitonException;
import org.apache.hadoop.yarn.state.MultipleArcTransition;
import org.apache.hadoop.yarn.state.SingleArcTransition;
import org.apache.hadoop.yarn.state.StateMachine;
import org.apache.hadoop.yarn.state.StateMachineFactory;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
@Private
@Unstable
public class NMClientAsyncImpl extends NMClientAsync {
private static final Log LOG = LogFactory.getLog(NMClientAsyncImpl.class);
protected static final int INITIAL_THREAD_POOL_SIZE = 10;
protected ThreadPoolExecutor threadPool;
protected int maxThreadPoolSize;
protected Thread eventDispatcherThread;
protected AtomicBoolean stopped = new AtomicBoolean(false);
protected BlockingQueue<ContainerEvent> events =
new LinkedBlockingQueue<ContainerEvent>();
protected ConcurrentMap<ContainerId, StatefulContainer> containers =
new ConcurrentHashMap<ContainerId, StatefulContainer>();
public NMClientAsyncImpl(CallbackHandler callbackHandler) {
this(NMClientAsync.class.getName(), callbackHandler);
}
public NMClientAsyncImpl(String name, CallbackHandler callbackHandler) {
this(name, new NMClientImpl(), callbackHandler);
}
@Private
@VisibleForTesting
protected NMClientAsyncImpl(String name, NMClient client,
CallbackHandler callbackHandler) {
super(name, client, callbackHandler);
this.client = client;
this.callbackHandler = callbackHandler;
}
@Override
protected void serviceInit(Configuration conf) throws Exception {
this.maxThreadPoolSize = conf.getInt(
YarnConfiguration.NM_CLIENT_ASYNC_THREAD_POOL_MAX_SIZE,
YarnConfiguration.DEFAULT_NM_CLIENT_ASYNC_THREAD_POOL_MAX_SIZE);
LOG.info("Upper bound of the thread pool size is " + maxThreadPoolSize);
client.init(conf);
super.serviceInit(conf);
}
@Override
protected void serviceStart() throws Exception {
client.start();
ThreadFactory tf = new ThreadFactoryBuilder().setNameFormat(
this.getClass().getName() + " #%d").setDaemon(true).build();
// Start with a default core-pool size and change it dynamically.
int initSize = Math.min(INITIAL_THREAD_POOL_SIZE, maxThreadPoolSize);
threadPool = new ThreadPoolExecutor(initSize, Integer.MAX_VALUE, 1,
TimeUnit.HOURS, new LinkedBlockingQueue<Runnable>(), tf);
eventDispatcherThread = new Thread() {
@Override
public void run() {
ContainerEvent event = null;
Set<String> allNodes = new HashSet<String>();
while (!stopped.get() && !Thread.currentThread().isInterrupted()) {
try {
event = events.take();
} catch (InterruptedException e) {
if (!stopped.get()) {
LOG.error("Returning, thread interrupted", e);
}
return;
}
allNodes.add(event.getNodeId().toString());
int threadPoolSize = threadPool.getCorePoolSize();
// We can increase the pool size only if haven't reached the maximum
// limit yet.
if (threadPoolSize != maxThreadPoolSize) {
// nodes where containers will run at *this* point of time. This is
// *not* the cluster size and doesn't need to be.
int nodeNum = allNodes.size();
int idealThreadPoolSize = Math.min(maxThreadPoolSize, nodeNum);
if (threadPoolSize < idealThreadPoolSize) {
// Bump up the pool size to idealThreadPoolSize +
// INITIAL_POOL_SIZE, the later is just a buffer so we are not
// always increasing the pool-size
int newThreadPoolSize = Math.min(maxThreadPoolSize,
idealThreadPoolSize + INITIAL_THREAD_POOL_SIZE);
LOG.info("Set NMClientAsync thread pool size to " +
newThreadPoolSize + " as the number of nodes to talk to is "
+ nodeNum);
threadPool.setCorePoolSize(newThreadPoolSize);
}
}
// the events from the queue are handled in parallel with a thread
// pool
threadPool.execute(getContainerEventProcessor(event));
// TODO: Group launching of multiple containers to a single
// NodeManager into a single connection
}
}
};
eventDispatcherThread.setName("Container Event Dispatcher");
eventDispatcherThread.setDaemon(false);
eventDispatcherThread.start();
super.serviceStart();
}
@Override
protected void serviceStop() throws Exception {
if (stopped.getAndSet(true)) {
// return if already stopped
return;
}
if (eventDispatcherThread != null) {
eventDispatcherThread.interrupt();
try {
eventDispatcherThread.join();
} catch (InterruptedException e) {
LOG.error("The thread of " + eventDispatcherThread.getName() +
" didn't finish normally.", e);
}
}
if (threadPool != null) {
threadPool.shutdownNow();
}
if (client != null) {
// If NMClientImpl doesn't stop running containers, the states doesn't
// need to be cleared.
if (!(client instanceof NMClientImpl) ||
((NMClientImpl) client).getCleanupRunningContainers().get()) {
if (containers != null) {
containers.clear();
}
}
client.stop();
}
super.serviceStop();
}
public void startContainerAsync(
Container container, ContainerLaunchContext containerLaunchContext) {
if (containers.putIfAbsent(container.getId(),
new StatefulContainer(this, container.getId())) != null) {
callbackHandler.onStartContainerError(container.getId(),
RPCUtil.getRemoteException("Container " + container.getId() +
" is already started or scheduled to start"));
}
try {
events.put(new StartContainerEvent(container, containerLaunchContext));
} catch (InterruptedException e) {
LOG.warn("Exception when scheduling the event of starting Container " +
container.getId());
callbackHandler.onStartContainerError(container.getId(), e);
}
}
public void stopContainerAsync(ContainerId containerId, NodeId nodeId) {
if (containers.get(containerId) == null) {
callbackHandler.onStopContainerError(containerId,
RPCUtil.getRemoteException("Container " + containerId +
" is neither started nor scheduled to start"));
}
try {
events.put(new ContainerEvent(containerId, nodeId, null,
ContainerEventType.STOP_CONTAINER));
} catch (InterruptedException e) {
LOG.warn("Exception when scheduling the event of stopping Container " +
containerId);
callbackHandler.onStopContainerError(containerId, e);
}
}
public void getContainerStatusAsync(ContainerId containerId, NodeId nodeId) {
try {
events.put(new ContainerEvent(containerId, nodeId, null,
ContainerEventType.QUERY_CONTAINER));
} catch (InterruptedException e) {
LOG.warn("Exception when scheduling the event of querying the status" +
" of Container " + containerId);
callbackHandler.onGetContainerStatusError(containerId, e);
}
}
protected static enum ContainerState {
PREP, FAILED, RUNNING, DONE,
}
protected boolean isCompletelyDone(StatefulContainer container) {
return container.getState() == ContainerState.DONE ||
container.getState() == ContainerState.FAILED;
}
protected ContainerEventProcessor getContainerEventProcessor(
ContainerEvent event) {
return new ContainerEventProcessor(event);
}
/**
* The type of the event of interacting with a container
*/
protected static enum ContainerEventType {
START_CONTAINER,
STOP_CONTAINER,
QUERY_CONTAINER
}
protected static class ContainerEvent
extends AbstractEvent<ContainerEventType>{
private ContainerId containerId;
private NodeId nodeId;
private Token containerToken;
public ContainerEvent(ContainerId containerId, NodeId nodeId,
Token containerToken, ContainerEventType type) {
super(type);
this.containerId = containerId;
this.nodeId = nodeId;
this.containerToken = containerToken;
}
public ContainerId getContainerId() {
return containerId;
}
public NodeId getNodeId() {
return nodeId;
}
public Token getContainerToken() {
return containerToken;
}
}
protected static class StartContainerEvent extends ContainerEvent {
private Container container;
private ContainerLaunchContext containerLaunchContext;
public StartContainerEvent(Container container,
ContainerLaunchContext containerLaunchContext) {
super(container.getId(), container.getNodeId(),
container.getContainerToken(), ContainerEventType.START_CONTAINER);
this.container = container;
this.containerLaunchContext = containerLaunchContext;
}
public Container getContainer() {
return container;
}
public ContainerLaunchContext getContainerLaunchContext() {
return containerLaunchContext;
}
}
protected static class StatefulContainer implements
EventHandler<ContainerEvent> {
protected final static StateMachineFactory<StatefulContainer,
ContainerState, ContainerEventType, ContainerEvent> stateMachineFactory
= new StateMachineFactory<StatefulContainer, ContainerState,
ContainerEventType, ContainerEvent>(ContainerState.PREP)
// Transitions from PREP state
.addTransition(ContainerState.PREP,
EnumSet.of(ContainerState.RUNNING, ContainerState.FAILED),
ContainerEventType.START_CONTAINER,
new StartContainerTransition())
.addTransition(ContainerState.PREP, ContainerState.DONE,
ContainerEventType.STOP_CONTAINER, new OutOfOrderTransition())
// Transitions from RUNNING state
// RUNNING -> RUNNING should be the invalid transition
.addTransition(ContainerState.RUNNING,
EnumSet.of(ContainerState.DONE, ContainerState.FAILED),
ContainerEventType.STOP_CONTAINER,
new StopContainerTransition())
// Transition from DONE state
.addTransition(ContainerState.DONE, ContainerState.DONE,
EnumSet.of(ContainerEventType.START_CONTAINER,
ContainerEventType.STOP_CONTAINER))
// Transition from FAILED state
.addTransition(ContainerState.FAILED, ContainerState.FAILED,
EnumSet.of(ContainerEventType.START_CONTAINER,
ContainerEventType.STOP_CONTAINER));
protected static class StartContainerTransition implements
MultipleArcTransition<StatefulContainer, ContainerEvent,
ContainerState> {
@Override
public ContainerState transition(
StatefulContainer container, ContainerEvent event) {
ContainerId containerId = event.getContainerId();
try {
StartContainerEvent scEvent = null;
if (event instanceof StartContainerEvent) {
scEvent = (StartContainerEvent) event;
}
assert scEvent != null;
Map<String, ByteBuffer> allServiceResponse =
container.nmClientAsync.getClient().startContainer(
scEvent.getContainer(), scEvent.getContainerLaunchContext());
try {
container.nmClientAsync.getCallbackHandler().onContainerStarted(
containerId, allServiceResponse);
} catch (Throwable thr) {
// Don't process user created unchecked exception
LOG.info("Unchecked exception is thrown from onContainerStarted for "
+ "Container " + containerId, thr);
}
return ContainerState.RUNNING;
} catch (YarnException e) {
return onExceptionRaised(container, event, e);
} catch (IOException e) {
return onExceptionRaised(container, event, e);
} catch (Throwable t) {
return onExceptionRaised(container, event, t);
}
}
private ContainerState onExceptionRaised(StatefulContainer container,
ContainerEvent event, Throwable t) {
try {
container.nmClientAsync.getCallbackHandler().onStartContainerError(
event.getContainerId(), t);
} catch (Throwable thr) {
// Don't process user created unchecked exception
LOG.info(
"Unchecked exception is thrown from onStartContainerError for " +
"Container " + event.getContainerId(), thr);
}
return ContainerState.FAILED;
}
}
protected static class StopContainerTransition implements
MultipleArcTransition<StatefulContainer, ContainerEvent,
ContainerState> {
@Override
public ContainerState transition(
StatefulContainer container, ContainerEvent event) {
ContainerId containerId = event.getContainerId();
try {
container.nmClientAsync.getClient().stopContainer(
containerId, event.getNodeId());
try {
container.nmClientAsync.getCallbackHandler().onContainerStopped(
event.getContainerId());
} catch (Throwable thr) {
// Don't process user created unchecked exception
LOG.info("Unchecked exception is thrown from onContainerStopped for "
+ "Container " + event.getContainerId(), thr);
}
return ContainerState.DONE;
} catch (YarnException e) {
return onExceptionRaised(container, event, e);
} catch (IOException e) {
return onExceptionRaised(container, event, e);
} catch (Throwable t) {
return onExceptionRaised(container, event, t);
}
}
private ContainerState onExceptionRaised(StatefulContainer container,
ContainerEvent event, Throwable t) {
try {
container.nmClientAsync.getCallbackHandler().onStopContainerError(
event.getContainerId(), t);
} catch (Throwable thr) {
// Don't process user created unchecked exception
LOG.info("Unchecked exception is thrown from onStopContainerError for "
+ "Container " + event.getContainerId(), thr);
}
return ContainerState.FAILED;
}
}
protected static class OutOfOrderTransition implements
SingleArcTransition<StatefulContainer, ContainerEvent> {
protected static final String STOP_BEFORE_START_ERROR_MSG =
"Container was killed before it was launched";
@Override
public void transition(StatefulContainer container, ContainerEvent event) {
try {
container.nmClientAsync.getCallbackHandler().onStartContainerError(
event.getContainerId(),
RPCUtil.getRemoteException(STOP_BEFORE_START_ERROR_MSG));
} catch (Throwable thr) {
// Don't process user created unchecked exception
LOG.info(
"Unchecked exception is thrown from onStartContainerError for " +
"Container " + event.getContainerId(), thr);
}
}
}
private final NMClientAsync nmClientAsync;
private final ContainerId containerId;
private final StateMachine<ContainerState, ContainerEventType,
ContainerEvent> stateMachine;
private final ReadLock readLock;
private final WriteLock writeLock;
public StatefulContainer(NMClientAsync client, ContainerId containerId) {
this.nmClientAsync = client;
this.containerId = containerId;
stateMachine = stateMachineFactory.make(this);
ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
readLock = lock.readLock();
writeLock = lock.writeLock();
}
@Override
public void handle(ContainerEvent event) {
writeLock.lock();
try {
try {
this.stateMachine.doTransition(event.getType(), event);
} catch (InvalidStateTransitonException e) {
LOG.error("Can't handle this event at current state", e);
}
} finally {
writeLock.unlock();
}
}
public ContainerId getContainerId() {
return containerId;
}
public ContainerState getState() {
readLock.lock();
try {
return stateMachine.getCurrentState();
} finally {
readLock.unlock();
}
}
}
protected class ContainerEventProcessor implements Runnable {
protected ContainerEvent event;
public ContainerEventProcessor(ContainerEvent event) {
this.event = event;
}
@Override
public void run() {
ContainerId containerId = event.getContainerId();
LOG.info("Processing Event " + event + " for Container " + containerId);
if (event.getType() == ContainerEventType.QUERY_CONTAINER) {
try {
ContainerStatus containerStatus = client.getContainerStatus(
containerId, event.getNodeId());
try {
callbackHandler.onContainerStatusReceived(
containerId, containerStatus);
} catch (Throwable thr) {
// Don't process user created unchecked exception
LOG.info(
"Unchecked exception is thrown from onContainerStatusReceived" +
" for Container " + event.getContainerId(), thr);
}
} catch (YarnException e) {
onExceptionRaised(containerId, e);
} catch (IOException e) {
onExceptionRaised(containerId, e);
} catch (Throwable t) {
onExceptionRaised(containerId, t);
}
} else {
StatefulContainer container = containers.get(containerId);
if (container == null) {
LOG.info("Container " + containerId + " is already stopped or failed");
} else {
container.handle(event);
if (isCompletelyDone(container)) {
containers.remove(containerId);
}
}
}
}
private void onExceptionRaised(ContainerId containerId, Throwable t) {
try {
callbackHandler.onGetContainerStatusError(containerId, t);
} catch (Throwable thr) {
// Don't process user created unchecked exception
LOG.info("Unchecked exception is thrown from onGetContainerStatusError" +
" for Container " + containerId, thr);
}
}
}
}