Package com.nr.test.test_chapter16

Source Code of com.nr.test.test_chapter16.Test_HMM

package com.nr.test.test_chapter16;

import static com.nr.NRUtil.buildMatrix;
import static org.junit.Assert.fail;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import com.nr.ci.HMM;
import com.nr.ran.Ran;

public class Test_HMM {

  @Before
  public void setUp() throws Exception {
  }

  @After
  public void tearDown() throws Exception {
  }

  @Test
  public void test() {
    int i,j,N=1000,M=5,K=5;
    double r,sum=0;
    double aatrans[]={      // Transition matrix
      0.0,0.7,0.1,0.0,0.2,
      0.2,0.4,0.0,0.2,0.2,
      0.0,1.0,0.0,0.0,0.0,
      0.0,0.3,0.0,0.7,0.0,
      0.1,0.1,0.0,0.0,0.8
    };
    double bb[]={         // Symbol probabilities for each state
      0.2,0.0,0.0,0.8,0.0,
      0.2,0.0,0.6,0.2,0.0,
      0.0,1.0,0.0,0.0,0.0,
      0.3,0.2,0.4,0.1,0.0,
      0.5,0.0,0.0,0.0,0.5
    };
    int[] state= new int[N],symbols= new int[N];
    double[][] atrans=buildMatrix(M,M,aatrans),b=buildMatrix(M,K,bb);
    boolean localflag, globalflag=false;

   

    // Test HMM
    System.out.println("Testing HMM");

    // Generate the Markov sequence of states
    HMM.markovgen(atrans,state,0,17);

    // Generate the sequence of symbols emitted
    Ran myran=new Ran(17);
    for (i=0;i<N;i++) {
      r=myran.doub();
      sum=0.0;
      for (j=0;j<K;j++) {
        sum += b[state[i]][j];
        if (r < sum) {
          symbols[i]=j;
          break;
        }
      }
    }

    // Try to discover the model, given the symbols
    HMM hmm=new HMM(atrans,b,symbols);
    hmm.forwardbackward();

    // Inspect results
    int jmax=0,ncorrect=0;
    double test;
    for (i=0;i<N;i++) {
      test=0;
      for (j=0;j<M;j++) {
        if (hmm.pstate[i][j] > test) {
          test=hmm.pstate[i][j];
          jmax=j;
        }
      }
//      System.out.printf(state[i] << " %f\n", jmax);
      if (jmax == state[i]) ncorrect++;
    }
    System.out.printf("Fraction correct: %f\n", (double)(ncorrect)/N);
    localflag = (double)(ncorrect)/N < 0.75;
    globalflag = globalflag || localflag;
    if (localflag) {
      fail("*** HMM: Actual state was not the top probability more than 25% of the time");
     
    }

    // Inspect reconstructed transition matrix
    int jpen=0;
    ncorrect=0;
    for (i=0;i<N;i++) {
      test=0.0;
      for (j=0;j<M;j++) {
        if (hmm.pstate[i][j] > test) {
          test=hmm.pstate[i][j];
          jmax=j;
        }
      }

      // Find second largest
      test=0.0;
      for (j=0;j<M;j++) {
        if (j != jmax) {
          if (hmm.pstate[i][j] > test) {
            test=hmm.pstate[i][j];
            jpen=j;
          }
        }
      }
   
//      System.out.printf(state[i] << " %f\n", jmax);
      if (jmax == state[i] || jpen == state[i]) ncorrect++;
    }
    System.out.printf("Fraction correct: %f\n", (double)(ncorrect)/N);
    double beforeBW=(double)(ncorrect)/N;
    localflag = (double)(ncorrect)/N < 0.95;
    globalflag = globalflag || localflag;
    if (localflag) {
      fail("*** HMM: Actual state was not in top 2 probabilities more than 5% of the time");
     
    }
   
//    System.out.println("Log-likelihood: %f\n", hmm.loglikelihood());

    // Test Baum-Welch reestimation
    for (i=0;i<100;i++) {
      hmm.baumwelch();
      hmm.forwardbackward();
//      System.out.println("Log-likelihood: %f\n", hmm.loglikelihood());
    }

    // Inspect reconstructed transition matrix
    ncorrect=0;
    for (i=0;i<N;i++) {
      test=0.0;
      for (j=0;j<M;j++) {
        if (hmm.pstate[i][j] > test) {
          test=hmm.pstate[i][j];
          jmax=j;
        }
      }

      // Find second largest
      test=0.0;
      for (j=0;j<M;j++) {
        if (j != jmax) {
          if (hmm.pstate[i][j] > test) {
            test=hmm.pstate[i][j];
            jpen=j;
          }
        }
      }
   
//      System.out.printf(state[i] << " %f\n", jmax);
      if (jmax == state[i] || jpen == state[i]) ncorrect++;
    }
    System.out.printf("Fraction correct after Baum-Welch: %f\n", (double)(ncorrect)/N);
    double afterBW=(double)(ncorrect)/N;
    localflag = afterBW <= beforeBW;
    globalflag = globalflag || localflag;
    if (localflag) {
      fail("*** HMM: Baum-Welch reestimation did not improve model");
     
    }

    if (globalflag) System.out.println("Failed\n");
    else System.out.println("Passed\n");
  }

}
TOP

Related Classes of com.nr.test.test_chapter16.Test_HMM

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.