Package eu.stratosphere.test.broadcastvars

Source Code of eu.stratosphere.test.broadcastvars.KMeansIterativeNepheleITCase

/***********************************************************************************************************************
*
* Copyright (C) 2010 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*
**********************************************************************************************************************/
package eu.stratosphere.test.broadcastvars;

import eu.stratosphere.nephele.jobgraph.DistributionPattern;
import org.apache.log4j.Level;

import eu.stratosphere.api.common.operators.util.UserCodeObjectWrapper;
import eu.stratosphere.api.common.typeutils.TypeComparatorFactory;
import eu.stratosphere.api.common.typeutils.TypeSerializerFactory;
import eu.stratosphere.api.java.record.io.CsvInputFormat;
import eu.stratosphere.configuration.Configuration;
import eu.stratosphere.core.fs.Path;
import eu.stratosphere.nephele.jobgraph.JobGraph;
import eu.stratosphere.nephele.jobgraph.JobGraphDefinitionException;
import eu.stratosphere.nephele.jobgraph.JobInputVertex;
import eu.stratosphere.nephele.jobgraph.JobOutputVertex;
import eu.stratosphere.nephele.jobgraph.JobTaskVertex;
import eu.stratosphere.pact.runtime.iterative.task.IterationHeadPactTask;
import eu.stratosphere.pact.runtime.iterative.task.IterationIntermediatePactTask;
import eu.stratosphere.pact.runtime.iterative.task.IterationTailPactTask;
import eu.stratosphere.api.java.typeutils.runtime.record.RecordComparatorFactory;
import eu.stratosphere.api.java.typeutils.runtime.record.RecordSerializerFactory;
import eu.stratosphere.pact.runtime.shipping.ShipStrategyType;
import eu.stratosphere.pact.runtime.task.CollectorMapDriver;
import eu.stratosphere.pact.runtime.task.DriverStrategy;
import eu.stratosphere.pact.runtime.task.NoOpDriver;
import eu.stratosphere.pact.runtime.task.GroupReduceDriver;
import eu.stratosphere.pact.runtime.task.chaining.ChainedCollectorMapDriver;
import eu.stratosphere.pact.runtime.task.util.LocalStrategy;
import eu.stratosphere.pact.runtime.task.util.TaskConfig;
import eu.stratosphere.runtime.io.channels.ChannelType;
import eu.stratosphere.test.iterative.nephele.JobGraphUtils;
import eu.stratosphere.test.recordJobs.kmeans.KMeansBroadcast.PointBuilder;
import eu.stratosphere.test.recordJobs.kmeans.KMeansBroadcast.RecomputeClusterCenter;
import eu.stratosphere.test.recordJobs.kmeans.KMeansBroadcast.SelectNearestCenter;
import eu.stratosphere.test.recordJobs.kmeans.KMeansBroadcast.PointOutFormat;
import eu.stratosphere.test.testdata.KMeansData;
import eu.stratosphere.test.util.RecordAPITestBase;
import eu.stratosphere.types.DoubleValue;
import eu.stratosphere.types.IntValue;
import eu.stratosphere.util.LogUtils;


public class KMeansIterativeNepheleITCase extends RecordAPITestBase {

  private static final int ITERATION_ID = 42;
 
  private static final int MEMORY_PER_CONSUMER = 2;
 
  protected String dataPath;
  protected String clusterPath;
  protected String resultPath;

 
  public KMeansIterativeNepheleITCase() {
    LogUtils.initializeDefaultConsoleLogger(Level.ERROR);
  }
 
  @Override
  protected void preSubmit() throws Exception {
    dataPath = createTempFile("datapoints.txt", KMeansData.DATAPOINTS);
    clusterPath = createTempFile("initial_centers.txt", KMeansData.INITIAL_CENTERS);
    resultPath = getTempDirPath("result");
  }
 
  @Override
  protected void postSubmit() throws Exception {
    compareResultsByLinesInMemory(KMeansData.CENTERS_AFTER_20_ITERATIONS_SINGLE_DIGIT, resultPath);
  }

  @Override
  protected JobGraph getJobGraph() throws Exception {
    return createJobGraph(dataPath, clusterPath, this.resultPath, 4, 20);
  }

  // -------------------------------------------------------------------------------------------------------------
  // Job vertex builder methods
  // -------------------------------------------------------------------------------------------------------------

  private static JobInputVertex createPointsInput(JobGraph jobGraph, String pointsPath, int numSubTasks, TypeSerializerFactory<?> serializer) {
    @SuppressWarnings("unchecked")
    CsvInputFormat pointsInFormat = new CsvInputFormat('|', IntValue.class, DoubleValue.class, DoubleValue.class, DoubleValue.class);
    JobInputVertex pointsInput = JobGraphUtils.createInput(pointsInFormat, pointsPath, "[Points]", jobGraph, numSubTasks, numSubTasks);
    {
      TaskConfig taskConfig = new TaskConfig(pointsInput.getConfiguration());
      taskConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
      taskConfig.setOutputSerializer(serializer);
     
      TaskConfig chainedMapper = new TaskConfig(new Configuration());
      chainedMapper.setDriverStrategy(DriverStrategy.COLLECTOR_MAP);
      chainedMapper.setStubWrapper(new UserCodeObjectWrapper<PointBuilder>(new PointBuilder()));
      chainedMapper.addOutputShipStrategy(ShipStrategyType.FORWARD);
      chainedMapper.setOutputSerializer(serializer);
     
      taskConfig.addChainedTask(ChainedCollectorMapDriver.class, chainedMapper, "Build points");
    }

    return pointsInput;
  }

  private static JobInputVertex createCentersInput(JobGraph jobGraph, String centersPath, int numSubTasks, TypeSerializerFactory<?> serializer) {
    @SuppressWarnings("unchecked")
    CsvInputFormat modelsInFormat = new CsvInputFormat('|', IntValue.class, DoubleValue.class, DoubleValue.class, DoubleValue.class);
    JobInputVertex modelsInput = JobGraphUtils.createInput(modelsInFormat, centersPath, "[Models]", jobGraph, numSubTasks, numSubTasks);

    {
      TaskConfig taskConfig = new TaskConfig(modelsInput.getConfiguration());
      taskConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
      taskConfig.setOutputSerializer(serializer);

      TaskConfig chainedMapper = new TaskConfig(new Configuration());
      chainedMapper.setDriverStrategy(DriverStrategy.COLLECTOR_MAP);
      chainedMapper.setStubWrapper(new UserCodeObjectWrapper<PointBuilder>(new PointBuilder()));
      chainedMapper.addOutputShipStrategy(ShipStrategyType.FORWARD);
      chainedMapper.setOutputSerializer(serializer);
     
      taskConfig.addChainedTask(ChainedCollectorMapDriver.class, chainedMapper, "Build centers");
    }

    return modelsInput;
  }

  private static JobOutputVertex createOutput(JobGraph jobGraph, String resultPath, int numSubTasks, TypeSerializerFactory<?> serializer) {
   
    JobOutputVertex output = JobGraphUtils.createFileOutput(jobGraph, "Output", numSubTasks, numSubTasks);

    {
      TaskConfig taskConfig = new TaskConfig(output.getConfiguration());
      taskConfig.addInputToGroup(0);
      taskConfig.setInputSerializer(serializer, 0);

      PointOutFormat outFormat = new PointOutFormat();
      outFormat.setOutputFilePath(new Path(resultPath));
     
      taskConfig.setStubWrapper(new UserCodeObjectWrapper<PointOutFormat>(outFormat));
    }

    return output;
  }
 
  private static JobTaskVertex createIterationHead(JobGraph jobGraph, int numSubTasks, TypeSerializerFactory<?> serializer) {
    JobTaskVertex head = JobGraphUtils.createTask(IterationHeadPactTask.class, "Iteration Head", jobGraph, numSubTasks, numSubTasks);

    TaskConfig headConfig = new TaskConfig(head.getConfiguration());
    headConfig.setIterationId(ITERATION_ID);
   
    // initial input / partial solution
    headConfig.addInputToGroup(0);
    headConfig.setIterationHeadPartialSolutionOrWorksetInputIndex(0);
    headConfig.setInputSerializer(serializer, 0);
   
    // back channel / iterations
    headConfig.setBackChannelMemory(MEMORY_PER_CONSUMER * JobGraphUtils.MEGABYTE);
   
    // output into iteration. broadcasting the centers
    headConfig.setOutputSerializer(serializer);
    headConfig.addOutputShipStrategy(ShipStrategyType.BROADCAST);
   
    // final output
    TaskConfig headFinalOutConfig = new TaskConfig(new Configuration());
    headFinalOutConfig.setOutputSerializer(serializer);
    headFinalOutConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
    headConfig.setIterationHeadFinalOutputConfig(headFinalOutConfig);
   
    // the sync
    headConfig.setIterationHeadIndexOfSyncOutput(2);
   
    // the driver
    headConfig.setDriver(NoOpDriver.class);
    headConfig.setDriverStrategy(DriverStrategy.UNARY_NO_OP);
   
    return head;
  }
 
  private static JobTaskVertex createMapper(JobGraph jobGraph, int numSubTasks, TypeSerializerFactory<?> inputSerializer,
      TypeSerializerFactory<?> broadcastVarSerializer, TypeSerializerFactory<?> outputSerializer,
      TypeComparatorFactory<?> outputComparator)
  {
    JobTaskVertex mapper = JobGraphUtils.createTask(IterationIntermediatePactTask.class,
      "Map (Select nearest center)", jobGraph, numSubTasks, numSubTasks);
   
    TaskConfig intermediateConfig = new TaskConfig(mapper.getConfiguration());
    intermediateConfig.setIterationId(ITERATION_ID);
   
    intermediateConfig.setDriver(CollectorMapDriver.class);
    intermediateConfig.setDriverStrategy(DriverStrategy.COLLECTOR_MAP);
    intermediateConfig.addInputToGroup(0);
    intermediateConfig.setInputSerializer(inputSerializer, 0);
   
    intermediateConfig.setOutputSerializer(outputSerializer);
    intermediateConfig.addOutputShipStrategy(ShipStrategyType.PARTITION_HASH);
    intermediateConfig.setOutputComparator(outputComparator, 0);

    intermediateConfig.setBroadcastInputName("centers", 0);
    intermediateConfig.addBroadcastInputToGroup(0);
    intermediateConfig.setBroadcastInputSerializer(broadcastVarSerializer, 0);
   
    // the udf
    intermediateConfig.setStubWrapper(new UserCodeObjectWrapper<SelectNearestCenter>(new SelectNearestCenter()));
   
    return mapper;
  }
 
  private static JobTaskVertex createReducer(JobGraph jobGraph, int numSubTasks, TypeSerializerFactory<?> inputSerializer,
      TypeComparatorFactory<?> inputComparator, TypeSerializerFactory<?> outputSerializer)
  {
    // ---------------- the tail (co group) --------------------
   
    JobTaskVertex tail = JobGraphUtils.createTask(IterationTailPactTask.class, "Reduce / Iteration Tail", jobGraph,
      numSubTasks, numSubTasks);
   
    TaskConfig tailConfig = new TaskConfig(tail.getConfiguration());
    tailConfig.setIterationId(ITERATION_ID);
    tailConfig.setIsWorksetUpdate();
   
    // inputs and driver
    tailConfig.setDriver(GroupReduceDriver.class);
    tailConfig.setDriverStrategy(DriverStrategy.SORTED_GROUP_REDUCE);
    tailConfig.addInputToGroup(0);
    tailConfig.setInputSerializer(inputSerializer, 0);   
    tailConfig.setDriverComparator(inputComparator, 0);

    tailConfig.setInputLocalStrategy(0, LocalStrategy.SORT);
    tailConfig.setInputComparator(inputComparator, 0);
    tailConfig.setMemoryInput(0, MEMORY_PER_CONSUMER * JobGraphUtils.MEGABYTE);
    tailConfig.setFilehandlesInput(0, 128);
    tailConfig.setSpillingThresholdInput(0, 0.9f);
   
    // output
    tailConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
    tailConfig.setOutputSerializer(outputSerializer);
   
    // the udf
    tailConfig.setStubWrapper(new UserCodeObjectWrapper<RecomputeClusterCenter>(new RecomputeClusterCenter()));
   
    return tail;
  }
 
  private static JobOutputVertex createSync(JobGraph jobGraph, int numIterations, int dop) {
    JobOutputVertex sync = JobGraphUtils.createSync(jobGraph, dop);
    TaskConfig syncConfig = new TaskConfig(sync.getConfiguration());
    syncConfig.setNumberOfIterations(numIterations);
    syncConfig.setIterationId(ITERATION_ID);
    return sync;
  }

  // -------------------------------------------------------------------------------------------------------------
  // Unified solution set and workset tail update
  // -------------------------------------------------------------------------------------------------------------

  private static JobGraph createJobGraph(String pointsPath, String centersPath, String resultPath, int numSubTasks, int numIterations) throws JobGraphDefinitionException {

    // -- init -------------------------------------------------------------------------------------------------
    final TypeSerializerFactory<?> serializer = RecordSerializerFactory.get();
    @SuppressWarnings("unchecked")
    final TypeComparatorFactory<?> int0Comparator = new RecordComparatorFactory(new int[] { 0 }, new Class[] { IntValue.class });

    JobGraph jobGraph = new JobGraph("KMeans Iterative");

    // -- vertices ---------------------------------------------------------------------------------------------
    JobInputVertex points = createPointsInput(jobGraph, pointsPath, numSubTasks, serializer);
    JobInputVertex centers = createCentersInput(jobGraph, centersPath, numSubTasks, serializer);
   
    JobTaskVertex head = createIterationHead(jobGraph, numSubTasks, serializer);
    JobTaskVertex mapper = createMapper(jobGraph, numSubTasks, serializer, serializer, serializer, int0Comparator);
   
    JobTaskVertex reducer = createReducer(jobGraph, numSubTasks, serializer, int0Comparator, serializer);
   
    JobOutputVertex fakeTailOutput = JobGraphUtils.createFakeOutput(jobGraph, "FakeTailOutput", numSubTasks, numSubTasks);
   
    JobOutputVertex sync = createSync(jobGraph, numIterations, numSubTasks);
   
    JobOutputVertex output = createOutput(jobGraph, resultPath, numSubTasks, serializer);

    // -- edges ------------------------------------------------------------------------------------------------
    JobGraphUtils.connect(points, mapper, ChannelType.NETWORK, DistributionPattern.POINTWISE);
   
    JobGraphUtils.connect(centers, head, ChannelType.NETWORK, DistributionPattern.POINTWISE);
   
    JobGraphUtils.connect(head, mapper, ChannelType.NETWORK, DistributionPattern.BIPARTITE);
    new TaskConfig(mapper.getConfiguration()).setBroadcastGateIterativeWithNumberOfEventsUntilInterrupt(0, numSubTasks);
    new TaskConfig(mapper.getConfiguration()).setInputCached(0, true);
    new TaskConfig(mapper.getConfiguration()).setInputMaterializationMemory(0, MEMORY_PER_CONSUMER * JobGraphUtils.MEGABYTE);

    JobGraphUtils.connect(mapper, reducer, ChannelType.NETWORK, DistributionPattern.BIPARTITE);
    new TaskConfig(reducer.getConfiguration()).setGateIterativeWithNumberOfEventsUntilInterrupt(0, numSubTasks);
   
    JobGraphUtils.connect(reducer, fakeTailOutput, ChannelType.NETWORK, DistributionPattern.POINTWISE);
   
    JobGraphUtils.connect(head, output, ChannelType.NETWORK, DistributionPattern.POINTWISE);
   
    JobGraphUtils.connect(head, sync, ChannelType.NETWORK, DistributionPattern.BIPARTITE);

    // -- instance sharing -------------------------------------------------------------------------------------
    points.setVertexToShareInstancesWith(output);
    centers.setVertexToShareInstancesWith(output);
    head.setVertexToShareInstancesWith(output);
    mapper.setVertexToShareInstancesWith(output);
    reducer.setVertexToShareInstancesWith(output);
    fakeTailOutput.setVertexToShareInstancesWith(output);
    sync.setVertexToShareInstancesWith(output);

    return jobGraph;
  }
}
TOP

Related Classes of eu.stratosphere.test.broadcastvars.KMeansIterativeNepheleITCase

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.