/*
* Copyright (C) 2012 Michigan State University <rdpstaff at msu.edu>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package edu.msu.cme.rdp.multicompare;
import edu.msu.cme.rdp.classifier.*;
import edu.msu.cme.rdp.classifier.rrnaclassifier.ClassificationParser;
import edu.msu.cme.rdp.classifier.io.ClassificationResultFormatter;
import edu.msu.cme.rdp.classifier.utils.ClassifierFactory;
import edu.msu.cme.rdp.classifier.utils.ClassifierSequence;
import edu.msu.cme.rdp.multicompare.taxon.MCTaxon;
import edu.msu.cme.rdp.readseq.readers.Sequence;
import edu.msu.cme.rdp.taxatree.ConcretRoot;
import edu.msu.cme.rdp.taxatree.Taxon;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.output.NullWriter;
/**
*
* @author fishjord
*/
public class MultiClassifier {
private static ClassifierFactory classifierFactory;
private static final float DEFAULT_CONF = 0.8f;
private static final PrintWriter DEFAULT_ASSIGN_WRITER = new PrintWriter(new NullWriter());
private static final ClassificationResultFormatter.FORMAT DEFAULT_FORMAT = ClassificationResultFormatter.FORMAT.allRank;
private File biomFile = null;
private HashMap<String, HashMap<String, String>> metadataMap = null;
private String[] ranks = ClassificationResultFormatter.RANKS; // default ranks;
private HashMap<Taxon, Double> cachedCopynumberMap = new HashMap<Taxon, Double>();
private boolean hasCopyNumber;
public MultiClassifier(String propfile, String gene){
if (propfile != null) {
ClassifierFactory.setDataProp(propfile, false);
}
try {
classifierFactory = ClassifierFactory.getFactory(gene);
} catch (Exception ex) {
throw new RuntimeException(ex);
}
if ( gene != null && (gene.equalsIgnoreCase(ClassifierFactory.FUNGALITS_warcup_GENE) || gene.equalsIgnoreCase(ClassifierFactory.FUNGALITS_unite_GENE)) ){
ranks = ClassificationResultFormatter.RANKS_WITHSPECIES;
}
hasCopyNumber = classifierFactory.getRoot().hasCopyNumberInfo();
}
public MultiClassifier(String propfile, String gene, File biomFile, File metadataFile) throws IOException{
this(propfile, gene);
this.biomFile = biomFile;
if ( metadataFile != null){
metadataMap = readMetaData(metadataFile);
}
}
public boolean hasCopyNumber(){
return hasCopyNumber;
}
public MultiClassifierResult multiCompare(List<MCSample> samples) throws IOException {
return multiCompare(samples, DEFAULT_CONF, DEFAULT_ASSIGN_WRITER, DEFAULT_FORMAT, Classifier.MIN_BOOTSTRSP_WORDS);
}
public MultiClassifierResult multiCompare(List<MCSample> samples, int min_bootstrap_words) throws IOException {
return multiCompare(samples, DEFAULT_CONF, DEFAULT_ASSIGN_WRITER, DEFAULT_FORMAT, min_bootstrap_words);
}
public MultiClassifierResult multiCompare(List<MCSample> samples, float conf, int min_bootstrap_words) throws IOException {
return multiCompare(samples, conf, DEFAULT_ASSIGN_WRITER, DEFAULT_FORMAT, min_bootstrap_words);
}
public MultiClassifierResult multiCompare(List<MCSample> samples, PrintWriter assign_out, int min_bootstrap_words) throws IOException {
return multiCompare(samples, DEFAULT_CONF, assign_out, DEFAULT_FORMAT, min_bootstrap_words);
}
/**
* Input files are sequence files
*/
public MultiClassifierResult multiCompare(List<MCSample> samples, float confidence, PrintWriter assign_out,
ClassificationResultFormatter.FORMAT format, int min_bootstrap_words) throws IOException {
HierarchyTree sampleTreeRoot = classifierFactory.getRoot();
ConcretRoot<MCTaxon> root = new ConcretRoot<MCTaxon>(new MCTaxon(sampleTreeRoot.getTaxid(), sampleTreeRoot.getName(), sampleTreeRoot.getRank()) );
Classifier classifier = classifierFactory.createClassifier();
List<String> badSequences = new ArrayList();
Map<String, Long> seqCountMap = new HashMap();
Map<String, String> seqClassificationMap = new HashMap(); // holds the classification results to replace the biom metadata
if ( format.equals(ClassificationResultFormatter.FORMAT.filterbyconf) ){
for (int i = 0; i <= ranks.length -1; i++) {
assign_out.print("\t" + ranks[i]);
}
assign_out.println();
}
for (MCSample sample : samples) {
Sequence seq;
while ((seq = sample.getNextSeq()) != null) {
try {
ClassificationResult result = classifier.classify(new ClassifierSequence(seq), min_bootstrap_words);
if ( !format.equals(ClassificationResultFormatter.FORMAT.biom)){
printClassificationResult(result, assign_out, format, confidence);
}else {
seqClassificationMap.put(result.getSequence().getSeqName(), ClassificationResultFormatter.getOutput(result, format, confidence, ranks));
}
processClassificationResult(result, sample, root, confidence, seqCountMap);
sample.addRankCount(result);
} catch (ShortSequenceException e) {
badSequences.add(seq.getSeqName());
}
}
}
if ( format.equals(ClassificationResultFormatter.FORMAT.biom)){
printBiom(assign_out, seqClassificationMap);
}
return new MultiClassifierResult(root, samples, badSequences, seqCountMap);
}
private void printBiom(PrintWriter assign_out, Map<String, String> seqClassificationMap) throws IOException {
BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(biomFile)));
String line = null;
boolean replaceTaxonomy = false;
boolean replaceSampleMetadata = false;
boolean continueColumns = false;
while ( (line=reader.readLine())!= null){
if ( line.startsWith("\"rows")){
replaceTaxonomy = true;
assign_out.println(line);
continue;
}
if (line.startsWith("\"columns\"")){
replaceTaxonomy = false;
if ( metadataMap != null){
replaceSampleMetadata = true;
}
assign_out.println(line);
continue;
}
if (line.trim().startsWith("],")){
if ( continueColumns){ // finish printing the previous sample
assign_out.println("");
}
replaceTaxonomy = false;
replaceSampleMetadata = false;
}
if (replaceTaxonomy) { // replace the taxonomy metadata
String[] values = line.split("\"");
String cluster_id = values[3];
String cluster_classification = seqClassificationMap.get(cluster_id);
if ( cluster_classification == null ){
throw new IllegalArgumentException("Can not find the cluster_id " + cluster_id + " in the classification result of the input sequences");
}
String[] classification = cluster_classification.split("\\t");
assign_out.print("\t {\"id\" : \"" + cluster_id +"\", \"metadata\" : {\"taxonomy\":[");
assign_out.print( "\"" + classification[1] + "\"");
if ( line.endsWith(",")){
assign_out.println( "]}},");
}else {
assign_out.println( "]}}");
}
}else if (replaceSampleMetadata){ // replace the sample metadata
if ( line.trim().contains("{\"id\"")){
if ( continueColumns){ // finish printing the previous sample
assign_out.println(",");
}
String[] values = line.split("\"");
HashMap<String, String> sampleMap = metadataMap.get( values[3]);
if ( sampleMap== null){
throw new IllegalArgumentException("Sample " + values[3] + " does not have metadata in the metadata file.");
}
assign_out.println( values[0] + "\"" + values[1] + "\"" + values[2]+ "\"" + values[3]
+ "\"" + values[4] + "\"" + values[5] +"\" : {");
Object[] tempList = sampleMap.keySet().toArray();
for ( int i = 0; i < tempList.length; i++){
if ( i < tempList.length -1){
assign_out.println("\t\t\"" + tempList[i] + "\":\"" + sampleMap.get(tempList[i]) + "\",");
}else {
assign_out.print("\t\t\"" + tempList[i] + "\":\"" + sampleMap.get(tempList[i]) + "\"}}");
}
}
continueColumns = true;
}
}else {
assign_out.println(line);
}
}
reader.close();
}
private HashMap<String, HashMap<String, String>> readMetaData(File metadataFile) throws IOException{
BufferedReader reader = new BufferedReader(new FileReader (metadataFile));
String line = reader.readLine();
String[] header = line.split("\\t");
HashMap<String, HashMap<String, String>> metadataMap = new HashMap<String, HashMap<String, String>>();
while ( (line= reader.readLine())!= null){
String[] vals = line.split("\\t");
HashMap<String, String> sampleMap = new HashMap<String, String>();
metadataMap.put(vals[0].trim(), sampleMap);
for ( int i = 1; i < header.length; i++){
sampleMap.put(header[i], vals[i]);
}
}
reader.close();
return metadataMap;
}
/**
* Input files are the classification results
* printRank indicates which rank to filter by the confidence for the detail assignment output
* taxonFilter indicates which taxon to match for the detail assignment output
*/
public MultiClassifierResult multiClassificationParser(List<MCSample> samples, float confidence, PrintWriter assign_out,
ClassificationResultFormatter.FORMAT format, String printRank, HashSet<String> taxonFilter) throws IOException {
HierarchyTree sampleTreeRoot = classifierFactory.getRoot();
ConcretRoot<MCTaxon> root = new ConcretRoot<MCTaxon>(new MCTaxon(sampleTreeRoot.getTaxid(), sampleTreeRoot.getName(), sampleTreeRoot.getRank()) );
List<String> badSequences = new ArrayList();
Map<String, Long> seqCountMap = new HashMap();
for (MCSample sample : samples) {
ClassificationParser parser = ((MCSampleResult) sample).getClassificationParser(classifierFactory);
ClassificationResult result;
while ((result = parser.next()) != null) {
processClassificationResult(result, sample, root, confidence, seqCountMap);
List<RankAssignment> assignList = result.getAssignments();
if ( printRank == null){
printRank = assignList.get(assignList.size() -1).getRank();
}
boolean match = false;
if ( taxonFilter == null){
match = true;
}else {
for ( RankAssignment assign: assignList){
if (taxonFilter.contains(assign.getBestClass().getName()) ){
match = true;
break;
}
}
}
if ( match){
for ( RankAssignment assign: assignList){
if ( assign.getRank().equalsIgnoreCase(printRank) && assign.getConfidence() >= confidence ){
printClassificationResult(result, assign_out, format, confidence);
break;
}
}
}
}
parser.close();
}
return new MultiClassifierResult(root, samples, badSequences, seqCountMap);
}
private MCTaxon findOrCreateTaxon(ConcretRoot<MCTaxon> root, RankAssignment assignment, int parentId, boolean unclassified, Map<String, Long> seqCountMap, String lineage) {
int taxid = assignment.getTaxid();
if (unclassified) {
taxid = Taxon.getUnclassifiedId(taxid);
}
MCTaxon ret = root.getChildTaxon(taxid);
if (ret == null) {
ret = new MCTaxon(assignment.getTaxid(), assignment.getName(), assignment.getRank(), unclassified);
root.addChild(ret, parentId);
Long val = seqCountMap.get(ret.getRank());
if (val == null) {
val = 0L;
}
seqCountMap.put(ret.getRank(), val + 1);
ret.setLineage(lineage.toString() + ret.getName() + ";" + ret.getRank() + ";");
}
return ret;
}
private void printClassificationResult(ClassificationResult result, PrintWriter assign_out, ClassificationResultFormatter.FORMAT format, float confidence) throws IOException {
String assignmentStr = ClassificationResultFormatter.getOutput(result, format, confidence, ranks);
assign_out.print(assignmentStr);
}
private void processClassificationResult(ClassificationResult result, MCSample sample, ConcretRoot<MCTaxon> root, float conf, Map<String, Long> seqCountMap) {
RankAssignment lastAssignment = null;
RankAssignment twoAgo = null;
StringBuffer lineage = new StringBuffer();
MCTaxon taxon = null;
MCTaxon cntaxon = null;
HashSet<MCTaxon> tempTaxonSet = new HashSet<MCTaxon>();
int parentId = root.getRootTaxid();
int count = sample.getDupCount(result.getSequence().getSeqName());
for (RankAssignment assignment : (List<RankAssignment>) result.getAssignments()) {
boolean stop = false;
if (assignment.getConfidence() < conf) {
parentId = root.getRootTaxid();
if (twoAgo != null) {
parentId = twoAgo.getTaxid();
}
cntaxon = taxon; // we only used the real taxon, not the unclassified taxon to find copy number
taxon = findOrCreateTaxon(root, lastAssignment, parentId, true, seqCountMap, lineage.toString());
stop = true;
} else {
if (lastAssignment != null) {
parentId = lastAssignment.getTaxid();
}
taxon = findOrCreateTaxon(root, assignment, parentId, false, seqCountMap, lineage.toString());
cntaxon = taxon;
}
tempTaxonSet.add(taxon);
twoAgo = lastAssignment;
lastAssignment = assignment;
if (stop) {
break;
}
lineage.append(assignment.getName()).append(";").append(assignment.getRank()).append(";");
}
if ( hasCopyNumber()){
// need to get the copy number for lowest level taxon above the conf cutoff, and add the count up
double copynumber = findCopyNumber(result.getAssignments().get(result.getAssignments().size()-1), cntaxon);
for ( MCTaxon t: tempTaxonSet) {
t.incCount(sample, count, copynumber);
}
}else {
for ( MCTaxon t: tempTaxonSet) {
t.incCount(sample, count);
}
}
}
private double findCopyNumber(RankAssignment assignment, Taxon taxon){
Double copyNumber = this.cachedCopynumberMap.get(taxon);
if (copyNumber == null){
HierarchyTree curTaxon = assignment.getBestClass();
while ( curTaxon != null){
if ( curTaxon.getName().equalsIgnoreCase(taxon.getName()) && curTaxon.getRank().equalsIgnoreCase(taxon.getRank())){
copyNumber = curTaxon.getCopyNumber();
this.cachedCopynumberMap.put(taxon, copyNumber);
break;
}
curTaxon = curTaxon.getParent();
}
}
return copyNumber;
}
}