Package de.jungblut.math.squashing

Source Code of de.jungblut.math.squashing.HammingLossFunction

package de.jungblut.math.squashing;

import gnu.trove.set.hash.TIntHashSet;

import java.util.Iterator;

import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.DoubleVector.DoubleVectorElement;

/**
* The HammingLoss is the fraction of labels that are incorrectly predicted.
*
* @author thomasjungblut
*
*/
public final class HammingLossFunction implements ErrorFunction {

  private final double activationThreshold;

  /**
   * @param activationThreshold the threshold which denotes from where on up a
   *          numerical value "turns a unit of activation on". Turning on the
   *          activation means that it is interpreted as value 1.
   */
  public HammingLossFunction(double activationThreshold) {
    this.activationThreshold = activationThreshold;
  }

  @Override
  public double calculateError(DoubleMatrix y, DoubleMatrix hypothesis) {
    double hammingSum = 0d;
    // we now loop row-wise over the matrices
    for (int row = 0; row < y.getRowCount(); row++) {
      DoubleVector yRow = y.getRowVector(row);
      DoubleVector hypRow = hypothesis.getRowVector(row);
      hammingSum += calculateError(yRow, hypRow);
    }

    return hammingSum / y.getRowCount();
  }

  @Override
  public double calculateError(DoubleVector y, DoubleVector hypothesis) {
    double sum = 0d;
    TIntHashSet visitedColumns = new TIntHashSet(y.getLength());
    Iterator<DoubleVectorElement> iterateNonZero = hypothesis.iterateNonZero();
    while (iterateNonZero.hasNext()) {
      DoubleVectorElement next = iterateNonZero.next();
      visitedColumns.add(next.getIndex());
      if (next.getValue() > activationThreshold ^ y.get(next.getIndex()) == 1d) {
        sum += 1d;
      }
    }

    iterateNonZero = y.iterateNonZero();
    while (iterateNonZero.hasNext()) {
      DoubleVectorElement next = iterateNonZero.next();
      if (!visitedColumns.contains(next.getIndex())) {
        visitedColumns.add(next.getIndex());
        if (hypothesis.get(next.getIndex()) > activationThreshold
            ^ next.getValue() == 1d) {
          sum += 1d;
        }
      }
    }
    return sum / visitedColumns.size();
  }

}
TOP

Related Classes of de.jungblut.math.squashing.HammingLossFunction

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.