Package com.github.pmerienne.trident.ml.evaluation

Source Code of com.github.pmerienne.trident.ml.evaluation.EvaluationUpdaterTridentIntegrationTest$DRPCArgsToInstance

/**
* Copyright 2013-2015 Pierre Merienne
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.github.pmerienne.trident.ml.evaluation;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import org.junit.Test;

import storm.trident.Stream;
import storm.trident.TridentState;
import storm.trident.TridentTopology;
import storm.trident.operation.BaseFunction;
import storm.trident.operation.TridentCollector;
import storm.trident.testing.MemoryMapState;
import storm.trident.tuple.TridentTuple;
import backtype.storm.Config;
import backtype.storm.LocalCluster;
import backtype.storm.LocalDRPC;
import backtype.storm.tuple.Fields;
import backtype.storm.tuple.Values;

import com.github.pmerienne.trident.ml.classification.ClassifierUpdater;
import com.github.pmerienne.trident.ml.classification.ClassifyQuery;
import com.github.pmerienne.trident.ml.classification.PerceptronClassifier;
import com.github.pmerienne.trident.ml.core.Instance;
import com.github.pmerienne.trident.ml.testing.NANDSpout;

public class EvaluationUpdaterTridentIntegrationTest {

  @Test
  public void should_test_then_train() throws InterruptedException {
    // Start local cluster
    LocalCluster cluster = new LocalCluster();
    LocalDRPC localDRPC = new LocalDRPC();

    try {
      TridentTopology toppology = new TridentTopology();
      MemoryMapState.Factory evaluationStateFactory = new MemoryMapState.Factory();
      MemoryMapState.Factory perceptronModelStateFactory = new MemoryMapState.Factory();
      TridentState perceptronModel = toppology.newStaticState(perceptronModelStateFactory);
      TridentState perceptronEvaluation = toppology.newStaticState(evaluationStateFactory);

      // Predict
      Stream predictionStream = toppology.newStream("nandsamples", new NANDSpout()) //
          .stateQuery(perceptronModel, new Fields("instance"), new ClassifyQuery<Boolean>("perceptron"), new Fields("prediction"));

      // Update evaluation
      predictionStream //
          .persistentAggregate(evaluationStateFactory, new Fields("instance", "prediction"), new AccuracyAggregator<Boolean>(), new Fields("accuracy"));

      // Update model
      predictionStream.partitionPersist(perceptronModelStateFactory, new Fields("instance"), new ClassifierUpdater<Boolean>("perceptron", new PerceptronClassifier()));

      // Classification stream
      toppology.newDRPCStream("predict", localDRPC)
      // convert DRPC args to instance
          .each(new Fields("args"), new DRPCArgsToInstance(), new Fields("instance"))

          // Query classifier to classify instance
          .stateQuery(perceptronModel, new Fields("instance"), new ClassifyQuery<Boolean>("perceptron"), new Fields("prediction")).project(new Fields("prediction"));

      // Evaluation stream
      toppology.newDRPCStream("evaluate", localDRPC) //
          .stateQuery(perceptronEvaluation, new EvaluationQuery<Boolean>(), new Fields("eval")) //
          .project(new Fields("eval"));

      cluster.submitTopology(this.getClass().getSimpleName(), new Config(), toppology.build());
      Thread.sleep(4000);

      assertEquals(Boolean.TRUE, extractPrediction(localDRPC.execute("predict", "1.0 0.0 0.0")));
      assertEquals(Boolean.TRUE, extractPrediction(localDRPC.execute("predict", "1.0 0.0 1.0")));
      assertEquals(Boolean.TRUE, extractPrediction(localDRPC.execute("predict", "1.0 1.0 0.0")));
      assertEquals(Boolean.FALSE, extractPrediction(localDRPC.execute("predict", "1.0 1.0 1.0")));

      Double evaluation = extractEvaluation(localDRPC.execute("evaluate", ""));
      assertTrue(evaluation > 0.9);
      assertTrue(evaluation < 1);
    } finally {
      cluster.shutdown();
      localDRPC.shutdown();
    }
  }

  protected static Boolean extractPrediction(String drpcResult) {
    return Boolean.parseBoolean(drpcResult.replaceAll("\\[", "").replaceAll("\\]", ""));
  }

  protected static Double extractEvaluation(String drpcResult) {
    return Double.parseDouble(drpcResult.replaceAll("\\[", "").replaceAll("\\]", ""));
  }

  public static class DRPCArgsToInstance extends BaseFunction {

    private static final long serialVersionUID = -2932222000448806586L;

    @SuppressWarnings("rawtypes")
    @Override
    public void execute(TridentTuple tuple, TridentCollector collector) {
      String[] args = tuple.getString(0).split(" ");
      double[] features = new double[args.length];
      for (int i = 0; i < args.length; i++) {
        features[i] = Double.parseDouble(args[i]);
      }

      Instance<?> instance = new Instance(features);
      collector.emit(new Values(instance));
    }
  }

}
TOP

Related Classes of com.github.pmerienne.trident.ml.evaluation.EvaluationUpdaterTridentIntegrationTest$DRPCArgsToInstance

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.