Package com.facebook.presto.execution

Source Code of com.facebook.presto.execution.SqlTaskManager

/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.execution;

import com.facebook.presto.OutputBuffers;
import com.facebook.presto.TaskSource;
import com.facebook.presto.client.FailureInfo;
import com.facebook.presto.event.query.QueryMonitor;
import com.facebook.presto.execution.SharedBuffer.QueueState;
import com.facebook.presto.operator.TaskContext;
import com.facebook.presto.sql.analyzer.Session;
import com.facebook.presto.sql.planner.LocalExecutionPlanner;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.plan.PlanNodeId;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.Uninterruptibles;
import io.airlift.concurrent.ThreadPoolExecutorMBean;
import io.airlift.log.Logger;
import io.airlift.stats.CounterStat;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import org.joda.time.DateTime;
import org.weakref.jmx.Managed;
import org.weakref.jmx.Nested;

import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import javax.inject.Inject;

import java.net.URI;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

import static com.facebook.presto.util.Threads.threadsNamed;
import static com.google.common.base.Preconditions.checkNotNull;

public class SqlTaskManager
        implements TaskManager
{
    private static final Logger log = Logger.get(SqlTaskManager.class);

    private final DataSize maxBufferSize;

    private final ExecutorService taskNotificationExecutor;
    private final ThreadPoolExecutorMBean taskNotificationExecutorMBean;

    private final TaskExecutor taskExecutor;

    private final ScheduledExecutorService taskManagementExecutor;
    private final ThreadPoolExecutorMBean taskManagementExecutorMBean;

    private final LocalExecutionPlanner planner;
    private final LocationFactory locationFactory;
    private final QueryMonitor queryMonitor;
    private final DataSize maxTaskMemoryUsage;
    private final DataSize operatorPreAllocatedMemory;
    private final Duration infoCacheTime;
    private final Duration clientTimeout;
    private final boolean cpuTimerEnabled;

    private final ConcurrentMap<TaskId, TaskInfo> taskInfos = new ConcurrentHashMap<>();
    private final ConcurrentMap<TaskId, TaskExecution> tasks = new ConcurrentHashMap<>();

    private final CounterStat inputDataSize = new CounterStat();
    private final CounterStat finishedInputDataSize = new CounterStat();

    private final CounterStat inputPositions = new CounterStat();
    private final CounterStat finishedInputPositions = new CounterStat();

    private final CounterStat outputDataSize = new CounterStat();
    private final CounterStat finishedOutputDataSize = new CounterStat();

    private final CounterStat outputPositions = new CounterStat();
    private final CounterStat finishedOutputPositions = new CounterStat();

    @Inject
    public SqlTaskManager(
            LocalExecutionPlanner planner,
            LocationFactory locationFactory,
            TaskExecutor taskExecutor,
            QueryMonitor queryMonitor,
            TaskManagerConfig config)
    {
        this.planner = checkNotNull(planner, "planner is null");
        this.locationFactory = checkNotNull(locationFactory, "locationFactory is null");
        this.taskExecutor = checkNotNull(taskExecutor, "taskExecutor is null");
        this.queryMonitor = checkNotNull(queryMonitor, "queryMonitor is null");

        checkNotNull(config, "config is null");
        this.maxBufferSize = config.getSinkMaxBufferSize();
        this.maxTaskMemoryUsage = config.getMaxTaskMemoryUsage();
        this.operatorPreAllocatedMemory = config.getOperatorPreAllocatedMemory();
        this.infoCacheTime = config.getInfoMaxAge();
        this.clientTimeout = config.getClientTimeout();
        this.cpuTimerEnabled = config.isTaskCpuTimerEnabled();

        taskNotificationExecutor = Executors.newCachedThreadPool(threadsNamed("task-notification-%d"));
        taskNotificationExecutorMBean = new ThreadPoolExecutorMBean((ThreadPoolExecutor) taskNotificationExecutor);

        taskManagementExecutor = Executors.newScheduledThreadPool(5, threadsNamed("task-management-%d"));
        taskManagementExecutorMBean = new ThreadPoolExecutorMBean((ThreadPoolExecutor) taskManagementExecutor);
    }

    @PostConstruct
    public void start()
    {
        taskManagementExecutor.scheduleAtFixedRate(new Runnable()
        {
            @Override
            public void run()
            {
                try {
                    removeOldTasks();
                }
                catch (Throwable e) {
                    log.warn(e, "Error removing old tasks");
                }
                try {
                    failAbandonedTasks();
                }
                catch (Throwable e) {
                    log.warn(e, "Error canceling abandoned tasks");
                }
            }
        }, 200, 200, TimeUnit.MILLISECONDS);

        taskManagementExecutor.scheduleAtFixedRate(new Runnable()
        {
            @Override
            public void run()
            {
                try {
                    updateStats();
                }
                catch (Throwable e) {
                    log.warn(e, "Error updating stats");
                }
            }
        }, 0, 1, TimeUnit.SECONDS);
    }

    @PreDestroy
    public void stop()
    {
        taskNotificationExecutor.shutdownNow();
        taskManagementExecutor.shutdownNow();
    }

    @Managed
    @Nested
    public CounterStat getInputDataSize()
    {
        return inputDataSize;
    }

    @Managed
    @Nested
    public CounterStat getInputPositions()
    {
        return inputPositions;
    }

    @Managed
    @Nested
    public CounterStat getOutputDataSize()
    {
        return outputDataSize;
    }

    @Managed
    @Nested
    public CounterStat getOutputPositions()
    {
        return outputPositions;
    }

    @Managed(description = "Task notification executor")
    @Nested
    public ThreadPoolExecutorMBean getTaskNotificationExecutor()
    {
        return taskNotificationExecutorMBean;
    }

    @Managed(description = "Task garbage collector executor")
    @Nested
    public ThreadPoolExecutorMBean getTaskManagementExecutor()
    {
        return taskManagementExecutorMBean;
    }

    @Override
    public List<TaskInfo> getAllTaskInfo(boolean full)
    {
        Map<TaskId, TaskInfo> taskInfos = new HashMap<>();
        for (TaskExecution taskExecution : tasks.values()) {
            taskInfos.put(taskExecution.getTaskId(), getTaskInfo(taskExecution, full));
        }
        taskInfos.putAll(this.taskInfos);
        return ImmutableList.copyOf(taskInfos.values());
    }

    @Override
    public void waitForStateChange(TaskId taskId, TaskState currentState, Duration maxWait)
            throws InterruptedException
    {
        checkNotNull(taskId, "taskId is null");
        checkNotNull(maxWait, "maxWait is null");

        TaskExecution taskExecution = tasks.get(taskId);
        if (taskExecution == null) {
            return;
        }

        taskExecution.recordHeartbeat();
        taskExecution.waitForStateChange(currentState, maxWait);
    }

    @Override
    public TaskInfo getTaskInfo(TaskId taskId, boolean full)
    {
        checkNotNull(taskId, "taskId is null");

        TaskExecution taskExecution = tasks.get(taskId);
        if (taskExecution != null) {
            taskExecution.recordHeartbeat();
            return getTaskInfo(taskExecution, full);
        }

        TaskInfo taskInfo = taskInfos.get(taskId);
        if (taskInfo == null) {
            throw new NoSuchElementException("Unknown query task " + taskId);
        }
        return taskInfo;
    }

    private TaskInfo getTaskInfo(TaskExecution taskExecution, boolean full)
    {
        TaskInfo taskInfo = taskExecution.getTaskInfo(full);
        if (taskInfo.getState().isDone()) {
            if (taskInfo.getStats().getEndTime() == null) {
                log.warn("Task %s is in done state %s but does not have an end time", taskInfo.getTaskId(), taskInfo.getState());
            }

            // cache task info
            taskInfos.putIfAbsent(taskInfo.getTaskId(), taskInfo);

            // record input and output stats
            TaskContext taskContext = taskExecution.getTaskContext();
            finishedInputDataSize.merge(taskContext.getInputDataSize());
            finishedInputPositions.merge(taskContext.getInputPositions());
            finishedOutputDataSize.merge(taskContext.getOutputDataSize());
            finishedOutputPositions.merge(taskContext.getOutputPositions());

            // remove task (after caching the task info)
            tasks.remove(taskInfo.getTaskId());
        }
        return taskInfo;
    }

    @Override
    public TaskInfo updateTask(Session session, TaskId taskId, PlanFragment fragment, List<TaskSource> sources, OutputBuffers outputBuffers)
    {
        URI location = locationFactory.createLocalTaskLocation(taskId);

        TaskExecution taskExecution;
        synchronized (this) {
            taskExecution = tasks.get(taskId);
            if (taskExecution == null) {
                // is task already complete?
                TaskInfo taskInfo = taskInfos.get(taskId);
                if (taskInfo != null) {
                    return taskInfo;
                }

                taskExecution = SqlTaskExecution.createSqlTaskExecution(session,
                        taskId,
                        location,
                        fragment,
                        sources,
                        outputBuffers,
                        planner,
                        maxBufferSize,
                        taskExecutor,
                        taskNotificationExecutor,
                        maxTaskMemoryUsage,
                        operatorPreAllocatedMemory,
                        queryMonitor,
                        cpuTimerEnabled
                );
                tasks.put(taskId, taskExecution);
            }
        }

        taskExecution.recordHeartbeat();
        taskExecution.addSources(sources);
        taskExecution.addResultQueue(outputBuffers);

        return getTaskInfo(taskExecution, false);
    }

    @Override
    public BufferResult getTaskResults(TaskId taskId, String outputName, long startingSequenceId, DataSize maxSize, Duration maxWaitTime)
            throws InterruptedException
    {
        checkNotNull(taskId, "taskId is null");
        checkNotNull(outputName, "outputName is null");
        Preconditions.checkArgument(maxSize.toBytes() > 0, "maxSize must be at least 1 byte");
        checkNotNull(maxWaitTime, "maxWaitTime is null");

        TaskExecution taskExecution = tasks.get(taskId);
        if (taskExecution == null) {
            TaskInfo taskInfo = taskInfos.get(taskId);
            if (taskInfo == null) {
                throw new NoSuchElementException("Unknown query task " + taskId);
            }
            else if (taskInfo.getState() == TaskState.FAILED) {
                // for a failed query, do not return a closed buffer as a
                // closed buffer signals to upstream tasks that everything
                // finished cleanly
                Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS);
                return BufferResult.emptyResults(startingSequenceId, false);
            }
            else {
                // query is finished
                return BufferResult.emptyResults(taskInfo.getOutputBuffers().getMasterSequenceId(), true);
            }
        }
        taskExecution.recordHeartbeat();
        return taskExecution.getResults(outputName, startingSequenceId, maxSize, maxWaitTime);
    }

    @Override
    public TaskInfo abortTaskResults(TaskId taskId, String outputId)
    {
        checkNotNull(taskId, "taskId is null");
        checkNotNull(outputId, "outputId is null");

        TaskExecution taskExecution = tasks.get(taskId);
        if (taskExecution == null) {
            TaskInfo taskInfo = taskInfos.get(taskId);
            if (taskInfo != null) {
                // todo this is not safe since task can be expired at any time
                // task was finished early, so the new split should be ignored
                return taskInfo;
            }
            else {
                throw new NoSuchElementException("Unknown query task " + taskId);
            }
        }
        log.debug("Aborting task %s output %s", taskId, outputId);
        taskExecution.abortResults(outputId);

        return getTaskInfo(taskExecution, false);
    }

    @Override
    public TaskInfo cancelTask(TaskId taskId)
    {
        checkNotNull(taskId, "taskId is null");

        TaskExecution taskExecution = tasks.get(taskId);
        if (taskExecution == null) {
            TaskInfo taskInfo = taskInfos.get(taskId);
            if (taskInfo == null) {
                // task does not exist yet, mark the task as canceled, so later if a late request
                // comes in to create the task, the task remains canceled
                TaskContext taskContext = new TaskContext(
                        new TaskStateMachine(taskId, taskNotificationExecutor),
                        taskManagementExecutor,
                        null,
                        maxTaskMemoryUsage,
                        operatorPreAllocatedMemory,
                        cpuTimerEnabled);

                taskInfo = new TaskInfo(taskId,
                        Long.MAX_VALUE,
                        TaskState.CANCELED,
                        URI.create("unknown"),
                        DateTime.now(),
                        new SharedBufferInfo(QueueState.FINISHED, 0, 0, ImmutableList.<BufferInfo>of()),
                        ImmutableSet.<PlanNodeId>of(),
                        taskContext.getTaskStats(),
                        ImmutableList.<FailureInfo>of());
                TaskInfo existingTaskInfo = taskInfos.putIfAbsent(taskId, taskInfo);
                if (existingTaskInfo != null) {
                    taskInfo = existingTaskInfo;
                }
            }
            return taskInfo;
        }

        // make sure task is finished
        taskExecution.cancel();

        return getTaskInfo(taskExecution, false);
    }

    public void removeOldTasks()
    {
        DateTime oldestAllowedTask = DateTime.now().minus(infoCacheTime.toMillis());
        for (TaskInfo taskInfo : taskInfos.values()) {
            try {
                DateTime endTime = taskInfo.getStats().getEndTime();
                if (endTime != null && endTime.isBefore(oldestAllowedTask)) {
                    taskInfos.remove(taskInfo.getTaskId());
                }
            }
            catch (RuntimeException e) {
                log.warn(e, "Error while inspecting age of complete task %s", taskInfo.getTaskId());
            }
        }
    }

    public void failAbandonedTasks()
    {
        DateTime now = DateTime.now();
        DateTime oldestAllowedHeartbeat = now.minus(clientTimeout.toMillis());
        for (TaskExecution taskExecution : tasks.values()) {
            try {
                TaskInfo taskInfo = taskExecution.getTaskInfo(false);
                if (taskInfo.getState().isDone()) {
                    continue;
                }
                DateTime lastHeartbeat = taskInfo.getLastHeartbeat();
                if (lastHeartbeat != null && lastHeartbeat.isBefore(oldestAllowedHeartbeat)) {
                    log.info("Failing abandoned task %s", taskExecution.getTaskId());
                    taskExecution.fail(new AbandonedException("Task " + taskInfo.getTaskId(), lastHeartbeat, now));

                    // trigger caching
                    getTaskInfo(taskExecution, false);
                }
            }
            catch (RuntimeException e) {
                log.warn(e, "Error while inspecting age of task %s", taskExecution.getTaskId());
            }
        }
    }

    //
    // Jmxutils only calls nested getters once, so we are forced to maintain a single
    // instance and periodically recalculate the stats.
    //
    @SuppressWarnings("deprecation")
    private void updateStats()
    {
        CounterStat temp;

        temp = new CounterStat();
        temp.merge(finishedInputDataSize);
        for (TaskExecution taskExecution : tasks.values()) {
            TaskContext taskContext = taskExecution.getTaskContext();
            temp.merge(taskContext.getInputDataSize());
        }
        inputDataSize.resetTo(temp);

        temp = new CounterStat();
        temp.merge(finishedInputPositions);
        for (TaskExecution taskExecution : tasks.values()) {
            TaskContext taskContext = taskExecution.getTaskContext();
            temp.merge(taskContext.getInputPositions());
        }
        inputPositions.resetTo(temp);

        temp = new CounterStat();
        temp.merge(finishedOutputDataSize);
        for (TaskExecution taskExecution : tasks.values()) {
            TaskContext taskContext = taskExecution.getTaskContext();
            temp.merge(taskContext.getOutputDataSize());
        }
        outputDataSize.resetTo(temp);

        temp = new CounterStat();
        temp.merge(finishedOutputPositions);
        for (TaskExecution taskExecution : tasks.values()) {
            TaskContext taskContext = taskExecution.getTaskContext();
            temp.merge(taskContext.getOutputPositions());
        }
        outputPositions.resetTo(temp);
    }
}
TOP

Related Classes of com.facebook.presto.execution.SqlTaskManager

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.