/*
* Copyright (c) 2010 Pentaho Corporation. All rights reserved.
* This software was developed by Pentaho Corporation and is provided under the terms
* of the GNU Lesser General Public License, Version 2.1. You may not use
* this file except in compliance with the license. If you need a copy of the license,
* please go to http://www.gnu.org/licenses/lgpl-2.1.txt. The Original Code is Time Series
* Forecasting. The Initial Developer is Pentaho Corporation.
*
* Software distributed under the GNU Lesser Public License is distributed on an "AS IS"
* basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. Please refer to
* the license for the specific language governing your rights and limitations.
*/
/*
* ErrorModule.java
* Copyright (C) 2010 Pentaho Corporation
*/
package weka.classifiers.timeseries.eval;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import weka.classifiers.evaluation.NumericPrediction;
import weka.core.Instance;
import weka.core.Utils;
/**
* Superclass of error-based evaluation modules. Stores the predictions for each
* target along with the actual values. Computes the sum of errors for each
* target.
*
* @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
* @version $Revision: 49983 $
*
*/
public class ErrorModule extends TSEvalModule {
/** The predictions for each target. Outer list indexes targets */
protected List<List<NumericPrediction>> m_predictions;
/** The counts of each valid target prediction */
protected double[] m_counts;
/**
* Reset this module
*/
public void reset() {
if (m_targetFieldNames != null) {
m_predictions = new ArrayList<List<NumericPrediction>>();
m_counts = new double[m_targetFieldNames.size()];
for (int i = 0; i < m_targetFieldNames.size(); i++) {
ArrayList<NumericPrediction> predsForTarget =
new ArrayList<NumericPrediction>();
m_predictions.add(predsForTarget);
}
}
}
/**
* Return the short identifying name of this evaluation module
*
* @return the short identifying name of this evaluation module
*/
public String getEvalName() {
return "Error";
}
/**
* Return the longer (single sentence) description
* of this evaluation module
*
* @return the longer description of this module
*/
public String getDescription() {
return "Sum of errors";
}
/**
* Return the mathematical formula that this
* evaluation module computes.
*
* @return the mathematical formula that this module
* computes.
*/
public String getDefinition() {
return "sum(predicted - actual)";
}
/**
* Gets a textual description of this module : getDescription() + getEvalName()
*/
public String toString() {
return getDescription() + " (" + getEvalName() + ")";
}
/**
* Evaluate the given forecast(s) with respect to the given
* test instance. Targets with missing values are ignored.
*
* @param forecasts a List of forecasted values. Each element
* corresponds to one of the targets and is assumed to be in the same
* order as the list of targets supplied to the setTargetFields() method.
* @throws Exception if the evaluation can't be completed for some
* reason.
*/
public void evaluateForInstance(List<NumericPrediction> forecasts, Instance inst)
throws Exception {
if (m_predictions == null) {
throw new Exception("Target fields haven't been set yet!");
}
if (forecasts.size() != m_targetFieldNames.size()) {
throw new Exception("The number of forecasted values does not match the" +
" number of target fields!");
}
for (int i = 0; i < m_targetFieldNames.size(); i++) {
double actualValue = getTargetValue(m_targetFieldNames.get(i), inst);
double predictedValue = forecasts.get(i).predicted();
//System.err.println("Actual: " + actualValue + " Predicted: " + predictedValue);
double[][] intervals = forecasts.get(i).predictionIntervals();
NumericPrediction pred = new NumericPrediction(actualValue, predictedValue, 1, intervals);
m_predictions.get(i).add(pred);
if (!Utils.isMissingValue(predictedValue) &&
!Utils.isMissingValue(actualValue)) {
m_counts[i]++;
}
}
}
/**
* Calculate the measure that this module represents.
*
* @return the value of the measure for this module for each
* of the target(s).
* @throws Exception if the measure can't be computed for some reason.
*/
public double[] calculateMeasure() throws Exception {
if (m_predictions == null || m_predictions.get(0).size() == 0) {
throw new Exception("No predictions have been seen yet!");
}
double[] result = new double[m_targetFieldNames.size()];
for (int i = 0; i < m_targetFieldNames.size(); i++) {
List<NumericPrediction> preds = m_predictions.get(i);
double sumOfE = 0;
for (NumericPrediction p : preds) {
if (!Utils.isMissingValue(p.error())) {
sumOfE += p.error();
}
}
result[i] = sumOfE;
}
return result;
}
/**
* Gets the number of predicted, actual pairs for each target. Only
* entries that are non-missing for both actual and predicted contribute
* to the overall count.
*
* @return the number of predicted, actual pairs for each target.
* @throws Exception
*/
public double[] countsForTargets() throws Exception {
if (m_predictions == null || m_predictions.get(0).size() == 0) {
throw new Exception("No predictions have been seen yet!");
}
return m_counts;
}
/**
* Get a list of the errors for the supplied target
*
* @param targetName the target to get the errors for
* @return the errors as a list of Double
* @throws IllegalArgumentException if the target name is unknown
*/
public List<Double> getErrorsForTarget(String targetName)
throws IllegalArgumentException {
for (int i = 0; i < m_targetFieldNames.size(); i++) {
if (m_targetFieldNames.get(i).equals(targetName)) {
ArrayList<Double> errors = new ArrayList<Double>();
List<NumericPrediction> preds = m_predictions.get(i);
for (int j = 0; j < preds.size(); j++) {
Double err = new Double(preds.get(j).error());
errors.add(err);
}
return errors;
}
}
throw new IllegalArgumentException("Unknown target: " + targetName);
}
/**
* Get a list of predictions (plus actuals if known) for the supplied target
*
* @param targetName the target to get predictions for
* @return a list of predictions for the supplied target
* @throws IllegalArgumentException if the target name is unknown
*/
public List<NumericPrediction> getPredictionsForTarget(String targetName)
throws IllegalArgumentException {
for (int i = 0; i < m_targetFieldNames.size(); i++) {
if (m_targetFieldNames.get(i).equals(targetName)) {
return m_predictions.get(i);
}
}
throw new IllegalArgumentException("Unknown target: " + targetName);
}
/**
* Gets the predictions for all targets
*
* @return the predictions for all targets as a list of lists the outer list
* indexes targets.
*/
public List<List<NumericPrediction>> getPredictionsForAllTargets() {
return m_predictions;
}
public String toSummaryString() throws Exception {
StringBuffer result = new StringBuffer();
double[] measures = calculateMeasure();
for (int i = 0; i < m_targetFieldNames.size(); i++) {
result.append(getDescription() + " (" + m_targetFieldNames.get(i) + "): "
+ Utils.doubleToString(measures[i], 4) + " (n = " + m_counts[i] + ")");
result.append("\n");
}
return result.toString();
}
}