Package opennlp.tools.ml.maxent.quasinewton

Source Code of opennlp.tools.ml.maxent.quasinewton.NegLogLikelihoodTest

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package opennlp.tools.ml.maxent.quasinewton;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

import opennlp.tools.ml.model.DataIndexer;
import opennlp.tools.ml.model.OnePassRealValueDataIndexer;
import opennlp.tools.ml.model.RealValueFileEventStream;

import org.junit.Test;

public class NegLogLikelihoodTest {
  public final double TOLERANCE01 = 1.0E-06;
  public final double TOLERANCE02 = 1.0E-10;

  @Test
  public void testDomainDimensionSanity() throws IOException {
    // given
    RealValueFileEventStream rvfes1 = new RealValueFileEventStream(
        "src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", "UTF-8")
    DataIndexer testDataIndexer = new OnePassRealValueDataIndexer(rvfes1,1);
    NegLogLikelihood objectFunction = new NegLogLikelihood(testDataIndexer);
    // when
    int correctDomainDimension = testDataIndexer.getPredLabels().length
        * testDataIndexer.getOutcomeLabels().length;
    // then
    assertEquals(correctDomainDimension, objectFunction.getDimension());
  }

  @Test
  public void testInitialSanity() throws IOException {
    // given
    RealValueFileEventStream rvfes1 = new RealValueFileEventStream(
        "src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", "UTF-8")
    DataIndexer testDataIndexer = new OnePassRealValueDataIndexer(rvfes1,1);
    NegLogLikelihood objectFunction = new NegLogLikelihood(testDataIndexer);
    // when
    double[] initial = objectFunction.getInitialPoint();
    // then
    for (int i = 0; i < initial.length; i++) {
      assertEquals(0.0, initial[i], TOLERANCE01);
    }
  }

  @Test
  public void testGradientSanity() throws IOException {
    // given
    RealValueFileEventStream rvfes1 = new RealValueFileEventStream(
        "src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", "UTF-8")
    DataIndexer testDataIndexer = new OnePassRealValueDataIndexer(rvfes1,1);
    NegLogLikelihood objectFunction = new NegLogLikelihood(testDataIndexer);
    // when
    double[] initial = objectFunction.getInitialPoint();
    double[] gradientAtInitial = objectFunction.gradientAt(initial);
    // then
    assertNotNull(gradientAtInitial);
  }

  @Test
  public void testValueAtInitialPoint() throws IOException {
    // given
    RealValueFileEventStream rvfes1 = new RealValueFileEventStream(
        "src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", "UTF-8");
    DataIndexer testDataIndexer = new OnePassRealValueDataIndexer(rvfes1,1);
    NegLogLikelihood objectFunction = new NegLogLikelihood(testDataIndexer);
    // when
    double value = objectFunction.valueAt(objectFunction.getInitialPoint());
    double expectedValue = 13.86294361;
    // then
    assertEquals(expectedValue, value, TOLERANCE01);
  }

  @Test
  public void testValueAtNonInitialPoint01() throws IOException {
    // given
    RealValueFileEventStream rvfes1 = new RealValueFileEventStream(
        "src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", "UTF-8");
    DataIndexer testDataIndexer = new OnePassRealValueDataIndexer(rvfes1,1);
    NegLogLikelihood objectFunction = new NegLogLikelihood(testDataIndexer);
    // when
    double[] nonInitialPoint = new double[] { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 };
    double value = objectFunction.valueAt(nonInitialPoint);
    double expectedValue = 13.862943611198894;
    // then
    assertEquals(expectedValue, value, TOLERANCE01);
  }

  @Test
  public void testValueAtNonInitialPoint02() throws IOException {
    // given
    RealValueFileEventStream rvfes1 = new RealValueFileEventStream(
        "src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", "UTF-8");
    DataIndexer testDataIndexer = new OnePassRealValueDataIndexer(rvfes1,1);
    NegLogLikelihood objectFunction = new NegLogLikelihood(testDataIndexer);
    // when
    double[] nonInitialPoint = new double[] { 3, 2, 3, 2, 3, 2, 3, 2, 3, 2 };
    double value = objectFunction.valueAt(dealignDoubleArrayForTestData(nonInitialPoint,
      testDataIndexer.getPredLabels(),
      testDataIndexer.getOutcomeLabels()));
    double expectedValue = 53.163219721099026;
    // then
    assertEquals(expectedValue, value, TOLERANCE02);
  }

  @Test
  public void testGradientAtInitialPoint() throws IOException {
    // given
    RealValueFileEventStream rvfes1 = new RealValueFileEventStream(
        "src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", "UTF-8");
    DataIndexer testDataIndexer = new OnePassRealValueDataIndexer(rvfes1,1);
    NegLogLikelihood objectFunction = new NegLogLikelihood(testDataIndexer);
    // when
    double[] gradientAtInitialPoint = objectFunction.gradientAt(objectFunction.getInitialPoint());
    double[] expectedGradient = new double[] { -9.0, -14.0, -17.0, 20.0, 8.5, 9.0, 14.0, 17.0, -20.0, -8.5 };
    // then
    assertTrue(compareDoubleArray(expectedGradient, gradientAtInitialPoint,
        testDataIndexer, TOLERANCE01));
  }

  @Test
  public void testGradientAtNonInitialPoint() throws IOException {
    // given
    RealValueFileEventStream rvfes1 = new RealValueFileEventStream(
        "src/test/resources/data/opennlp/maxent/real-valued-weights-training-data.txt", "UTF-8");
    DataIndexer testDataIndexer = new OnePassRealValueDataIndexer(rvfes1,1);
    NegLogLikelihood objectFunction = new NegLogLikelihood(testDataIndexer);
    // when
    double[] nonInitialPoint = new double[] { 0.2, 0.5, 0.2, 0.5, 0.2, 0.5, 0.2, 0.5, 0.2, 0.5 };
    double[] gradientAtNonInitialPoint =
        objectFunction.gradientAt(dealignDoubleArrayForTestData(nonInitialPoint,
            testDataIndexer.getPredLabels(),
            testDataIndexer.getOutcomeLabels()));
    double[] expectedGradient =
            new double[] { -12.755042847945553, -21.227127506102434,
                           -72.57790706276435,   38.03525795198456,
                            15.34865088935492512.755042847945557,
                            21.22712750610244,   72.57790706276438,
                           -38.03525795198456,  -15.348650889354925 };  
    // then
    assertTrue(compareDoubleArray(expectedGradient, gradientAtNonInitialPoint,
        testDataIndexer, TOLERANCE01));
  }
 
  private double[] alignDoubleArrayForTestData(double[] expected,
      String[] predLabels, String[] outcomeLabels) {
    double[] aligned = new double[predLabels.length * outcomeLabels.length];
   
    String[] sortedPredLabels = predLabels.clone();
    String[] sortedOutcomeLabels =  outcomeLabels.clone();
    Arrays.sort(sortedPredLabels);
    Arrays.sort(sortedOutcomeLabels);
   
    Map<String, Integer> invertedPredIndex = new HashMap<String, Integer>();
    Map<String, Integer> invertedOutcomeIndex = new HashMap<String, Integer>();
    for (int i = 0; i < predLabels.length; i++) {
      invertedPredIndex.put(predLabels[i], i);
    }
    for (int i = 0; i < outcomeLabels.length; i++) {
      invertedOutcomeIndex.put(outcomeLabels[i], i);
    }
 
    for (int i = 0; i < sortedOutcomeLabels.length; i++) {
      for (int j = 0; j < sortedPredLabels.length; j++) {
        aligned[i * sortedPredLabels.length + j] = expected[invertedOutcomeIndex
            .get(sortedOutcomeLabels[i])
            * sortedPredLabels.length
            + invertedPredIndex.get(sortedPredLabels[j])];
      }
    }
    return aligned;
  }
 
  private double[] dealignDoubleArrayForTestData(double[] expected,
      String[] predLabels, String[] outcomeLabels) {
    double[] dealigned = new double[predLabels.length * outcomeLabels.length];

    String[] sortedPredLabels = predLabels.clone();
    String[] sortedOutcomeLabels = outcomeLabels.clone();
    Arrays.sort(sortedPredLabels);
    Arrays.sort(sortedOutcomeLabels);

    Map<String, Integer> invertedPredIndex = new HashMap<String, Integer>();
    Map<String, Integer> invertedOutcomeIndex = new HashMap<String, Integer>();
    for (int i = 0; i < predLabels.length; i++) {
      invertedPredIndex.put(predLabels[i], i);
    }
    for (int i = 0; i < outcomeLabels.length; i++) {
      invertedOutcomeIndex.put(outcomeLabels[i], i);
    }

    for (int i = 0; i < sortedOutcomeLabels.length; i++) {
      for (int j = 0; j < sortedPredLabels.length; j++) {
        dealigned[invertedOutcomeIndex.get(sortedOutcomeLabels[i])
            * sortedPredLabels.length
            + invertedPredIndex.get(sortedPredLabels[j])] = expected[i
            * sortedPredLabels.length + j];
      }
    }

    return dealigned;
  }
 
  private boolean compareDoubleArray(double[] expected, double[] actual,
      DataIndexer indexer, double tolerance)
  {
    double[] alignedActual = alignDoubleArrayForTestData(
        actual, indexer.getPredLabels(), indexer.getOutcomeLabels());

    if (expected.length != alignedActual.length) {
      return false;
    }
   
    for (int i = 0; i < alignedActual.length; i++) {
      if (Math.abs(alignedActual[i] - expected[i]) > tolerance) {
        return false;
      }
    }
    return true;
  }
}
TOP

Related Classes of opennlp.tools.ml.maxent.quasinewton.NegLogLikelihoodTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.