Package org.apache.flink.spargel.java.record

Source Code of org.apache.flink.spargel.java.record.SpargelIteration$MessagingDriver

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.spargel.java.record;

import java.io.IOException;
import java.util.Iterator;

import org.apache.flink.api.common.aggregators.AggregatorRegistry;
import org.apache.flink.api.common.operators.Operator;
import org.apache.flink.api.java.record.functions.CoGroupFunction;
import org.apache.flink.api.java.record.functions.FunctionAnnotation.ConstantFieldsFirst;
import org.apache.flink.api.java.record.operators.CoGroupOperator;
import org.apache.flink.api.java.record.operators.DeltaIteration;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.types.Key;
import org.apache.flink.types.Record;
import org.apache.flink.types.Value;
import org.apache.flink.util.Collector;
import org.apache.flink.util.InstantiationUtil;
import org.apache.flink.util.ReflectionUtil;

public class SpargelIteration {
 
  private static final String DEFAULT_NAME = "<unnamed vertex-centric iteration>";
 
  private final DeltaIteration iteration;
 
  private final Class<? extends Key<?>> vertexKey;
  private final Class<? extends Value> vertexValue;
  private final Class<? extends Value> messageType;
  private final Class<? extends Value> edgeValue;
 
  private final CoGroupOperator vertexUpdater;
  private final CoGroupOperator messager;
 
 
  // ----------------------------------------------------------------------------------
 
  public <VertexKey extends Key<VertexKey>, VertexValue extends Value, Message extends Value, EdgeValue extends Value>
      SpargelIteration(MessagingFunction<VertexKey, VertexValue, Message, EdgeValue> mf,
      VertexUpdateFunction<VertexKey, VertexValue, Message> uf)
  {
    this(mf, uf, DEFAULT_NAME);
  }
 
  public <VertexKey extends Key<VertexKey>, VertexValue extends Value, Message extends Value, EdgeValue extends Value> SpargelIteration(
      MessagingFunction<VertexKey, VertexValue, Message, EdgeValue> mf, VertexUpdateFunction<VertexKey, VertexValue, Message> uf,
      String name)
  {
    // get the types
    this.vertexKey = ReflectionUtil.getTemplateType1(mf.getClass());
    this.vertexValue = ReflectionUtil.getTemplateType2(mf.getClass());
    this.messageType = ReflectionUtil.getTemplateType3(mf.getClass());
    this.edgeValue = ReflectionUtil.getTemplateType4(mf.getClass());
   
    if (vertexKey == null || vertexValue == null || messageType == null || edgeValue == null) {
      throw new RuntimeException();
    }
 
    // instantiate the data flow
    this.iteration = new DeltaIteration(0, name);
   
    this.messager = CoGroupOperator.builder(MessagingDriver.class, vertexKey, 0, 0)
      .input2(iteration.getWorkset())
      .name("Message Sender")
      .build();
    this.vertexUpdater = CoGroupOperator.builder(VertexUpdateDriver.class, vertexKey, 0, 0)
      .input1(messager)
      .input2(iteration.getSolutionSet())
      .name("Vertex Updater")
      .build();
   
    iteration.setNextWorkset(vertexUpdater);
    iteration.setSolutionSetDelta(vertexUpdater);
   
    // parameterize the data flow
    try {
      Configuration vertexUdfParams = vertexUpdater.getParameters();
      InstantiationUtil.writeObjectToConfig(uf, vertexUdfParams, VertexUpdateDriver.UDF_PARAM);
      vertexUdfParams.setClass(VertexUpdateDriver.KEY_PARAM, vertexKey);
      vertexUdfParams.setClass(VertexUpdateDriver.VALUE_PARAM, vertexValue);
      vertexUdfParams.setClass(VertexUpdateDriver.MESSAGE_PARAM, messageType);
     
      Configuration messageUdfParams = messager.getParameters();
      InstantiationUtil.writeObjectToConfig(mf, messageUdfParams, MessagingDriver.UDF_PARAM);
      messageUdfParams.setClass(MessagingDriver.KEY_PARAM, vertexKey);
      messageUdfParams.setClass(MessagingDriver.VALUE_PARAM, vertexValue);
      messageUdfParams.setClass(MessagingDriver.MESSAGE_PARAM, messageType);
      messageUdfParams.setClass(MessagingDriver.EDGE_PARAM, edgeValue);
    }
    catch (IOException e) {
      throw new RuntimeException("Could not serialize the UDFs for distribution" +
          (e.getMessage() == null ? '.' : ": " + e.getMessage()), e);
    }
  }
 
  // ----------------------------------------------------------------------------------
  //  inputs and outputs
  // ----------------------------------------------------------------------------------
 
  public void setVertexInput(Operator<Record> c) {
    this.iteration.setInitialSolutionSet(c);
    this.iteration.setInitialWorkset(c);
  }
 
  public void setEdgesInput(Operator<Record> c) {
    this.messager.setFirstInput(c);
  }
 
  public Operator<?> getOutput() {
    return this.iteration;
  }
 
  public void setDegreeOfParallelism(int dop) {
    this.iteration.setDegreeOfParallelism(dop);
  }
 
  public void setNumberOfIterations(int iterations) {
    this.iteration.setMaximumNumberOfIterations(iterations);
  }
 
  public AggregatorRegistry getAggregators() {
    return this.iteration.getAggregators();
  }
 
  // --------------------------------------------------------------------------------------------
  //  Wrapping UDFs
  // --------------------------------------------------------------------------------------------
 
  @ConstantFieldsFirst(0)
  public static final class VertexUpdateDriver<K extends Key<K>, V extends Value, M extends Value> extends CoGroupFunction {
   
    private static final long serialVersionUID = 1L;
   
    private static final String UDF_PARAM = "spargel.udf";
    private static final String KEY_PARAM = "spargel.key-type";
    private static final String VALUE_PARAM = "spargel.value-type";
    private static final String MESSAGE_PARAM = "spargel.message-type";
   
    private VertexUpdateFunction<K, V, M> vertexUpdateFunction;
   
    private K vertexKey;
    private V vertexValue;
    private MessageIterator<M> messageIter;

    @Override
    public void coGroup(Iterator<Record> messages, Iterator<Record> vertex, Collector<Record> out) throws Exception {
     
      if (vertex.hasNext()) {
        Record first = vertex.next();
        first.getFieldInto(0, vertexKey);
        first.getFieldInto(1, vertexValue);
        messageIter.setSource(messages);
        vertexUpdateFunction.setOutput(first, out);
        vertexUpdateFunction.updateVertex(vertexKey, vertexValue, messageIter);
      } else {
        if (messages.hasNext()) {
          String message = "Target vertex does not exist!.";
          try {
            Record next = messages.next();
            next.getFieldInto(0, vertexKey);
            message = "Target vertex '" + vertexKey + "' does not exist!.";
          } catch (Throwable t) {}
          throw new Exception(message);
        } else {
          throw new Exception();
        }
      }
    }
   
    @SuppressWarnings("unchecked")
    @Override
    public void open(Configuration parameters) throws Exception {
      // instantiate only the first time
      if (vertexUpdateFunction == null) {
        ClassLoader cl = getRuntimeContext().getUserCodeClassLoader();
       
        Class<K> vertexKeyClass = parameters.getClass(KEY_PARAM, null, cl);
        Class<V> vertexValueClass = parameters.getClass(VALUE_PARAM, null, cl);
        Class<M> messageClass = parameters.getClass(MESSAGE_PARAM, null, cl);
       
        vertexKey = InstantiationUtil.instantiate(vertexKeyClass, Key.class);
        vertexValue = InstantiationUtil.instantiate(vertexValueClass, Value.class);
        messageIter = new MessageIterator<M>(InstantiationUtil.instantiate(messageClass, Value.class));
       
        ClassLoader ucl = getRuntimeContext().getUserCodeClassLoader();
       
        try {
          this.vertexUpdateFunction = (VertexUpdateFunction<K, V, M>) InstantiationUtil.readObjectFromConfig(parameters, UDF_PARAM, ucl);
        } catch (Exception e) {
          String message = e.getMessage() == null ? "." : ": " + e.getMessage();
          throw new Exception("Could not instantiate VertexUpdateFunction" + message, e);
        }
       
        this.vertexUpdateFunction.init(getIterationRuntimeContext());
        this.vertexUpdateFunction.setup(parameters);
      }
      this.vertexUpdateFunction.preSuperstep();
    }
   
    @Override
    public void close() throws Exception {
      this.vertexUpdateFunction.postSuperstep();
    }
  }
 
  public static final class MessagingDriver<K extends Key<K>, V extends Value, M extends Value, E extends Value> extends CoGroupFunction {

    private static final long serialVersionUID = 1L;
   
    private static final String UDF_PARAM = "spargel.udf";
    private static final String KEY_PARAM = "spargel.key-type";
    private static final String VALUE_PARAM = "spargel.value-type";
    private static final String MESSAGE_PARAM = "spargel.message-type";
    private static final String EDGE_PARAM = "spargel.edge-value";
   
   
    private MessagingFunction<K, V, M, E> messagingFunction;
   
    private K vertexKey;
    private V vertexValue;
   
    @Override
    public void coGroup(Iterator<Record> edges, Iterator<Record> state, Collector<Record> out) throws Exception {
      if (state.hasNext()) {
        Record first = state.next();
        first.getFieldInto(0, vertexKey);
        first.getFieldInto(1, vertexValue);
        messagingFunction.set(edges, out);
        messagingFunction.sendMessages(vertexKey, vertexValue);
      }
    }
   
    @SuppressWarnings("unchecked")
    @Override
    public void open(Configuration parameters) throws Exception {
      // instantiate only the first time
      if (messagingFunction == null) {
        ClassLoader cl = getRuntimeContext().getUserCodeClassLoader();
       
        Class<K> vertexKeyClass = parameters.getClass(KEY_PARAM, null, cl);
        Class<V> vertexValueClass = parameters.getClass(VALUE_PARAM, null, cl);
//        Class<M> messageClass = parameters.getClass(MESSAGE_PARAM, null, Value.class);
        Class<E> edgeClass = parameters.getClass(EDGE_PARAM, null, cl);
       
        vertexKey = InstantiationUtil.instantiate(vertexKeyClass, Key.class);
        vertexValue = InstantiationUtil.instantiate(vertexValueClass, Value.class);
       
        K edgeKeyHolder = InstantiationUtil.instantiate(vertexKeyClass, Key.class);
        E edgeValueHolder = InstantiationUtil.instantiate(edgeClass, Value.class);
       
        ClassLoader ucl = getRuntimeContext().getUserCodeClassLoader();
       
        try {
          this.messagingFunction = (MessagingFunction<K, V, M, E>) InstantiationUtil.readObjectFromConfig(parameters, UDF_PARAM, ucl);
        } catch (Exception e) {
          String message = e.getMessage() == null ? "." : ": " + e.getMessage();
          throw new Exception("Could not instantiate MessagingFunction" + message, e);
        }
       
        this.messagingFunction.init(getIterationRuntimeContext(), edgeKeyHolder, edgeValueHolder);
        this.messagingFunction.setup(parameters);
      }
      this.messagingFunction.preSuperstep();
    }
   
    @Override
    public void close() throws Exception {
      this.messagingFunction.postSuperstep();
    }
  }
}
TOP

Related Classes of org.apache.flink.spargel.java.record.SpargelIteration$MessagingDriver

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.