Package org.apache.flink.runtime.operators.testutils

Source Code of org.apache.flink.runtime.operators.testutils.MockEnvironment$MockInputGate

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/


package org.apache.flink.runtime.operators.testutils;

import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.FutureTask;

import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.Path;
import org.apache.flink.core.io.IOReadableWritable;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.network.Buffer;
import org.apache.flink.runtime.io.network.bufferprovider.BufferAvailabilityListener;
import org.apache.flink.runtime.io.network.bufferprovider.BufferProvider;
import org.apache.flink.runtime.io.network.bufferprovider.GlobalBufferPool;
import org.apache.flink.runtime.io.network.bufferprovider.LocalBufferPoolOwner;
import org.apache.flink.runtime.io.network.channels.ChannelID;
import org.apache.flink.runtime.io.network.gates.GateID;
import org.apache.flink.runtime.io.network.gates.InputChannelResult;
import org.apache.flink.runtime.io.network.gates.InputGate;
import org.apache.flink.runtime.io.network.gates.OutputGate;
import org.apache.flink.runtime.io.network.gates.RecordAvailabilityListener;
import org.apache.flink.runtime.io.network.serialization.AdaptiveSpanningRecordDeserializer;
import org.apache.flink.runtime.io.network.serialization.RecordDeserializer;
import org.apache.flink.runtime.io.network.serialization.RecordDeserializer.DeserializationResult;
import org.apache.flink.runtime.jobgraph.JobID;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
import org.apache.flink.runtime.memorymanager.DefaultMemoryManager;
import org.apache.flink.runtime.memorymanager.MemoryManager;
import org.apache.flink.runtime.protocols.AccumulatorProtocol;
import org.apache.flink.types.Record;
import org.apache.flink.util.MutableObjectIterator;

public class MockEnvironment implements Environment, BufferProvider, LocalBufferPoolOwner {
 
  private final MemoryManager memManager;

  private final IOManager ioManager;

  private final InputSplitProvider inputSplitProvider;

  private final Configuration jobConfiguration;

  private final Configuration taskConfiguration;

  private final List<InputGate<Record>> inputs;

  private final List<OutputGate> outputs;

  private final JobID jobID = new JobID();

  private final Buffer mockBuffer;
 
  private final BroadcastVariableManager bcVarManager = new BroadcastVariableManager();
 

  public MockEnvironment(long memorySize, MockInputSplitProvider inputSplitProvider, int bufferSize) {
    this.jobConfiguration = new Configuration();
    this.taskConfiguration = new Configuration();
    this.inputs = new LinkedList<InputGate<Record>>();
    this.outputs = new LinkedList<OutputGate>();

    this.memManager = new DefaultMemoryManager(memorySize, 1);
    this.ioManager = new IOManager(System.getProperty("java.io.tmpdir"));
    this.inputSplitProvider = inputSplitProvider;
    this.mockBuffer = new Buffer(new MemorySegment(new byte[bufferSize]), bufferSize, null);
  }

  public void addInput(MutableObjectIterator<Record> inputIterator) {
    int id = inputs.size();
    inputs.add(new MockInputGate(id, inputIterator));
  }

  public void addOutput(List<Record> outputList) {
    int id = outputs.size();
    outputs.add(new MockOutputGate(id, outputList));
  }

  @Override
  public Configuration getTaskConfiguration() {
    return this.taskConfiguration;
  }

  @Override
  public MemoryManager getMemoryManager() {
    return this.memManager;
  }
 
  @Override
  public IOManager getIOManager() {
    return this.ioManager;
  }

  @Override
  public JobID getJobID() {
    return this.jobID;
  }

  @Override
  public Buffer requestBuffer(int minBufferSize) throws IOException {
    return mockBuffer;
  }

  @Override
  public Buffer requestBufferBlocking(int minBufferSize) throws IOException, InterruptedException {
    return mockBuffer;
  }

  @Override
  public int getBufferSize() {
    return this.mockBuffer.size();
  }

  @Override
  public BufferAvailabilityRegistration registerBufferAvailabilityListener(BufferAvailabilityListener listener) {
    return BufferAvailabilityRegistration.FAILED_BUFFER_POOL_DESTROYED;
  }

  @Override
  public int getNumberOfChannels() {
    return 1;
  }

  @Override
  public void setDesignatedNumberOfBuffers(int numBuffers) {

  }

  @Override
  public void clearLocalBufferPool() {

  }

  @Override
  public void registerGlobalBufferPool(GlobalBufferPool globalBufferPool) {

  }

  @Override
  public void logBufferUtilization() {

  }

  @Override
  public void reportAsynchronousEvent() {

  }

  private static class MockInputGate extends InputGate<Record> {
   
    private MutableObjectIterator<Record> it;

    public MockInputGate(int id, MutableObjectIterator<Record> it) {
      super(new JobID(), new GateID(), id);
      this.it = it;
    }

    @Override
    public void registerRecordAvailabilityListener(final RecordAvailabilityListener<Record> listener) {
      super.registerRecordAvailabilityListener(listener);
      this.notifyRecordIsAvailable(0);
    }
   
    @Override
    public InputChannelResult readRecord(Record target) throws IOException, InterruptedException {

      if ((target = it.next(target)) != null) {
        // everything comes from the same source channel and buffer in this mock
        notifyRecordIsAvailable(0);
        return InputChannelResult.INTERMEDIATE_RECORD_FROM_BUFFER;
      } else {
        return InputChannelResult.END_OF_STREAM;
      }
    }
  }

  private class MockOutputGate extends OutputGate {
   
    private List<Record> out;

    private RecordDeserializer<Record> deserializer;

    private Record record;

    public MockOutputGate(int index, List<Record> outList) {
      super(new JobID(), new GateID(), index);
      this.out = outList;
      this.deserializer = new AdaptiveSpanningRecordDeserializer<Record>();
      this.record = new Record();
    }

    @Override
    public void sendBuffer(Buffer buffer, int targetChannel) throws IOException, InterruptedException {

      this.deserializer.setNextMemorySegment(MockEnvironment.this.mockBuffer.getMemorySegment(), MockEnvironment.this.mockBuffer.size());

      while (this.deserializer.hasUnfinishedData()) {
        DeserializationResult result = this.deserializer.getNextRecord(this.record);

        if (result.isFullRecord()) {
          this.out.add(this.record.createCopy());
        }

        if (result == DeserializationResult.LAST_RECORD_FROM_BUFFER ||
          result == DeserializationResult.PARTIAL_RECORD) {
          break;
        }
      }
    }

    @Override
    public int getNumChannels() {
      return 1;
    }
  }

  @Override
  public Configuration getJobConfiguration() {
    return this.jobConfiguration;
  }

  @Override
  public int getCurrentNumberOfSubtasks() {
    return 1;
  }

  @Override
  public int getIndexInSubtaskGroup() {
    return 0;
  }

  @Override
  public InputSplitProvider getInputSplitProvider() {
    return this.inputSplitProvider;
  }

  @Override
  public String getTaskName() {
    return null;
  }

  @Override
  public GateID getNextUnboundInputGateID() {
    return null;
  }

  @Override
  public int getNumberOfOutputGates() {
    return this.outputs.size();
  }

  @Override
  public int getNumberOfInputGates() {
    return this.inputs.size();
  }

  @Override
  public Set<ChannelID> getOutputChannelIDs() {
    throw new IllegalStateException("getOutputChannelIDs called on MockEnvironment");
  }

  @Override
  public Set<ChannelID> getInputChannelIDs() {
    throw new IllegalStateException("getInputChannelIDs called on MockEnvironment");
  }

  @Override
  public Set<GateID> getOutputGateIDs() {
    throw new IllegalStateException("getOutputGateIDs called on MockEnvironment");
  }

  @Override
  public Set<GateID> getInputGateIDs() {
    throw new IllegalStateException("getInputGateIDs called on MockEnvironment");
  }

  @Override
  public Set<ChannelID> getOutputChannelIDsOfGate(final GateID gateID) {
    throw new IllegalStateException("getOutputChannelIDsOfGate called on MockEnvironment");
  }

  @Override
  public Set<ChannelID> getInputChannelIDsOfGate(final GateID gateID) {
    throw new IllegalStateException("getInputChannelIDsOfGate called on MockEnvironment");
  }

  @Override
  public OutputGate createAndRegisterOutputGate() {
    return this.outputs.remove(0);
  }

  @SuppressWarnings("unchecked")
  @Override
  public <T extends IOReadableWritable> InputGate<T> createAndRegisterInputGate() {
    return (InputGate<T>) this.inputs.remove(0);
  }

  @Override
  public int getNumberOfOutputChannels() {
    return this.outputs.size();
  }

  @Override
  public int getNumberOfInputChannels() {
    return this.inputs.size();
  }
 
  @Override
  public AccumulatorProtocol getAccumulatorProtocolProxy() {
    throw new UnsupportedOperationException(
        "getAccumulatorProtocolProxy() is not supported by MockEnvironment");
  }

  @Override
  public ClassLoader getUserClassLoader() {
    return getClass().getClassLoader();
  }

  @Override
  public BufferProvider getOutputBufferProvider() {
    return this;
  }

  @Override
  public Map<String, FutureTask<Path>> getCopyTask() {
    return null;
  }

  @Override
  public JobVertexID getJobVertexId() {
    return new JobVertexID(new byte[16]);
  }
 
  @Override
  public BroadcastVariableManager getBroadcastVariableManager() {
    return this.bcVarManager;
  }
}
TOP

Related Classes of org.apache.flink.runtime.operators.testutils.MockEnvironment$MockInputGate

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.