Package edu.msu.cme.rdp.multicompare

Source Code of edu.msu.cme.rdp.multicompare.MultiClassifier

/*
* Copyright (C) 2012 Michigan State University <rdpstaff at msu.edu>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/
package edu.msu.cme.rdp.multicompare;

import edu.msu.cme.rdp.classifier.*;
import edu.msu.cme.rdp.classifier.rrnaclassifier.ClassificationParser;
import edu.msu.cme.rdp.classifier.io.ClassificationResultFormatter;
import edu.msu.cme.rdp.classifier.utils.ClassifierFactory;
import edu.msu.cme.rdp.classifier.utils.ClassifierSequence;
import edu.msu.cme.rdp.multicompare.taxon.MCTaxon;
import edu.msu.cme.rdp.readseq.readers.Sequence;
import edu.msu.cme.rdp.taxatree.ConcretRoot;
import edu.msu.cme.rdp.taxatree.Taxon;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.output.NullWriter;

/**
*
* @author fishjord
*/
public class MultiClassifier {

    private static ClassifierFactory classifierFactory;
    private static final float DEFAULT_CONF = 0.8f;
    private static final PrintWriter DEFAULT_ASSIGN_WRITER = new PrintWriter(new NullWriter());
    private static final ClassificationResultFormatter.FORMAT DEFAULT_FORMAT = ClassificationResultFormatter.FORMAT.allRank;
    private File biomFile = null;
    private HashMap<String, HashMap<String, String>> metadataMap = null;
    private String[] ranks = ClassificationResultFormatter.RANKS; // default ranks;
    private HashMap<Taxon, Double> cachedCopynumberMap = new HashMap<Taxon, Double>();
    private boolean hasCopyNumber;

    public MultiClassifier(String propfile, String gene){

        if (propfile != null) {
            ClassifierFactory.setDataProp(propfile, false);
        }
        try {
            classifierFactory = ClassifierFactory.getFactory(gene);
        } catch (Exception ex) {
            throw new RuntimeException(ex);
        }
        if ( gene != null && (gene.equalsIgnoreCase(ClassifierFactory.FUNGALITS_warcup_GENE) || gene.equalsIgnoreCase(ClassifierFactory.FUNGALITS_unite_GENE)) ){
            ranks = ClassificationResultFormatter.RANKS_WITHSPECIES;
        }
        hasCopyNumber = classifierFactory.getRoot().hasCopyNumberInfo();
    }

    public MultiClassifier(String propfile, String gene, File biomFile, File metadataFile) throws IOException{
        this(propfile, gene);
        this.biomFile = biomFile;
        if ( metadataFile != null){
            metadataMap = readMetaData(metadataFile);
        }       
    }
   
    public boolean hasCopyNumber(){
        return hasCopyNumber;
    }
   
    public MultiClassifierResult multiCompare(List<MCSample> samples) throws IOException {
        return multiCompare(samples, DEFAULT_CONF, DEFAULT_ASSIGN_WRITER, DEFAULT_FORMAT, Classifier.MIN_BOOTSTRSP_WORDS);
    }
   
    public MultiClassifierResult multiCompare(List<MCSample> samples, int min_bootstrap_words) throws IOException {
        return multiCompare(samples, DEFAULT_CONF, DEFAULT_ASSIGN_WRITER, DEFAULT_FORMAT, min_bootstrap_words);
    }

    public MultiClassifierResult multiCompare(List<MCSample> samples, float conf, int min_bootstrap_words) throws IOException {
        return multiCompare(samples, conf, DEFAULT_ASSIGN_WRITER, DEFAULT_FORMAT, min_bootstrap_words);
    }

    public MultiClassifierResult multiCompare(List<MCSample> samples, PrintWriter assign_out, int min_bootstrap_words) throws IOException {
        return multiCompare(samples, DEFAULT_CONF, assign_out, DEFAULT_FORMAT, min_bootstrap_words);
    }

    /**
     * Input files are sequence files
     */
    public MultiClassifierResult multiCompare(List<MCSample> samples, float confidence, PrintWriter assign_out,
            ClassificationResultFormatter.FORMAT format, int min_bootstrap_words) throws IOException {
        HierarchyTree sampleTreeRoot  = classifierFactory.getRoot();
        ConcretRoot<MCTaxon> root = new ConcretRoot<MCTaxon>(new MCTaxon(sampleTreeRoot.getTaxid(), sampleTreeRoot.getName(), sampleTreeRoot.getRank()) );

        Classifier classifier = classifierFactory.createClassifier();
        List<String> badSequences = new ArrayList();
        Map<String, Long> seqCountMap = new HashMap();
        Map<String, String> seqClassificationMap = new HashMap(); // holds the classification results to replace the biom metadata
        if ( format.equals(ClassificationResultFormatter.FORMAT.filterbyconf) ){
            for (int i = 0; i <= ranks.length -1; i++) {
                assign_out.print("\t" + ranks[i]);
            }
            assign_out.println();
        }
        for (MCSample sample : samples) {
            Sequence seq;

            while ((seq = sample.getNextSeq()) != null) {
                try {
                    ClassificationResult result = classifier.classify(new ClassifierSequence(seq), min_bootstrap_words);
                    if ( !format.equals(ClassificationResultFormatter.FORMAT.biom)){
                        printClassificationResult(result, assign_out, format, confidence);
                    }else {
                        seqClassificationMap.put(result.getSequence().getSeqName(), ClassificationResultFormatter.getOutput(result, format, confidence, ranks));
                    }
                    processClassificationResult(result, sample, root, confidence, seqCountMap);
                    sample.addRankCount(result);

                } catch (ShortSequenceException e) {
                    badSequences.add(seq.getSeqName());
                }

            }
        }

        if ( format.equals(ClassificationResultFormatter.FORMAT.biom)){
            printBiom(assign_out, seqClassificationMap);
        }
        return new MultiClassifierResult(root, samples, badSequences, seqCountMap);
    }
   
   
    private void printBiom(PrintWriter assign_out, Map<String, String> seqClassificationMap) throws IOException {
        BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(biomFile)));
        String line = null;
        boolean replaceTaxonomy = false;
        boolean replaceSampleMetadata = false;
        boolean continueColumns = false;
        while ( (line=reader.readLine())!= null){
            if ( line.startsWith("\"rows")){
                replaceTaxonomy = true;
                assign_out.println(line);
                continue;
            }
            if (line.startsWith("\"columns\"")){               
                replaceTaxonomy = false;           
                if ( metadataMap != null){               
                    replaceSampleMetadata = true;
                }
                assign_out.println(line);
                continue;
            }
            if (line.trim().startsWith("],")){
                if ( continueColumns){ // finish printing the previous sample
                    assign_out.println("");
                }
                replaceTaxonomy = false;
                replaceSampleMetadata = false;               
            }
            if (replaceTaxonomy) { // replace the taxonomy metadata   
              
                String[] values = line.split("\"");
                String cluster_id = values[3];
                String cluster_classification = seqClassificationMap.get(cluster_id);
                if ( cluster_classification == null ){
                    throw new IllegalArgumentException("Can not find the cluster_id " + cluster_id + " in the classification result of the input sequences");
                }
                String[] classification = cluster_classification.split("\\t");
                assign_out.print("\t {\"id\" : \"" + cluster_id +"\", \"metadata\" : {\"taxonomy\":[");

                assign_out.print( "\"" + classification[1] + "\"");

                if ( line.endsWith(",")){
                    assign_out.println( "]}},");
                }else {
                    assign_out.println( "]}}");
                }
                              
            }else if (replaceSampleMetadata){  // replace the sample metadata
                if ( line.trim().contains("{\"id\"")){
                    if ( continueColumns){ // finish printing the previous sample
                        assign_out.println(",");
                    }
                    String[] values = line.split("\"");
                    HashMap<String, String> sampleMap = metadataMap.get( values[3]);
                    if ( sampleMap== null){
                        throw new IllegalArgumentException("Sample " +  values[3] + " does not have metadata in the metadata file.");
                    }
                    assign_out.println( values[0] + "\"" + values[1] + "\"" + values[2]+ "\"" + values[3]
                            + "\"" + values[4] + "\"" + values[5] +"\" : {");
                    Object[] tempList = sampleMap.keySet().toArray();
                    for ( int i = 0; i < tempList.length; i++){
                        if ( i < tempList.length -1){
                            assign_out.println("\t\t\"" + tempList[i] + "\":\"" + sampleMap.get(tempList[i]) + "\",");
                        }else {
                            assign_out.print("\t\t\"" + tempList[i] + "\":\"" + sampleMap.get(tempList[i]) + "\"}}");
                        }
                    }
                    continueColumns = true;
                }
            }else {
                assign_out.println(line);
            }
        }
       
        reader.close();
    }
   
    private HashMap<String, HashMap<String, String>> readMetaData(File metadataFile) throws IOException{
        BufferedReader reader = new BufferedReader(new FileReader (metadataFile));
        String line = reader.readLine();       
        String[] header = line.split("\\t");
        HashMap<String, HashMap<String, String>> metadataMap = new HashMap<String, HashMap<String, String>>();
               
        while ( (line= reader.readLine())!= null){
            String[] vals = line.split("\\t");
            HashMap<String, String> sampleMap = new HashMap<String, String>();
            metadataMap.put(vals[0].trim(), sampleMap);
            for ( int i = 1; i < header.length; i++){
                sampleMap.put(header[i], vals[i]);
            }
        }
       
        reader.close();
        return metadataMap;
    }

    /**
     * Input files are the classification results
     * printRank indicates which rank to filter by the confidence for the detail assignment output
     * taxonFilter indicates which taxon to match for the detail assignment output
     */
    public MultiClassifierResult multiClassificationParser(List<MCSample> samples, float confidence, PrintWriter assign_out,
            ClassificationResultFormatter.FORMAT format, String printRank, HashSet<String> taxonFilter) throws IOException {
        HierarchyTree sampleTreeRoot  = classifierFactory.getRoot();
        ConcretRoot<MCTaxon> root = new ConcretRoot<MCTaxon>(new MCTaxon(sampleTreeRoot.getTaxid(), sampleTreeRoot.getName(), sampleTreeRoot.getRank()) );
        List<String> badSequences = new ArrayList();
        Map<String, Long> seqCountMap = new HashMap();

        for (MCSample sample : samples) {
            ClassificationParser parser = ((MCSampleResult) sample).getClassificationParser(classifierFactory);
            ClassificationResult result;

            while ((result = parser.next()) != null) {              
                processClassificationResult(result, sample, root, confidence, seqCountMap);
                List<RankAssignment> assignList = result.getAssignments();
                if ( printRank == null){
                    printRank = assignList.get(assignList.size() -1).getRank();
                }
                boolean match = false;
                if ( taxonFilter == null){
                    match = true;
                }else {
                    for ( RankAssignment assign: assignList){
                        if (taxonFilter.contains(assign.getBestClass().getName()) ){
                            match = true;
                            break;
                        }  
                    }
                }
                if ( match){
                    for ( RankAssignment assign: assignList){
                        if ( assign.getRank().equalsIgnoreCase(printRank) && assign.getConfidence() >= confidence ){
                            printClassificationResult(result, assign_out, format, confidence);
                            break;
                        }
                    }
                }
            }
            parser.close();
        }
        return new MultiClassifierResult(root, samples, badSequences, seqCountMap);
    }

    private MCTaxon findOrCreateTaxon(ConcretRoot<MCTaxon> root, RankAssignment assignment, int parentId, boolean unclassified, Map<String, Long> seqCountMap, String lineage) {
        int taxid = assignment.getTaxid();
        if (unclassified) {
            taxid = Taxon.getUnclassifiedId(taxid);
        }

        MCTaxon ret = root.getChildTaxon(taxid);
        if (ret == null) {
            ret = new MCTaxon(assignment.getTaxid(), assignment.getName(), assignment.getRank(), unclassified);
            root.addChild(ret, parentId);

            Long val = seqCountMap.get(ret.getRank());
            if (val == null) {
                val = 0L;
            }
            seqCountMap.put(ret.getRank(), val + 1);
            ret.setLineage(lineage.toString() + ret.getName() + ";" + ret.getRank() + ";");
        }

        return ret;
    }

    private void printClassificationResult(ClassificationResult result, PrintWriter assign_out, ClassificationResultFormatter.FORMAT format, float confidence) throws IOException {
        String assignmentStr = ClassificationResultFormatter.getOutput(result, format, confidence, ranks);
        assign_out.print(assignmentStr);
    }

   
    private void processClassificationResult(ClassificationResult result, MCSample sample, ConcretRoot<MCTaxon> root, float conf, Map<String, Long> seqCountMap) {
        RankAssignment lastAssignment = null;
        RankAssignment twoAgo = null;
        StringBuffer lineage = new StringBuffer();       
        MCTaxon taxon = null;
        MCTaxon cntaxon = null;
        HashSet<MCTaxon> tempTaxonSet = new HashSet<MCTaxon>();
        int parentId = root.getRootTaxid();   
        int count = sample.getDupCount(result.getSequence().getSeqName());
        for (RankAssignment assignment : (List<RankAssignment>) result.getAssignments()) {
            boolean stop = false;
            if (assignment.getConfidence() < conf) {

                parentId = root.getRootTaxid();
                if (twoAgo != null) {
                    parentId = twoAgo.getTaxid();
                }
                cntaxon = taxon;  // we only used the real taxon, not the unclassified taxon to find copy number
                taxon = findOrCreateTaxon(root, lastAssignment, parentId, true, seqCountMap, lineage.toString());
                stop = true;
            } else {
                if (lastAssignment != null) {
                    parentId = lastAssignment.getTaxid();
                }
                taxon = findOrCreateTaxon(root, assignment, parentId, false, seqCountMap, lineage.toString());
                cntaxon = taxon;
           
            tempTaxonSet.add(taxon);           
            twoAgo = lastAssignment;
            lastAssignment = assignment;

            if (stop) {
                break;
            }
            lineage.append(assignment.getName()).append(";").append(assignment.getRank()).append(";");
        }
        if ( hasCopyNumber()){
            // need to get the copy number for lowest level taxon above the conf cutoff, and add the count up
            double copynumber = findCopyNumber(result.getAssignments().get(result.getAssignments().size()-1), cntaxon);
            for ( MCTaxon t: tempTaxonSet) {
                t.incCount(sample, count, copynumber);
            }
        }else {
            for ( MCTaxon t: tempTaxonSet) {
                t.incCount(sample, count);
            }
        }
    }
   
    private double findCopyNumber(RankAssignment assignment, Taxon taxon){
        Double copyNumber = this.cachedCopynumberMap.get(taxon);
        if (copyNumber == null){          
            HierarchyTree curTaxon = assignment.getBestClass();
            while ( curTaxon != null){               
                if ( curTaxon.getName().equalsIgnoreCase(taxon.getName()) && curTaxon.getRank().equalsIgnoreCase(taxon.getRank())){
                    copyNumber = curTaxon.getCopyNumber();
                    this.cachedCopynumberMap.put(taxon, copyNumber);
                    break;
                }
                curTaxon = curTaxon.getParent();
            }
        }
        return copyNumber;
    }
}
TOP

Related Classes of edu.msu.cme.rdp.multicompare.MultiClassifier

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.