Package etc.aloe.data

Source Code of etc.aloe.data.ROC

/*
* This file is part of ALOE.
*
* ALOE is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.

* ALOE is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.

* You should have received a copy of the GNU General Public License
* along with ALOE.  If not, see <http://www.gnu.org/licenses/>.
*
* Copyright (c) 2012 SCCL, University of Washington (http://depts.washington.edu/sccl)
*/
package etc.aloe.data;

import com.csvreader.CsvWriter;
import etc.aloe.processes.Saving;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.List;

/**
* Class for storing a ROC curve.
*
* @author Michael Brooks <mjbrooks@uw.edu>
*/
public class ROC implements Saving {

    private final List<Double> falsePositiveRates = new ArrayList<Double>();
    private final List<Double> truePositiveRates = new ArrayList<Double>();
    private final List<Double> thresholdValues = new ArrayList<Double>();
    private final String name;

    public ROC(String name) {
        this.name = name;
    }

    public String getName() {
        return name;
    }

    public int size() {
        return this.falsePositiveRates.size();
    }

    public double getFalsePositiveRate(int index) {
        return falsePositiveRates.get(index);
    }

    public double getTruePositiveRate(int index) {
        return truePositiveRates.get(index);
    }

    public double getThresholdValue(int index) {
        return thresholdValues.get(index);
    }

    /**
     * Record a data point on the ROC curve.
     *
     * @param fpRate
     * @param tpRate
     * @param threshold
     */
    public void record(double fpRate, double tpRate, double threshold) {
        this.falsePositiveRates.add(fpRate);
        this.truePositiveRates.add(tpRate);
        this.thresholdValues.add(threshold);
    }

    /**
     * Clear the recorded curves.
     */
    public void clear() {
        this.falsePositiveRates.clear();
        this.truePositiveRates.clear();
        this.thresholdValues.clear();
    }

    /**
     * Generate the ROC curve from the given predictions.
     *
     * @param predictions
     */
    public void calculateCurve(Predictions predictions) {
        clear();
        predictions = predictions.sortByConfidence();

        int truePositives = 0;
        int falsePositives = 0;
        int totalPositives = predictions.getTruePositiveCount() + predictions.getFalseNegativeCount();
        int totalNegatives = predictions.getTrueNegativeCount() + predictions.getFalsePositiveCount();
        for (int i = 0; i < predictions.size(); i++) {
            Boolean trueLabel = predictions.getTrueLabel(i);
            Double confidence = predictions.getPredictionConfidence(i);

            if (trueLabel == null) {
                continue;
            } else if (trueLabel) {
                truePositives++;
            } else {
                falsePositives++;
            }

            double tpRate = (double) truePositives / totalPositives;
            double fpRate = (double) falsePositives / totalNegatives;
            record(fpRate, tpRate, confidence);
        }
    }

    @Override
    public boolean save(OutputStream destination) throws IOException {
        CsvWriter out = new CsvWriter(destination, ',', Charset.forName("UTF-8"));

        out.write("Threshold");
        out.write("True Positive Rate");
        out.write("False Positive Rate");
        out.endRecord();

        for (int i = 0; i < size(); i++) {
            double threshold = getThresholdValue(i);
            double fpRate = getFalsePositiveRate(i);
            double tpRate = getTruePositiveRate(i);

            out.write("" + threshold);
            out.write("" + tpRate);
            out.write("" + fpRate);
            out.endRecord();
        }

        out.flush();
        return true;


    }
}
TOP

Related Classes of etc.aloe.data.ROC

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.