/*
* This file is part of ALOE.
*
* ALOE is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
* ALOE is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
* You should have received a copy of the GNU General Public License
* along with ALOE. If not, see <http://www.gnu.org/licenses/>.
*
* Copyright (c) 2012 SCCL, University of Washington (http://depts.washington.edu/sccl)
*/
package etc.aloe.data;
import com.csvreader.CsvWriter;
import etc.aloe.processes.Saving;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.List;
/**
* Class for storing a ROC curve.
*
* @author Michael Brooks <mjbrooks@uw.edu>
*/
public class ROC implements Saving {
private final List<Double> falsePositiveRates = new ArrayList<Double>();
private final List<Double> truePositiveRates = new ArrayList<Double>();
private final List<Double> thresholdValues = new ArrayList<Double>();
private final String name;
public ROC(String name) {
this.name = name;
}
public String getName() {
return name;
}
public int size() {
return this.falsePositiveRates.size();
}
public double getFalsePositiveRate(int index) {
return falsePositiveRates.get(index);
}
public double getTruePositiveRate(int index) {
return truePositiveRates.get(index);
}
public double getThresholdValue(int index) {
return thresholdValues.get(index);
}
/**
* Record a data point on the ROC curve.
*
* @param fpRate
* @param tpRate
* @param threshold
*/
public void record(double fpRate, double tpRate, double threshold) {
this.falsePositiveRates.add(fpRate);
this.truePositiveRates.add(tpRate);
this.thresholdValues.add(threshold);
}
/**
* Clear the recorded curves.
*/
public void clear() {
this.falsePositiveRates.clear();
this.truePositiveRates.clear();
this.thresholdValues.clear();
}
/**
* Generate the ROC curve from the given predictions.
*
* @param predictions
*/
public void calculateCurve(Predictions predictions) {
clear();
predictions = predictions.sortByConfidence();
int truePositives = 0;
int falsePositives = 0;
int totalPositives = predictions.getTruePositiveCount() + predictions.getFalseNegativeCount();
int totalNegatives = predictions.getTrueNegativeCount() + predictions.getFalsePositiveCount();
for (int i = 0; i < predictions.size(); i++) {
Boolean trueLabel = predictions.getTrueLabel(i);
Double confidence = predictions.getPredictionConfidence(i);
if (trueLabel == null) {
continue;
} else if (trueLabel) {
truePositives++;
} else {
falsePositives++;
}
double tpRate = (double) truePositives / totalPositives;
double fpRate = (double) falsePositives / totalNegatives;
record(fpRate, tpRate, confidence);
}
}
@Override
public boolean save(OutputStream destination) throws IOException {
CsvWriter out = new CsvWriter(destination, ',', Charset.forName("UTF-8"));
out.write("Threshold");
out.write("True Positive Rate");
out.write("False Positive Rate");
out.endRecord();
for (int i = 0; i < size(); i++) {
double threshold = getThresholdValue(i);
double fpRate = getFalsePositiveRate(i);
double tpRate = getTruePositiveRate(i);
out.write("" + threshold);
out.write("" + tpRate);
out.write("" + fpRate);
out.endRecord();
}
out.flush();
return true;
}
}