Package cc.mallet.topics

Source Code of cc.mallet.topics.HDP

/* Hierarchical Dirichlet Process for Mallet
* Version:0.1
*
* Author: CHyi-Kwei Yau
*
* HDP implementation on Mallet
* Basic structure & Code form "Implementing the HDP with minimum code complexity" by Gregor Heinrich.
*
* ==================================================================================================
* TODO List:
* - export HDP result- top words, topic distribution
* - export Inference result - same
*
*
* Update History:
*
* 2012/10/01 version 0.1
*  - bug fix: print correct topic number in training
*  - add cross validation in HDP
*  - add inferencer class
*  - add preplexity calculation in inferencer
*
* 2012/09/29 Version 0.1
*  - bug fix: printed result are correct now
*  - bug fix: empty topic are caused by initial topic number > 0
*  - change initial topic assignment to uniform distribution and remove empty topics.
*
* 2012/09/28 Version 0.1
*   - bug: Topic number and total word count not match in printed result
*  - bug: some topics are empty but not removed 
*
* 2012/09/27 Version 0.1
*   - main algorithm work. not finished all function,
*  - bug: auto update hyper-parameter doesn't work well. Disable it for now.
* */

package cc.mallet.topics;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.InstanceList;

import org.knowceans.util.ArrayUtils;
import org.knowceans.util.DirichletEstimation;
import org.knowceans.util.IndexQuickSort;
import org.knowceans.util.Samplers;
import org.knowceans.util.Vectors;

public class HDP {
 
   private InstanceList instances;
  
   private int currentIter;
  
   private Random rand;
   /*root base measure*/
   private double beta;
   private double gamma;
  
   /*2nd level Dirichlet Process parameter*/
   private double alpha;//precision
   private ArrayList<Double> tau; //mean
  
   /*hyper parameter*/
   private double a_alpha;
   private double a_beta;
   private double a_gamma;
   private double b_alpha;
   private double b_beta;
   private double b_gamma;
  
   /*Training Instance*/
   private int numDocuments; //num of doc
   private int numTypes; //num of vocabulary
   private int totalWord=0;
   private int topWordNum = 6; //display number of top word
  
   /*algorithm parameter*/
   /*for training*/
   private List<Integer>[] nmk;
   private List<int[]> nkt;
   private List<Integer> nk;
 
   private int[][] z; //topic indicator
   private double[] pp; //topic distribution
   private final int ppStep =10;
  
   private double tables;
  
   //number of samples for parameter samplers
   private int parameterSampleNum=10;
  
   private int K; //initial topic number
   private List<Integer> kgaps;
   private List<Integer> kactive;
  
   /*non-empty parameter*/
   private boolean initialized = false;
   private boolean fixedK = false;
   private boolean fixedHyper = true;
   private boolean showResult = true;
  
   /*set parameter*/
   public HDP(){
     this(1.0, 0.2, 1.0, 5);
   }
  
   public HDP(HDP model){
      this(model.alpha, model.beta, model.gamma, model.K);
   }
  
   public HDP(double alpha, double beta, double gamma, int initialK){
     this.alpha = alpha;
     this.beta = beta;
     this.gamma = gamma;
    
     this.a_alpha=1.0;
     this.a_beta =1.0;
     this.a_gamma=5.0;
     this.b_alpha=1.0;
     this.b_beta=1.0;
     this.b_gamma=5.0;
 
     this.K = initialK;  
     this.currentIter=0;
    
     rand = new Random();
    
     if(gamma== 0){
       this.fixedK=true;
     }
   }
  
   /*set parameters*/
   public void setAlpha(double a){
     this.alpha=a;
    
     //once parameters changed, need to re-initialize
     initialized=false;
   }
  
   /*initialization*/
   public void initialize(InstanceList instances){
    
     if (! (instances.get(0).getData() instanceof FeatureSequence)) {
        throw new IllegalArgumentException("Input must be a FeatureSequence, using the --feature-sequence option when impoting data, for example");
      }
    
     this.instances=instances;
     numDocuments=instances.size();
     numTypes=instances.getDataAlphabet().size();
    
     FeatureSequence fs;
     for(int i=0;i<numDocuments ; i++){
       fs =(FeatureSequence) instances.get(i).getData();
       totalWord+=fs.getLength();
     }
    
     /*initialize algorithm parameter*/
     init();
    
   }
  
   /*initialize algorithm parameter*/
   private void init(){
    
     nmk = new ArrayList[numDocuments];
     nkt = new ArrayList<int[]>();
     nk = new ArrayList<Integer>();
     z = new int[numDocuments][];
    
     FeatureSequence fs;
     for(int m=0; m < numDocuments; m++){
       nmk[m] = new ArrayList<Integer>();
      
       for(int k=0; k<K ; k++){
         nmk[m].add(0);
       }
       fs =(FeatureSequence) instances.get(m).getData();
       z[m]= new int[fs.getLength()];
     }
    
     /*topic index*/
     kgaps = new ArrayList<Integer>();
     kactive = new ArrayList<Integer>();
     tau = new ArrayList<Double>();
    
     for(int k=0 ; k<K ; k++){
       kactive.add(k);
       nkt.add(new int[numTypes]);
       nk.add(0);
       tau.add(1.0/K);
     }
    
     //add one more topic
     tau.add(1.0/K);
     pp = new double[(K+ppStep)];
    
     //initialize
     randomAssignTopics();
    
     if(!fixedK){
       updateTau();
     }
    
     initialized=true;
   }
  
  /*estimate*/
   public void estimate(int iterations){
     estimate(iterations, showResult);
   }
  
   private void estimate(int iterations, boolean printResult){
    
     for(int iter=0; iter < iterations ; iter++){
      
       if(!initialized){
         throw new IllegalStateException("Initialize HDP first!");
       }
      
       for(int i=0 ; i < numDocuments ; i++){    
         updateDocs(i);
       }
      
       if(!fixedK){
         updateTau();
       }
      
       if(iter > 10 && !fixedHyper){
         updateHyper();
       }
        
       //print current status
       if( iter !=0 && (iter % 50) == 0){
         if(!fixedK){
           System.out.println("Iteration=" + iter +"\tTopic=" + K );
         }
         else{
           System.out.print(iter + " ");
         }
         if(!fixedHyper){
           printParameter();
         }
       }
     }//end iter
    
     //accumulate iterations
     currentIter+=iterations;
    
     if(printResult){
       //print a summary of estimation
       System.out.println();
       printParameter();
       System.out.println();
       printTopWord(topWordNum);
    }
    
   }
  
   /*random initialize topics at first*/
   private void randomAssignTopics(){
    
     //uniform multinomial distribution for initial assignment
     for(int kk=0; kk<K ; kk++){  
       //equal probability for each topic
       pp[kk] = 1.0/K;
     }
    
     for(int m=0 ; m < numDocuments ; m++){    
      
       FeatureSequence fs = (FeatureSequence) instances.get(m).getData();
       int seqLen = fs.getLength();
       int type, token, k;
       double sum;
      
       for(token=0 ; token < seqLen ; token++){
        
         type = fs.getIndexAtPosition(token);
        
         int u = rand.nextInt(K);
         //assign topics
         k = kactive.get(u);
         z[m][token]=k;
         //add z back
         nmk[m].set(k, nmk[m].get(k)+1);
         nkt.get(k)[type]++;
         nk.set(k, nk.get(k)+1);
       }
     }
    
     //remove empty topic if topic number are not fixed
     if(!fixedK){
       for(int k=0 ; k<nk.size(); k++)
       {
         if(nk.get(k)==0){
           kactive.remove((Integer)k);
           kgaps.add(k);
           assert(Vectors.sum(nkt.get(k))==0);
           K--;
           updateTau();
         }
       }
     }
   }
  
   private void updateDocs(int m) {
    
     FeatureSequence fs = (FeatureSequence) instances.get(m).getData();
     int seqLen = fs.getLength();
     int type, token;
     double sum;
    
     for(token=0 ; token < seqLen ; token++){
      
       type = fs.getIndexAtPosition(token);
       int k, kold = -1;
      
       if(initialized){
         //get old topic
         k=z[m][token];
         //decrement
         nmk[m].set(k, nmk[m].get(k)-1);
         nkt.get(k)[type]--;
         nk.set(k, nk.get(k)-1);
         kold=k;
       }
      
       sum=0.0;
       for(int kk=0; kk<K ; kk++){
         k=kactive.get(kk);
         pp[kk] = (nmk[m].get(k) + alpha*tau.get(k))*
             (nkt.get(k)[type]+beta) / (nk.get(k)+ numTypes*beta);
        
         sum+=pp[kk];
       }
      
       if(!fixedK){
         pp[K] = alpha * tau.get(K) / numTypes;
         sum+=pp[K];
       }
      
       //sampling
       double u = rand.nextDouble();
       u *= sum;
       sum=0.0;
       int kk;
      
       for(kk=0 ; kk<=K ; kk++){
         sum+=pp[kk];
         if(u <= sum) break;
       }
      
       //check kk is old or new topic
       if(kk < K) //in old topic
       {
         k = kactive.get(kk);
         z[m][token]=k;
         //add z back
         nmk[m].set(k, nmk[m].get(k)+1);
         nkt.get(k)[type]++;
         nk.set(k, nk.get(k)+1);
       }
       else //add new topic
       {
         assert(!fixedK);
         z[m][token]=addTopic(m,type);
         updateTau();
         //System.out.println("add K="+K);
       }
      
       //disable empty topic
       if(initialized && !fixedK && nk.get(kold)==0)
       {
         kactive.remove((Integer)kold);
         kgaps.add(kold);
         assert(Vectors.sum(nkt.get(kold))==0 &&
             nk.get(kold)==0 && nmk[m].get(kold)==0);
         K--;
         //System.out.println("remove K="+K);
        
         updateTau();
       }
     }
   }

  /*auto update hyper parameter need refine*/
  private void updateHyper(){
    for(int r=0 ; r<parameterSampleNum ; r++){
     
      //check paper!
      //gamma: root level
      double eta = Samplers.randBeta(gamma+1 , tables);
      double bloge =b_gamma - Math.log(eta);
      double pie = (1.0 / (tables * bloge/ (gamma+K-1)));
      int u = Samplers.randBernoulli(pie);
      gamma = Samplers.randGamma(a_gamma+K-1+u, 1.0/bloge);
     
      //alpha: document level
      double qs = 0.0;
      double qw = 0.0;
     
      FeatureSequence fs;
      int seqLen;
      for(int m=0 ; m< numDocuments ; m++){
        fs = (FeatureSequence) instances.get(m).getData();
        seqLen = fs.getLength();
        qs += Samplers.randBernoulli(seqLen/(seqLen+alpha));
        qw += Math.log( Samplers.randBeta(alpha+1, seqLen));
      }
      alpha = Samplers.randGamma(a_alpha+ tables -qs, 1.0/(b_alpha-qw));
    }
   
   
    //estimate beta
    //convert nk & akt to array
    int[] ak = (int[]) ArrayUtils.asPrimitiveArray(nk);
    int[][] akt =new int[K][numTypes];
    for(int k=0; k<K; k++){
      akt[k] = nkt.get(k);
    }
    beta = DirichletEstimation.estimateAlphaMap(akt, ak, beta, a_beta, b_beta);
   
   }
  
  private void updateTau() {
    double[] mk = new double[K+1];
   
    for(int kk=0 ; kk<K ; kk++){
      int k = kactive.get(kk);
     
      for(int m=0 ; m < numDocuments ; m++){
       
        if(nmk[m].get(k) > 1){
          //sample number of tables
          mk[kk] +=Samplers.randAntoniak(alpha * tau.get(k),
              nmk[m].get(k));
        }
        else //nmk[m].get(k) = 0 or 1
        {  
          mk[kk] +=nmk[m].get(k);
        }
       
      }
    }// end outter for loop
   
    //get number of tables
    tables = Vectors.sum(mk);
    mk[K] = gamma;
   
    double[] tt =Samplers.randDir(mk);
   
    for(int kk=0 ; kk < K ; kk++){
     
      int k=kactive.get(kk);
      tau.set(k, tt[kk]);
    }
    tau.set(K, tt[K]);
  }
 
  private int addTopic(int m, int type) {
    int k;
    if(kgaps.size()>0){ //reuse gaps
      k=kgaps.remove(kgaps.size()-1);
      kactive.add(k);
      nmk[m].set(k,1);
      nkt.get(k)[type]=1;
      nk.set(k,1);
    }
    else{
      k=K;
     
      for(int i=0 ; i< numDocuments ; i++){
        nmk[i].add(0);
      }
     
      kactive.add(K);
      nmk[m].set(K,1);
      nkt.add(new int[numTypes]);
      nk.add(1);
      tau.add(0.0);
    }
    //add topic number
    K++;
    if(pp.length <=K){
      pp = new double[K + ppStep];
    }
    return k;
  }
 
   /*print*/
   public void printTopWord(int numWords){
    
    //sort topic from largest to smallest
    trimTopics();
   
    int wordCount=0;
   
    for(int k=0 ; k<nk.size(); k++){
       
      if(nk.get(k)!=0){
       
        int count=nk.get(k);
       
        //check word count
        wordCount+=count;
       
        IDSorter[] sortedTypes = new IDSorter[numTypes];
        //sort word in topic k
        for (int type=0; type < numTypes; type++) {
          sortedTypes[type] = new IDSorter(type, nkt.get(k)[type]);
        }
       
        Arrays.sort(sortedTypes);
     
        Alphabet alphabet = instances.getDataAlphabet();
        StringBuffer out = new StringBuffer();
        out.append("topic"+k + ": ");
        out.append("word:"+ count + ", ");
        if(k< kactive.size()){
          out.append("matched topic "+kactive.get(k) + ", ");
        }
        double prop = (double)count/totalWord;
        out.append(String.format("prop:%2.4f, ", prop));
     
        for (int i=0; i<numWords; i++) {
          out.append(alphabet.lookupObject(sortedTypes[i].getID()) + " ");
        }
        System.out.println(out);
      }
      else{
        if(k < kactive.size() )
          System.out.println("Topic"+k+": matched topic " + kactive.get(k));
        else
          System.out.println("Topic"+k+": empty");
      }
    }
    System.out.println("Total Word count: "+ wordCount );
   }
  
   /*this is just used for test & debug*/
   public void printTopWordTest(int numWords){
    
      //trimTopics();
     
      //int[] ak = (int[]) ArrayUtils.asPrimitiveArray(nk);
      //Arrays.sort(ak);
     
      int wordCount=0;
     
      for(int k=0 ; k<nk.size(); k++){
         
        if(nk.get(k)!=0){
         
          //check word count
          wordCount+=nk.get(k);
         
          IDSorter[] sortedTypes = new IDSorter[numTypes];
          //sort word in topic k
          for (int type=0; type < numTypes; type++) {
            sortedTypes[type] = new IDSorter(type, nkt.get(k)[type]);
          }
         
          Arrays.sort(sortedTypes);
       
          Alphabet alphabet = instances.getDataAlphabet();
          StringBuffer out = new StringBuffer();
          out.append("topic"+k + ": ");
          if(k< kactive.size()){
            out.append("matched topic "+kactive.get(k) + ", ");
          }
          out.append("word:"+ nk.get(k) + ", ");
          double prop = (double)nk.get(k)/totalWord;
          out.append(String.format("prop:%2.4f, ", prop));
       
          for (int i=0; i<numWords; i++) {
            out.append(alphabet.lookupObject(sortedTypes[i].getID()) + " ");
          }
          System.out.println(out);
        }
        else{
          if(k < kactive.size() )
            System.out.println("Topic"+k+": matched topic " + kactive.get(k));
          else
            System.out.println("Topic"+k+": empty");
        }
      }
      System.out.println("Total Word count: "+ wordCount );
     }
  
   public void printParameter(){
    
     String out = String.format("Summary: Docs=%d, topics=%d, totalWords=%d, alpha=%2.3f, beta=%2.3f, gamma=%2.3f",numDocuments, K,totalWord, alpha, beta, gamma);
     System.out.println(out);
  
   }
  
   private void trimTopics(){
    
     //System.out.println("start trim");
    
     int[] new_nk = IndexQuickSort.sort(nk);
     IndexQuickSort.reverse(new_nk);
    
     //remove empty topic
     IndexQuickSort.reorder(nk, new_nk);
     IndexQuickSort.reorder(nkt, new_nk);
    
     for(int i=0; i < kgaps.size(); i++){
       nk.remove(nk.size()-1 );
       nkt.remove(nkt.size()-1 );
     }
    
     for(int m=0 ; m < numDocuments; m++){
       IndexQuickSort.reorder(nmk[m], new_nk);
       //remove gaps
       for(int i=0 ; i <kgaps.size() ; i++){
         nmk[m].remove(nmk[m].size()-1);
       }
     }
     //clean kgaps
     kgaps.clear();
     int[] k2knew = IndexQuickSort.inverse(new_nk);
    
     //rewrite topic
     for(int i=0; i<K; i++){
       kactive.set(i, k2knew[kactive.get(i)]);
     }
    
     for(int m=0; m<numDocuments ; m++){
       FeatureSequence fs = (FeatureSequence) instances.get(m).getData();
      
       for(int n=0 ; n<fs.getLength() ; n++){
         z[m][n]=k2knew[z[m][n]];
       }
     }
    
   }
  
   private void exportTopWords(String outFileName){
     //TODO
   }
  
   private void exportTopicDistribution(String outFileName){
     //TODO
   }
  
     public double[] topicDistribution(int m){
    
     if(m >= numDocuments){
       throw new ArrayIndexOutOfBoundsException();
     }
    
     double[] distr = new double[K];
     double alphaSum = K * alpha;
     int totalCount=0;
     for(int k=0; k<K ;k++){
       totalCount+=nmk[m].get(k);
     }
    
    
     for(int k=0 ; k<K ; k++){
       //distr[k] = (alpha + nmk[m].get(k))/(alphaSum + totalCount);
       distr[k] = nmk[m].get(k);
     }
    
     return distr;
   }
  
   /*n-fold cross validation*/
   public void runCrossValidation(int nfolds, int iterations){
    
     InstanceList.CrossValidationIterator iter = instances.crossValidationIterator(nfolds);
     double[] prepResult = new  double[nfolds];
     InstanceList[] splitedList;
     InstanceList trainList, testList;
    
     int fold=0;
     while(iter.hasNext()){
       splitedList = iter.next();
       trainList = splitedList[0];
       testList = splitedList[1];
      
       HDP model = new HDP(this);
       model.initialize(trainList);
       model.printParameter();
       model.estimate(iterations, false);
       HDPInferencer infer = model.getInferencer();
       infer.setInstance(testList);
       infer.estimate(iterations);
       prepResult[fold] = infer.getPreplexity();
       //System.out.println("n="+fold+", preplexity="+prepResult[fold]);
       fold++;
       System.out.println();
     }
    
     //print result
     double sum=0.0;
     for(int i=0; i< nfolds ; i++){
       System.out.println("n="+i+", preplexity="+prepResult[i]);
       sum+=prepResult[i];
     }
     System.out.println("CV result: average preplexity=" + (sum/nfolds));
   }
  
   /*Inferencer*/
   public HDPInferencer getInferencer(){
    
     if(!initialized){
       throw new IllegalStateException("HDP model is not iniitalized.");
     }
    
     //get phi
     double[][] phi = new double[K][numTypes];
     for(int kk=0 ; kk<K ; kk++){
       //int k = kactive.get(kk);
       for(int t=0 ; t < numTypes ; t++){
         phi[kk][t]=(nkt.get(kk)[t]+beta) / (nk.get(kk)+ numTypes*beta);
       }
     }
    
     return new HDPInferencer(alpha, K, phi, instances.getDataAlphabet(), rand);
   }

   /*display result or not*/
   public void showResult(boolean b){
     this.showResult=b;
   }
 
}
TOP

Related Classes of cc.mallet.topics.HDP

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.