Package etc.aloe.cscw2013

Source Code of etc.aloe.cscw2013.WekaModelTest

/*
* This file is part of ALOE.
*
* ALOE is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.

* ALOE is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.

* You should have received a copy of the GNU General Public License
* along with ALOE.  If not, see <http://www.gnu.org/licenses/>.
*
* Copyright (c) 2012 SCCL, University of Washington (http://depts.washington.edu/sccl)
*/
package etc.aloe.cscw2013;

import etc.aloe.cscw2013.WekaModel;
import etc.aloe.data.ExampleSet;
import etc.aloe.data.Predictions;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.List;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import static org.junit.Assert.*;
import weka.classifiers.Classifier;
import weka.classifiers.trees.J48;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instances;

/**
*
* @author Michael Brooks <mjbrooks@uw.edu>
*/
public class WekaModelTest {

    private Instances instances;
    private Instances testInstances;

    public WekaModelTest() {
    }

    @BeforeClass
    public static void setUpClass() {
    }

    @AfterClass
    public static void tearDownClass() {
    }

    @Before
    public void setUp() {
        ArrayList<Attribute> attributes = new ArrayList<Attribute>();

        attributes.add(new Attribute("id"));

        ArrayList<String> classValues = new ArrayList<String>();
        classValues.add("false");
        classValues.add("true");
        attributes.add(new Attribute("class", classValues));

        this.instances = new Instances("TrainInstances", attributes, 12);
        this.instances.setClassIndex(1);

        this.instances.add(new DenseInstance(1.0, new double[]{1.0, 1.0}));
        this.instances.add(new DenseInstance(1.0, new double[]{2.0, 1.0}));
        this.instances.add(new DenseInstance(1.0, new double[]{3.0, 1.0}));
        this.instances.add(new DenseInstance(1.0, new double[]{4.0, 1.0}));
        this.instances.add(new DenseInstance(1.0, new double[]{5.0, 1.0}));
        this.instances.add(new DenseInstance(1.0, new double[]{6.0, 1.0}));
        this.instances.add(new DenseInstance(1.0, new double[]{7.0, 0.0}));
        this.instances.add(new DenseInstance(1.0, new double[]{8.0, 0.0}));
        this.instances.add(new DenseInstance(1.0, new double[]{9.0, 0.0}));
        this.instances.add(new DenseInstance(1.0, new double[]{10.0, 0.0}));
        this.instances.add(new DenseInstance(1.0, new double[]{11.0, 0.0}));
        this.instances.add(new DenseInstance(1.0, new double[]{12.0, 0.0}));

        this.testInstances = new Instances("TestInstances", attributes, 4);
        this.testInstances.setClassIndex(1);

        this.testInstances.add(new DenseInstance(1.0, new double[]{1.0, 1.0}));
        this.testInstances.add(new DenseInstance(1.0, new double[]{6.0, 1.0}));
        this.testInstances.add(new DenseInstance(1.0, new double[]{7.0, 0.0}));
        this.testInstances.add(new DenseInstance(1.0, new double[]{12.0, 0.0}));
    }

    @After
    public void tearDown() {
    }

    /**
     * Test of save method, of class Model.
     */
    @Test
    public void testSave() throws Exception {
        System.out.println("save");

        J48 classifier = new J48();
        classifier.setNumFolds(456);
        WekaModel model = new WekaModel(classifier);

        ByteArrayOutputStream out = new ByteArrayOutputStream();
        assertTrue(model.save(out));
        out.close();
        byte[] serializedStr = out.toByteArray();

        //It wrote something
        assertTrue(serializedStr.length > 0);

        //It wrote a J48 classifier
        ObjectInputStream in = new ObjectInputStream(new ByteArrayInputStream(serializedStr));
        try {
            J48 serialized = (J48) in.readObject();
            in.close();
            assertEquals(classifier.getNumFolds(), serialized.getNumFolds());
        } catch (ClassNotFoundException e) {
            assertTrue(e.getMessage(), false);
        } catch (IOException e) {
            assertTrue(e.getMessage(), false);
        }
    }

    /**
     * Test of load method, of class Model.
     */
    @Test
    public void testLoad() throws Exception {
        System.out.println("load");

        J48 serializedClassifier = new J48();
        serializedClassifier.setNumFolds(456);
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        ObjectOutputStream serializer = new ObjectOutputStream(out);
        serializer.writeObject(serializedClassifier);
        out.close();
        byte[] serializedStr = out.toByteArray();

        ByteArrayInputStream in = new ByteArrayInputStream(serializedStr);
        WekaModel model = new WekaModel();
        assertTrue(model.load(in));
        in.close();

        Classifier classifier = model.getClassifier();
        assertTrue(classifier instanceof J48);
        J48 j48 = (J48) classifier;
        assertEquals(serializedClassifier.getNumFolds(), j48.getNumFolds());
    }

    /**
     * Test of getPredictedLabels method, of class Model.
     */
    @Test
    public void testGetPredictedLabels() {
        System.out.println("getPredictedLabels");

        Classifier classifier = new J48();
        try {
            classifier.buildClassifier(instances);
        } catch (Exception e) {
            assertTrue("Classifier could not be trained", false);
        }

        WekaModel instance = new WekaModel(classifier);

        Boolean[] expResult = new Boolean[]{true, true, false, false};
        Double[] expConfidence = new Double[]{1.0, 1.0, 0.0, 0.0};

        ExampleSet examples = new ExampleSet(testInstances);

        Predictions predictions = instance.getPredictions(examples);

        assertEquals(expResult.length, predictions.size());

        for (int i = 0; i < expResult.length; i++) {
            assertEquals(expResult[i], predictions.getPredictedLabel(i));
            assertEquals(expConfidence[i], predictions.getPredictionConfidence(i));
            assertEquals(testInstances.get(i).classValue(), predictions.getTrueLabel(i) ? 1.0 : 0.0, 0.0);
        }
    }
}
TOP

Related Classes of etc.aloe.cscw2013.WekaModelTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.