Package edu.msu.cme.rdp.classifier

Source Code of edu.msu.cme.rdp.classifier.Classifier

/*
* Classifier.java
*
* Copyright 2006 Michigan State University Board of Trustees
*
* Created on September 18, 2003, 6:05 PM
*/
package edu.msu.cme.rdp.classifier;

import edu.msu.cme.rdp.classifier.utils.ClassifierSequence;
import edu.msu.cme.rdp.readseq.readers.Sequence;
import java.util.*;


/**
* This is the class to do the classification.
* @author  wangqion
* @version 1.0
*/
public class Classifier {

    private TrainingInfo trainingInfo;
    /** The number of bootstrap trials. Initially set to 100. */
    private final int NUM_OF_RUNS = 100;
    /** The assumed maximum number of words per sequence. Initially set to 5000. */
    private final int MAX_NUM_OF_WORDS = 5000;
    /** The minimum number of bases per sequence. Initially set to 200. */
    public static final int MIN_SEQ_LEN = 50;
    public static final int MAX_SEQ_LEN = 5000;

    public static final int MIN_GOOD_WORDS = MIN_SEQ_LEN - ClassifierSequence.WORDSIZE;
    public static final int MIN_BOOTSTRSP_WORDS = 5;   // mininum number of words needs for bootstrap

    private float[][] querySeq_wordProbArr;  // 2-D array for the query sequence
    // the 1st dim is the MAX_NUM_OF_WORDS, the 2nd dim is the number of genus nodes
    private float[] accumulateProbArr;   // an array to temporary store the accumulative probability
    private long seed = 1;
    private Random randomGenerator = new Random(seed);
    private Random randomSelectGenera = new Random();
    private int[] wordIndexArr = new int[1024];

    /** Creates new Classifier.  */
    Classifier(TrainingInfo t) {
        trainingInfo = t;
        int nodeListSize = trainingInfo.getGenusNodeListSize();
        querySeq_wordProbArr = new float[MAX_NUM_OF_WORDS][nodeListSize];
        accumulateProbArr = new float[nodeListSize];
    }

    public String getTrainRank(){
        return trainingInfo.getTrainRank();
    }
   
    /**Takes a query sequence, returns the classification result.
     * For each query sequence, first assign it to a genus node using all the words for calculation.
     * Then randomly chooses one-eighth of the all overlapping words in the query
     * to calculate the joint probability. The number of times a genus was selected out of
     * the number of bootstrap trials was used as an estimate of confidence in the assignment to that genus.
     * @throws ShortSequenceException if the sequence length is less than the minimum sequence length.
     */
    public ClassificationResult classify(Sequence seq) {
        return classify(new ClassifierSequence(seq));
    }

    public ClassificationResult classify(ClassifierSequence seq) {
        return classify(seq, MIN_BOOTSTRSP_WORDS );
    }

    /**Takes a query sequence, returns the classification result.
     * For each query sequence, first assign it to a genus node using all the words for calculation.
     * Then randomly chooses one-eighth of the all overlapping words in the query
     * to calculate the joint probability. The number of times a genus was selected out of
     * the number of bootstrap trials was used as an estimate of confidence in the assignment to that genus.
     * @throws ShortSequenceException if the sequence length is less than the minimum sequence length.
     */
    public ClassificationResult classify(ClassifierSequence seq, int min_bootstrap_words) {
        GenusWordConditionalProb gProb = null;
        int nodeListSize = trainingInfo.getGenusNodeListSize();
        boolean reversed = false;
        int wordCount = seq.getSeqString().length();

        if(wordIndexArr.length < seq.getSeqString().length()) {
            wordIndexArr = new int[seq.getSeqString().length()];
        }

        seq.createWordIndexArr(wordIndexArr);

        if (trainingInfo.isSeqReversed(wordIndexArr, wordCount)) {
            seq = seq.getReversedSeq();
            seq.createWordIndexArr(wordIndexArr);
            reversed = true;
        }

        int goodWordCount = seq.getGoodWordCount();

        if (seq.getSeqString().length() < MIN_SEQ_LEN) {
            throw new ShortSequenceException(seq.getSeqName(), "ShortSequenceException: The length of sequence with recordID="
                    + seq.getSeqName() + " is less than " + MIN_SEQ_LEN);
        }

        if (goodWordCount < MIN_GOOD_WORDS) {
            throw new ShortSequenceException(seq.getSeqName(), "ShortSequenceException: The sequence with recordID="
                    + seq.getSeqName() + " does not have enough valid words, minimum " + MIN_GOOD_WORDS + " are required");
        }

        if (goodWordCount > MAX_NUM_OF_WORDS) {
            querySeq_wordProbArr = new float[goodWordCount][nodeListSize];
            System.err.println("increase the array size to " + goodWordCount);
        }

        int NUM_OF_SELECTIONS = Math.max( goodWordCount / ClassifierSequence.WORDSIZE, min_bootstrap_words);

        for (int offset = 0; offset < goodWordCount; offset++) {
            int wordIndex = wordIndexArr[offset];
            float wordPrior = trainingInfo.getLogWordPrior(wordIndex);
            // first set the value for zero occurrence word
            for (int node = 0; node < nodeListSize; node++) {
                querySeq_wordProbArr[offset][node] = wordPrior - trainingInfo.getLogLeaveCount(node);
            }

            // find the non-zero occurrence value
            int start = trainingInfo.getStartIndex(wordIndex);
            int stop = trainingInfo.getStopIndex(wordIndex);
            for (int n = start; n < stop; n++) {
                gProb = trainingInfo.getWordConditionalProbObject(n);
                querySeq_wordProbArr[offset][gProb.getGenusIndex()] = gProb.getProbability();
            }
        }

        // first use all the word to determine the genus assignment
        for (int j = 0; j < goodWordCount; j++) {
            for (int node = 0; node < nodeListSize; node++) {
                accumulateProbArr[node] += querySeq_wordProbArr[j][node];
            }
        }

        float maxPosteriorProb = Float.NEGATIVE_INFINITY;
        int determinedNodeIndex = 0;
        for (int node = 0; node < nodeListSize; node++) {
            if (accumulateProbArr[node] > maxPosteriorProb) {
                determinedNodeIndex = node;
                maxPosteriorProb = accumulateProbArr[node];
            }
            accumulateProbArr[node] = 0; // reset to 0
        }

        HashMap<HierarchyTree, RankAssignment> determinedMap = new HashMap<HierarchyTree, RankAssignment>();
        HierarchyTree aNode = trainingInfo.getGenusNodebyIndex(determinedNodeIndex);
        while (aNode != null) {
            determinedMap.put(aNode, new RankAssignment(aNode, 0));
            aNode = aNode.getParent();
        }

        randomGenerator.setSeed(seed)// use the seed to choose the same random words for the same sequence

        // for each run, pick the genus node which has the highest posterior probability
        boolean tied = false;
        for (int i = 0; i < NUM_OF_RUNS; i++) {
            for (int j = 0; j < NUM_OF_SELECTIONS; j++) {
                int randomIndex = randomGenerator.nextInt(goodWordCount);

                for (int node = 0; node < nodeListSize; node++) {
                    accumulateProbArr[node] += querySeq_wordProbArr[randomIndex][node];
                }
            }

            maxPosteriorProb = Float.NEGATIVE_INFINITY;
            int bestNodeIndex = 0;

            for (int node = 0; node < nodeListSize; node++) {
                if (accumulateProbArr[node] >= maxPosteriorProb) {
                    if (accumulateProbArr[node] > maxPosteriorProb){
                    // if the maxPosteriorProb tied between multiple genera, we need save all genera with tied score and radomly pick one later
                        bestNodeIndex = node;
                        maxPosteriorProb = accumulateProbArr[node];
                        tied = false;
                    }else{
                        tied = true;
                    }
                }
            }

            // if the maxPosteriorProb tied between multiple genera, we need to radomly pick one
            if (tied){
                ArrayList<Integer> possibleSet = new ArrayList<Integer>();
                for (int node = 0; node < nodeListSize; node++) {
                    if (accumulateProbArr[node] == maxPosteriorProb) {
                        possibleSet.add(node);
                    }
                }

                bestNodeIndex = possibleSet.get(randomSelectGenera.nextInt(possibleSet.size()));
            }

             for (int node = 0; node < nodeListSize; node++) {
                 accumulateProbArr[node] = 0; // reset to 0
             }

            addConfidence(trainingInfo.getGenusNodebyIndex(bestNodeIndex), determinedMap);
        }


        List finalAssigns = getFinalResultList(determinedMap, trainingInfo.getGenusNodebyIndex(determinedNodeIndex));
        ClassificationResult finalResult = new ClassificationResult(seq, reversed, finalAssigns, trainingInfo.getHierarchyInfo());

        return finalResult;
    }

    /**
     * increase the count of the RankAssignment in the map if match that node or any ancestor of that node.
     * @param node
     * @param map
     */
    public void addConfidence(HierarchyTree node, HashMap map) {
        while (node != null) {
            RankAssignment assign = (RankAssignment) map.get(node);
            if (assign != null) {
                assign.setConfidence(assign.getConfidence() + 1.0f);
            }
            node = node.getParent();
        }
    }

    private List getFinalResultList(HashMap map, HierarchyTree aNode) {
        List<RankAssignment> finalAssigns = new ArrayList<RankAssignment>();
        while (aNode != null) {
            RankAssignment assign = (RankAssignment) map.get(aNode);
            assign.setConfidence(assign.getConfidence() / (float) NUM_OF_RUNS);
            finalAssigns.add(0, assign);
            aNode = aNode.getParent();
        }
        return finalAssigns;
    }

    /** Adds a single RankAssignment to the list of the classification results
     *  if a treenode is already included in the list, simply increases the
     *  confidence for that RankAssignment by 1, for easy calculation.
     *   Later this confidence will be divided by NUM_OF_RUNS.
     */
    private void addBestGenusNode(HierarchyTree node, List resultList) {
        int node_taxid = node.getTaxid();
        Iterator it = resultList.iterator();
        boolean found = false;
        while (it.hasNext()) {
            RankAssignment result = (RankAssignment) it.next();
            // if this result is already in the list, increase the confidence
            if (node_taxid == result.getBestClass().getTaxid()) {
                result.setConfidence(result.getConfidence() + 1);
                found = true;
                break;
            }
        }

        if (!found) {
            RankAssignment result = new RankAssignment(node, 1);
            resultList.add(result);
        }
    }

    /** Finds the ancestors of the given node up to the root.
     * The ancestors are kept in a list in which the root is the first item.
     */
    private void findAncestor(HierarchyTree root, RankAssignment curNode, List ancestorList) {
        if (curNode == null) {
            return;
        }
        if (curNode.getBestClass().getTaxid() == root.getTaxid()) {
            ancestorList.add(curNode);
            return;
        }

        RankAssignment parent = new RankAssignment(curNode.getBestClass().getParent(), curNode.getConfidence());
        findAncestor(root, parent, ancestorList);
        ancestorList.add(curNode);
    }

    /** Returns a list of RankAssignment in which root is the first item.
     * Each node in the list is the one which has the highest confidence
     * among the children of the previous node.
     * Algorithm: From the top rank, select the child that has the highest votes,
     * traverse down to the bottom rank level.
     */
    private List getGreedyPath(HierarchyTree root, List resultList) {
        List votingList = new ArrayList()// a list of ancestorLists,
        // each ancestorList cotains a list RankAssignment (from root to genus)

        // first gets all the ancestors for each genus node in the resultList
        Iterator genusIt = resultList.iterator();
        while (genusIt.hasNext()) {
            RankAssignment aResult = (RankAssignment) genusIt.next();
            List ancestorList = new ArrayList();
            findAncestor(root, aResult, ancestorList);
            votingList.add(ancestorList);
        }

        if (votingList.size() == 0) {
            return null;
        }

        // calculates the confidence for the common treenode at the votingList
        // then finds the RankAssignment which has highest confidence at each level
        int ancestorIndex = 0;
        List greedyList = new ArrayList();
        // we assume that every voter will agree with the root assignment
        greedyList.add(((List) votingList.get(0)).get(ancestorIndex));
        ((RankAssignment) ((List) votingList.get(0)).get(ancestorIndex)).setConfidence(1);

        while (root.getSizeofSubclasses() > 0) {
            ancestorIndex++;
            RankAssignment bestNode = null;
            List tempList = new ArrayList();

            for (int i = 0; i < votingList.size(); i++) {
                boolean found = false;
                Iterator tmpIt = tempList.iterator();
                RankAssignment aResult = null;
                if (((List) votingList.get(i)).size() > ancestorIndex) {
                    aResult = (RankAssignment) (((List) votingList.get(i)).get(ancestorIndex));
                } else {
                    continue;
                }

                while (tmpIt.hasNext()) {
                    RankAssignment tmp = (RankAssignment) tmpIt.next();
                    if (aResult.getBestClass().getTaxid() == tmp.getBestClass().getTaxid()) {
                        tmp.setConfidence(tmp.getConfidence() + aResult.getConfidence());
                        if (tmp.getBestClass().getParent().getTaxid() == root.getTaxid()
                                && tmp.getConfidence() > bestNode.getConfidence()) {
                            bestNode = tmp;
                        }
                        found = true;
                        break;
                    }
                }

                if (!found) {
                    tempList.add(aResult);
                    if (aResult.getBestClass().getParent().getTaxid() == root.getTaxid()) {
                        if (bestNode == null) {
                            bestNode = aResult;
                        } else if (aResult.getConfidence() > bestNode.getConfidence()) {
                            bestNode = aResult;
                        }
                    }
                }
            }
            bestNode.setConfidence(bestNode.getConfidence() / (float) NUM_OF_RUNS);
            greedyList.add(bestNode);
            root = bestNode.getBestClass();

        }
        return greedyList;
    }
}
TOP

Related Classes of edu.msu.cme.rdp.classifier.Classifier

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.