/*
* Copyright 2013 The Regents of The University California
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package edu.berkeley.sparrow.daemon.scheduler;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.configuration.Configuration;
import org.apache.log4j.Logger;
import org.apache.thrift.TException;
import org.apache.thrift.async.AsyncMethodCallback;
import com.google.common.base.Optional;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import edu.berkeley.sparrow.daemon.SparrowConf;
import edu.berkeley.sparrow.daemon.util.Logging;
import edu.berkeley.sparrow.daemon.util.Network;
import edu.berkeley.sparrow.daemon.util.Serialization;
import edu.berkeley.sparrow.daemon.util.ThriftClientPool;
import edu.berkeley.sparrow.thrift.FrontendService;
import edu.berkeley.sparrow.thrift.FrontendService.AsyncClient.frontendMessage_call;
import edu.berkeley.sparrow.thrift.InternalService;
import edu.berkeley.sparrow.thrift.InternalService.AsyncClient;
import edu.berkeley.sparrow.thrift.InternalService.AsyncClient.enqueueTaskReservations_call;
import edu.berkeley.sparrow.thrift.TEnqueueTaskReservationsRequest;
import edu.berkeley.sparrow.thrift.TFullTaskId;
import edu.berkeley.sparrow.thrift.THostPort;
import edu.berkeley.sparrow.thrift.TPlacementPreference;
import edu.berkeley.sparrow.thrift.TSchedulingRequest;
import edu.berkeley.sparrow.thrift.TTaskLaunchSpec;
import edu.berkeley.sparrow.thrift.TTaskSpec;
/**
* This class implements the Sparrow scheduler functionality.
*/
public class Scheduler {
private final static Logger LOG = Logger.getLogger(Scheduler.class);
private final static Logger AUDIT_LOG = Logging.getAuditLogger(Scheduler.class);
/** Used to uniquely identify requests arriving at this scheduler. */
private AtomicInteger counter = new AtomicInteger(0);
/** How many times the special case has been triggered. */
private AtomicInteger specialCaseCounter = new AtomicInteger(0);
private THostPort address;
/** Socket addresses for each frontend. */
HashMap<String, InetSocketAddress> frontendSockets =
new HashMap<String, InetSocketAddress>();
/**
* Service that handles cancelling outstanding reservations for jobs that have already been
* scheduled. Only instantiated if {@code SparrowConf.CANCELLATION} is set to true.
*/
private CancellationService cancellationService;
private boolean useCancellation;
/** Thrift client pool for communicating with node monitors */
ThriftClientPool<InternalService.AsyncClient> nodeMonitorClientPool =
new ThriftClientPool<InternalService.AsyncClient>(
new ThriftClientPool.InternalServiceMakerFactory());
/** Thrift client pool for communicating with front ends. */
private ThriftClientPool<FrontendService.AsyncClient> frontendClientPool =
new ThriftClientPool<FrontendService.AsyncClient>(
new ThriftClientPool.FrontendServiceMakerFactory());
/** Information about cluster workload due to other schedulers. */
private SchedulerState state;
/** Probe ratios to use if the probe ratio is not explicitly set in the request. */
private double defaultProbeRatioUnconstrained;
private double defaultProbeRatioConstrained;
/**
* For each request, the task placer that should be used to place the request's tasks. Indexed
* by the request ID.
*/
private ConcurrentMap<String, TaskPlacer> requestTaskPlacers;
/**
* When a job includes SPREAD_EVENLY in the description and has this number of tasks,
* Sparrow spreads the tasks evenly over machines to evenly cache data. We need this (in
* addition to the SPREAD_EVENLY descriptor) because only the reduce phase -- not the map
* phase -- should be spread.
*/
private int spreadEvenlyTaskSetSize;
private Configuration conf;
public void initialize(Configuration conf, InetSocketAddress socket) throws IOException {
address = Network.socketAddressToThrift(socket);
String mode = conf.getString(SparrowConf.DEPLYOMENT_MODE, "unspecified");
this.conf = conf;
if (mode.equals("standalone")) {
state = new StandaloneSchedulerState();
} else if (mode.equals("configbased")) {
state = new ConfigSchedulerState();
} else {
throw new RuntimeException("Unsupported deployment mode: " + mode);
}
state.initialize(conf);
defaultProbeRatioUnconstrained = conf.getDouble(SparrowConf.SAMPLE_RATIO,
SparrowConf.DEFAULT_SAMPLE_RATIO);
defaultProbeRatioConstrained = conf.getDouble(SparrowConf.SAMPLE_RATIO_CONSTRAINED,
SparrowConf.DEFAULT_SAMPLE_RATIO_CONSTRAINED);
requestTaskPlacers = Maps.newConcurrentMap();
useCancellation = conf.getBoolean(SparrowConf.CANCELLATION, SparrowConf.DEFAULT_CANCELLATION);
if (useCancellation) {
LOG.debug("Initializing cancellation service");
cancellationService = new CancellationService(nodeMonitorClientPool);
new Thread(cancellationService).start();
} else {
LOG.debug("Not using cancellation");
}
spreadEvenlyTaskSetSize = conf.getInt(SparrowConf.SPREAD_EVENLY_TASK_SET_SIZE,
SparrowConf.DEFAULT_SPREAD_EVENLY_TASK_SET_SIZE);
}
public boolean registerFrontend(String appId, String addr) {
LOG.debug(Logging.functionCall(appId, addr));
Optional<InetSocketAddress> socketAddress = Serialization.strToSocket(addr);
if (!socketAddress.isPresent()) {
LOG.error("Bad address from frontend: " + addr);
return false;
}
frontendSockets.put(appId, socketAddress.get());
return state.watchApplication(appId);
}
/**
* Callback for enqueueTaskReservations() that returns the client to the client pool.
*/
private class EnqueueTaskReservationsCallback
implements AsyncMethodCallback<enqueueTaskReservations_call> {
String requestId;
InetSocketAddress nodeMonitorAddress;
long startTimeMillis;
public EnqueueTaskReservationsCallback(String requestId, InetSocketAddress nodeMonitorAddress) {
this.requestId = requestId;
this.nodeMonitorAddress = nodeMonitorAddress;
this.startTimeMillis = System.currentTimeMillis();
}
public void onComplete(enqueueTaskReservations_call response) {
AUDIT_LOG.debug(Logging.auditEventString(
"scheduler_complete_enqueue_task", requestId,
nodeMonitorAddress.getAddress().getHostAddress()));
long totalTime = System.currentTimeMillis() - startTimeMillis;
LOG.debug("Enqueue Task RPC to " + nodeMonitorAddress.getAddress().getHostAddress() +
" for request " + requestId + " completed in " + totalTime + "ms");
try {
nodeMonitorClientPool.returnClient(nodeMonitorAddress, (AsyncClient) response.getClient());
} catch (Exception e) {
LOG.error("Error returning client to node monitor client pool: " + e);
}
return;
}
public void onError(Exception exception) {
// Do not return error client to pool
LOG.error("Error executing enqueueTaskReservation RPC:" + exception);
}
}
/** Adds constraints such that tasks in the job will be spread evenly across the cluster.
*
* We expect three of these special jobs to be submitted; 3 sequential calls to this
* method will result in spreading the tasks for the 3 jobs across the cluster such that no
* more than 1 task is assigned to each machine.
*/
private TSchedulingRequest addConstraintsToSpreadTasks(TSchedulingRequest req)
throws TException {
LOG.info("Handling spread tasks request: " + req);
int specialCaseIndex = specialCaseCounter.incrementAndGet();
if (specialCaseIndex < 1 || specialCaseIndex > 3) {
LOG.error("Invalid special case index: " + specialCaseIndex);
}
// No tasks have preferences and we have the magic number of tasks
TSchedulingRequest newReq = new TSchedulingRequest();
newReq.user = req.user;
newReq.app = req.app;
newReq.probeRatio = req.probeRatio;
List<InetSocketAddress> allBackends = Lists.newArrayList();
List<InetSocketAddress> backends = Lists.newArrayList();
// We assume the below always returns the same order (invalid assumption?)
for (InetSocketAddress backend : state.getBackends(req.app)) {
allBackends.add(backend);
}
// Each time this is called, we restrict to 1/3 of the nodes in the cluster
for (int i = 0; i < allBackends.size(); i++) {
if (i % 3 == specialCaseIndex - 1) {
backends.add(allBackends.get(i));
}
}
Collections.shuffle(backends);
if (!(allBackends.size() >= (req.getTasks().size() * 3))) {
LOG.error("Special case expects at least three times as many machines as tasks.");
return null;
}
LOG.info(backends);
for (int i = 0; i < req.getTasksSize(); i++) {
TTaskSpec task = req.getTasks().get(i);
TTaskSpec newTask = new TTaskSpec();
newTask.message = task.message;
newTask.taskId = task.taskId;
newTask.preference = new TPlacementPreference();
newTask.preference.addToNodes(backends.get(i).getHostName());
newReq.addToTasks(newTask);
}
LOG.info("New request: " + newReq);
return newReq;
}
/** Checks whether we should add constraints to this job to evenly spread tasks over machines.
*
* This is a hack used to force Spark to cache data in 3 locations: we run 3 select * queries
* on the same table and spread the tasks for those queries evenly across the cluster such that
* the input data for the query is triple replicated and spread evenly across the cluster.
*
* We signal that Sparrow should use this hack by adding SPREAD_TASKS to the job's description.
*/
private boolean isSpreadTasksJob(TSchedulingRequest request) {
if ((request.getDescription() != null) &&
(request.getDescription().indexOf("SPREAD_EVENLY") != -1)) {
// Need to check to see if there are 3 constraints; if so, it's the map phase of the
// first job that reads the data from HDFS, so we shouldn't override the constraints.
for (TTaskSpec t: request.getTasks()) {
if (t.getPreference() != null && (t.getPreference().getNodes() != null) &&
(t.getPreference().getNodes().size() == 3)) {
LOG.debug("Not special case: one of request's tasks had 3 preferences");
return false;
}
}
if (request.getTasks().size() != spreadEvenlyTaskSetSize) {
LOG.debug("Not special case: job had " + request.getTasks().size() +
" tasks rather than the expected " + spreadEvenlyTaskSetSize);
return false;
}
if (specialCaseCounter.get() >= 3) {
LOG.error("Not using special case because special case code has already been " +
" called 3 more more times!");
return false;
}
LOG.debug("Spreading tasks for job with " + request.getTasks().size() + " tasks");
return true;
}
LOG.debug("Not special case: description did not contain SPREAD_EVENLY");
return false;
}
public void submitJob(TSchedulingRequest request) throws TException {
// Short-circuit case that is used for liveness checking
if (request.tasks.size() == 0) { return; }
if (isSpreadTasksJob(request)) {
handleJobSubmission(addConstraintsToSpreadTasks(request));
} else {
handleJobSubmission(request);
}
}
public void handleJobSubmission(TSchedulingRequest request) throws TException {
LOG.debug(Logging.functionCall(request));
long start = System.currentTimeMillis();
String requestId = getRequestId();
String user = "";
if (request.getUser() != null && request.getUser().getUser() != null) {
user = request.getUser().getUser();
}
String description = "";
if (request.getDescription() != null) {
description = request.getDescription();
}
String app = request.getApp();
List<TTaskSpec> tasks = request.getTasks();
Set<InetSocketAddress> backends = state.getBackends(app);
LOG.debug("NumBackends: " + backends.size());
boolean constrained = false;
for (TTaskSpec task : tasks) {
constrained = constrained || (
task.preference != null &&
task.preference.nodes != null &&
!task.preference.nodes.isEmpty());
}
// Logging the address here is somewhat redundant, since all of the
// messages in this particular log file come from the same address.
// However, it simplifies the process of aggregating the logs, and will
// also be useful when we support multiple daemons running on a single
// machine.
AUDIT_LOG.info(Logging.auditEventString("arrived", requestId,
request.getTasks().size(),
address.getHost(), address.getPort(),
user, description, constrained));
TaskPlacer taskPlacer;
if (constrained) {
if (request.isSetProbeRatio()) {
taskPlacer = new ConstrainedTaskPlacer(requestId, request.getProbeRatio());
} else {
taskPlacer = new ConstrainedTaskPlacer(requestId, defaultProbeRatioConstrained);
}
} else {
if (request.isSetProbeRatio()) {
taskPlacer = new UnconstrainedTaskPlacer(requestId, request.getProbeRatio());
} else {
taskPlacer = new UnconstrainedTaskPlacer(requestId, defaultProbeRatioUnconstrained);
}
}
requestTaskPlacers.put(requestId, taskPlacer);
Map<InetSocketAddress, TEnqueueTaskReservationsRequest> enqueueTaskReservationsRequests;
enqueueTaskReservationsRequests = taskPlacer.getEnqueueTaskReservationsRequests(
request, requestId, backends, address);
// Request to enqueue a task at each of the selected nodes.
for (Entry<InetSocketAddress, TEnqueueTaskReservationsRequest> entry :
enqueueTaskReservationsRequests.entrySet()) {
try {
InternalService.AsyncClient client = nodeMonitorClientPool.borrowClient(entry.getKey());
LOG.debug("Launching enqueueTask for request " + requestId + "on node: " + entry.getKey());
AUDIT_LOG.debug(Logging.auditEventString(
"scheduler_launch_enqueue_task", entry.getValue().requestId,
entry.getKey().getAddress().getHostAddress()));
client.enqueueTaskReservations(
entry.getValue(), new EnqueueTaskReservationsCallback(requestId, entry.getKey()));
} catch (Exception e) {
LOG.error("Error enqueuing task on node " + entry.getKey().toString() + ":" + e);
}
}
long end = System.currentTimeMillis();
LOG.debug("All tasks enqueued for request " + requestId + "; returning. Total time: " +
(end - start) + " milliseconds");
}
public List<TTaskLaunchSpec> getTask(
String requestId, THostPort nodeMonitorAddress) {
/* TODO: Consider making this synchronized to avoid the need for synchronization in
* the task placers (although then we'd lose the ability to parallelize over task placers). */
LOG.debug(Logging.functionCall(requestId, nodeMonitorAddress));
TaskPlacer taskPlacer = requestTaskPlacers.get(requestId);
if (taskPlacer == null) {
LOG.debug("Received getTask() request for request " + requestId + ", which had no more " +
"unplaced tasks");
return Lists.newArrayList();
}
synchronized(taskPlacer) {
List<TTaskLaunchSpec> taskLaunchSpecs = taskPlacer.assignTask(nodeMonitorAddress);
if (taskLaunchSpecs == null || taskLaunchSpecs.size() > 1) {
LOG.error("Received invalid task placement for request " + requestId + ": " +
taskLaunchSpecs.toString());
return Lists.newArrayList();
} else if (taskLaunchSpecs.size() == 1) {
AUDIT_LOG.info(Logging.auditEventString("scheduler_assigned_task", requestId,
taskLaunchSpecs.get(0).taskId,
nodeMonitorAddress.getHost()));
} else {
AUDIT_LOG.info(Logging.auditEventString("scheduler_get_task_no_task", requestId,
nodeMonitorAddress.getHost()));
}
if (taskPlacer.allTasksPlaced()) {
LOG.debug("All tasks placed for request " + requestId);
requestTaskPlacers.remove(requestId);
if (useCancellation) {
Set<THostPort> outstandingNodeMonitors =
taskPlacer.getOutstandingNodeMonitorsForCancellation();
for (THostPort nodeMonitorToCancel : outstandingNodeMonitors) {
cancellationService.addCancellation(requestId, nodeMonitorToCancel);
}
}
}
return taskLaunchSpecs;
}
}
/**
* Returns an ID that identifies a request uniquely (across all Sparrow schedulers).
*
* This should only be called once for each request (it will return a different
* identifier if called a second time).
*
* TODO: Include the port number, so this works when there are multiple schedulers
* running on a single machine.
*/
private String getRequestId() {
/* The request id is a string that includes the IP address of this scheduler followed
* by the counter. We use a counter rather than a hash of the request because there
* may be multiple requests to run an identical job. */
return String.format("%s_%d", Network.getIPAddress(conf), counter.getAndIncrement());
}
private class sendFrontendMessageCallback implements
AsyncMethodCallback<frontendMessage_call> {
private InetSocketAddress frontendSocket;
private FrontendService.AsyncClient client;
public sendFrontendMessageCallback(InetSocketAddress socket, FrontendService.AsyncClient client) {
frontendSocket = socket;
this.client = client;
}
public void onComplete(frontendMessage_call response) {
try { frontendClientPool.returnClient(frontendSocket, client); }
catch (Exception e) { LOG.error(e); }
}
public void onError(Exception exception) {
// Do not return error client to pool
LOG.error("Error sending frontend message callback: " + exception);
}
}
public void sendFrontendMessage(String app, TFullTaskId taskId,
int status, ByteBuffer message) {
LOG.debug(Logging.functionCall(app, taskId, message));
InetSocketAddress frontend = frontendSockets.get(app);
if (frontend == null) {
LOG.error("Requested message sent to unregistered app: " + app);
}
try {
FrontendService.AsyncClient client = frontendClientPool.borrowClient(frontend);
client.frontendMessage(taskId, status, message,
new sendFrontendMessageCallback(frontend, client));
} catch (IOException e) {
LOG.error("Error launching message on frontend: " + app, e);
} catch (TException e) {
LOG.error("Error launching message on frontend: " + app, e);
} catch (Exception e) {
LOG.error("Error launching message on frontend: " + app, e);
}
}
}