Package net.sf.cram.select

Source Code of net.sf.cram.select.SamRecordComparision$Params

package net.sf.cram.select;

import java.io.File;
import java.io.PrintStream;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;

import net.sf.cram.AlignmentSliceQuery;
import net.sf.cram.Bam2Cram;
import net.sf.cram.CramTools.LevelConverter;
import net.sf.picard.util.Log;
import net.sf.picard.util.Log.LogLevel;
import net.sf.samtools.SAMFileReader;
import net.sf.samtools.SAMRecord;
import net.sf.samtools.SAMFileReader.ValidationStringency;
import net.sf.samtools.SAMRecord.SAMTagAndValue;
import net.sf.samtools.SAMRecordIterator;

import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import com.beust.jcommander.Parameters;
import com.beust.jcommander.converters.FileConverter;

public class SamRecordComparision {
  private int maxValueLen = 15;
  private static Log log = Log.getInstance(SamRecordComparision.class);

  public EnumSet<FIELD_TYPE> fields = EnumSet.allOf(FIELD_TYPE.class);
  public Set<String> tagsToIgnore = new TreeSet<String>();
  public Set<String> tagsToCompare = new TreeSet<String>();
  public boolean compareTags = false;
  public int ignoreFlags = 0;
  public int ignoreTLENDiff = 0;

  public static class SamRecordDiscrepancy {
    public FIELD_TYPE field;
    public String tagId;
    public long recordCounter;

    public SAMRecord record1, record2;

    public int prematureEnd = 0;

  }

  public static Object getValue(SAMRecord record, FIELD_TYPE field,
      String tagId) {
    if (field == null)
      throw new IllegalArgumentException("Record field is null.");

    switch (field) {
    case QNAME:
      return record.getReadName();
    case FLAG:
      return Integer.toString(record.getFlags());
    case RNAME:
      return record.getReferenceName();
    case POS:
      return Integer.toString(record.getAlignmentStart());
    case MAPQ:
      return Integer.toString(record.getMappingQuality());
    case CIGAR:
      return record.getCigarString();
    case RNEXT:
      return record.getMateReferenceName();
    case PNEXT:
      return Integer.toString(record.getMateAlignmentStart());
    case TLEN:
      return Integer.toString(record.getInferredInsertSize());
    case SEQ:
      return record.getReadString();
    case QUAL:
      return record.getBaseQualityString();

    case TAG:
      if (tagId == null)
        throw new IllegalArgumentException(
            "Tag mismatch reqiues tag id. ");
      return record.getAttribute(tagId);

    default:
      throw new IllegalArgumentException("Unknown record field: "
          + field.name());
    }
  }

  public boolean compareFieldValue(SAMRecord r1, SAMRecord r2,
      FIELD_TYPE field, String tagId) {
    if (field == null)
      throw new IllegalArgumentException("Record field is null.");

    if (field == FIELD_TYPE.FLAG) {
      int f1 = r1.getFlags() & ~ignoreFlags;
      int f2 = r2.getFlags() & ~ignoreFlags;

      return f1 == f2;
    }

    if (field == FIELD_TYPE.TLEN) {
      int t1 = r1.getInferredInsertSize();
      int t2 = r2.getInferredInsertSize();

      return Math.abs(t1 - t2) <= ignoreTLENDiff;
    }

    Object value1 = getValue(r1, field, tagId);
    Object value2 = getValue(r2, field, tagId);
    return compareObjects(value1, value2);
  }

  private static boolean compareObjects(Object o1, Object o2) {
    if (o1 == null && o2 == null)
      return true;
    if (o1 == null || o2 == null)
      return false;

    if (o1.equals(o2))
      return true;

    if (o1.getClass().isArray() && o2.getClass().isArray()) {
      if (o1 instanceof byte[] && o2 instanceof byte[])
        return Arrays.equals((byte[]) o1, (byte[]) o2);

      if (o1 instanceof short[] && o2 instanceof short[])
        return Arrays.equals((short[]) o1, (short[]) o2);

      return Arrays.equals((Object[]) o1, (Object[]) o2);
    }

    if (o1 instanceof SAMTagAndValue && o2 instanceof SAMTagAndValue) {
      SAMTagAndValue t1 = (SAMTagAndValue) o1;
      SAMTagAndValue t2 = (SAMTagAndValue) o2;

      return t1.tag.equals(t2.tag) && compareObjects(t1.value, t2.value);
    }

    return false;
  }

  private boolean compareTags(SAMRecord r1, SAMRecord r2, long recordCounter,
      List<SamRecordDiscrepancy> list) {
    if (!compareTags)
      return true;

    Map<String, SAMTagAndValue> m1 = new TreeMap<String, SAMRecord.SAMTagAndValue>();
    for (SAMTagAndValue t : r1.getAttributes())
      m1.put(t.tag, t);

    Map<String, SAMTagAndValue> m2 = new TreeMap<String, SAMRecord.SAMTagAndValue>();
    for (SAMTagAndValue t : r2.getAttributes())
      m2.put(t.tag, t);

    boolean equal = true;
    for (String id : m1.keySet()) {
      if (tagsToIgnore.contains(id))
        continue;
      if (!tagsToCompare.isEmpty() && !tagsToCompare.contains(id))
        continue;

      if (m2.containsKey(id) && compareObjects(m1.get(id), m2.get(id)))
        continue;

      SamRecordDiscrepancy d = new SamRecordDiscrepancy();
      d.record1 = r1;
      d.record2 = r2;
      d.field = FIELD_TYPE.TAG;
      d.tagId = id;
      d.recordCounter = recordCounter;
      list.add(d);

      equal = false;
    }

    for (String id : m2.keySet()) {
      if (tagsToIgnore.contains(id))
        continue;
      if (!tagsToCompare.isEmpty() && !tagsToCompare.contains(id))
        continue;

      if (m1.containsKey(id) && compareObjects(m1.get(id), m2.get(id)))
        continue;

      SamRecordDiscrepancy d = new SamRecordDiscrepancy();
      d.record1 = r1;
      d.record2 = r2;
      d.field = FIELD_TYPE.TAG;
      d.tagId = id;
      d.recordCounter = recordCounter;
      list.add(d);

      equal = false;
    }

    return equal;
  }

  public void compareRecords(SAMRecord r1, SAMRecord r2, long recordCounter,
      List<SamRecordDiscrepancy> list) {
    // if (!r1.getReadName().equals(r2.getReadName())
    // || r1.getAlignmentStart() != r2.getAlignmentStart()) {
    // System.err.println("Name mismatch: ");
    // System.err.println("\t"+r1.getSAMString());
    // System.err.println("\t"+r2.getSAMString());
    // }
    for (FIELD_TYPE field : fields) {
      String tagId = null;
      if (field == FIELD_TYPE.TAG) {
        compareTags(r1, r2, recordCounter, list);
      } else {
        if (!compareFieldValue(r1, r2, field, tagId)) {
          SamRecordDiscrepancy d = new SamRecordDiscrepancy();
          d.record1 = r1;
          d.record2 = r2;
          d.field = field;
          d.tagId = null;
          d.recordCounter = recordCounter;
          list.add(d);
        }
      }
    }

  }

  public List<SamRecordDiscrepancy> compareRecords(SAMRecordIterator it1,
      SAMRecordIterator it2, int maxDiscrepandcies) {
    List<SamRecordDiscrepancy> discrepancies = new ArrayList<SamRecordComparision.SamRecordDiscrepancy>();
    long recordCounter = 0;

    while (it1.hasNext() && it2.hasNext()
        && discrepancies.size() < maxDiscrepandcies) {
      recordCounter++;
      SAMRecord record1 = it1.next();
      SAMRecord record2 = it2.next();

      compareRecords(record1, record2, recordCounter, discrepancies);
    }

    if (it1.hasNext() && !it2.hasNext()) {
      SamRecordDiscrepancy d = new SamRecordDiscrepancy();
      d.record1 = it1.next();
      d.prematureEnd = 2;
      discrepancies.add(d);
    } else if (it2.hasNext() && !it1.hasNext()) {
      SamRecordDiscrepancy d = new SamRecordDiscrepancy();
      d.record2 = it2.next();
      d.prematureEnd = 1;
      discrepancies.add(d);
    }

    return discrepancies;
  }

  private static String print(String value, int maxLen) {
    if (value == null)
      return "!NULL!";
    if (value.length() <= maxLen)
      return value;
    else
      return value.substring(0, Math.min(maxLen, value.length())) + "...";
  }

  private static void createDiscrepancyTable(String tableName, Connection c)
      throws SQLException {
    System.out.println(tableName);
    PreparedStatement ps = c
        .prepareStatement("CREATE TABLE "
            + tableName
            + "(counter INT PRIMARY KEY, field VARCHAR, tag VARCHAR, premature int, value1 VARCHAR, value2 VARCHAR, name1 VARCHAR, name2 VARCHAR, record1 VARCHAR, record2 VARCHAR);");
    ps.executeUpdate();
    c.commit();
  }

  private static void dbLog(String tableName,
      Iterator<SamRecordDiscrepancy> it, Connection c)
      throws SQLException {
    PreparedStatement ps = c.prepareStatement("INSERT INTO " + tableName
        + " VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?);");
    while (it.hasNext()) {
      SamRecordDiscrepancy d = it.next();
      int column = 1;
      ps.setLong(column++, d.recordCounter);
      ps.setString(column++, d.field.name());
      ps.setString(column++, d.tagId);
      ps.setInt(column++, d.prematureEnd);

      String value1 = null;
      String value2 = null;
      switch (d.prematureEnd) {
      case 0:
        value1 = SAMRecordField.toString(getValue(d.record1, d.field,
            d.tagId));
        ps.setString(column++, value1);

        value2 = SAMRecordField.toString(getValue(d.record2, d.field,
            d.tagId));
        ps.setString(column++, value2);

        ps.setString(column++, d.record1.getReadName());
        ps.setString(column++, d.record2.getReadName());

        ps.setString(column++, d.record1.getSAMString());
        ps.setString(column++, d.record2.getSAMString());
        break;
      case 1:
        ps.setString(column++, null);

        value2 = SAMRecordField.toString(getValue(d.record2,
            FIELD_TYPE.TAG, d.tagId));
        ps.setString(column++, value2);

        ps.setString(column++, null);
        ps.setString(column++, d.record2.getReadName());

        ps.setString(column++, null);
        ps.setString(column++, d.record2.getSAMString());
        break;
      case 2:
        value1 = SAMRecordField.toString(getValue(d.record1,
            FIELD_TYPE.TAG, d.tagId));
        ps.setString(column++, value1);

        ps.setString(column++, null);

        ps.setString(column++, d.record1.getReadName());
        ps.setString(column++, null);

        ps.setString(column++, d.record1.getSAMString());
        ps.setString(column++, null);
        break;

      default:
        break;
      }

      ps.addBatch();
    }
    ps.executeBatch();
    c.commit();
  }

  public void log(SamRecordDiscrepancy d, boolean dumpRecords, PrintStream ps) {
    switch (d.prematureEnd) {
    case 0:
      if (d.field != FIELD_TYPE.TAG) {
        String value1 = SAMRecordField.toString(getValue(d.record1,
            d.field, null));
        String value2 = SAMRecordField.toString(getValue(d.record2,
            d.field, null));
        ps.println(String.format("FIELD:\t%d\t%s\t%s\t%s",
            d.recordCounter, d.field.name(),
            print(value1, maxValueLen), print(value2, maxValueLen)));
      } else {
        String value1 = SAMRecordField.toString(getValue(d.record1,
            FIELD_TYPE.TAG, d.tagId));
        String value2 = SAMRecordField.toString(getValue(d.record2,
            FIELD_TYPE.TAG, d.tagId));
        ps.println(String.format("TAG:\t%d\t%s\t%s\t%s\t%s",
            d.recordCounter, d.field.name(), d.tagId,
            print(value1, maxValueLen), print(value2, maxValueLen)));
      }
      if (dumpRecords) {
        ps.println("\t" + d.record1.getSAMString());
        ps.println("\t" + d.record2.getSAMString());
      }
      break;
    case 1:
      ps.println(String.format("PREMATURE:\t%d\t%d", d.recordCounter,
          d.prematureEnd));
      if (dumpRecords)
        ps.println("\t" + d.record2.getSAMString());
      break;
    case 2:
      ps.println(String.format("PREMATURE:\t%d\t%d", d.recordCounter,
          d.prematureEnd));
      if (dumpRecords)
        ps.println("\t" + d.record1.getSAMString());
      break;

    default:
      throw new IllegalArgumentException("Unknown premature end value: "
          + d.prematureEnd);
    }

  }

  private static class MutableInt {
    int value;

    public MutableInt(int value) {
      this.value = value;
    }
  }

  public void summary(List<SamRecordDiscrepancy> list, PrintStream ps) {
    Map<String, MutableInt> map = new HashMap<String, SamRecordComparision.MutableInt>();
    for (SamRecordDiscrepancy d : list) {
      String id = d.field == FIELD_TYPE.TAG ? d.tagId : d.field.name();
      MutableInt m = map.get(id);
      if (m == null) {
        m = new MutableInt(0);
        map.put(id, m);
      }
      m.value++;
    }

    for (FIELD_TYPE f : FIELD_TYPE.values()) {
      if (f == FIELD_TYPE.TAG)
        continue;
      MutableInt m = map.remove(f.name());
      if (m == null)
        continue;
      ps.printf("%s: %d\n", f.name(), m.value);
    }

    for (String id : map.keySet()) {
      MutableInt m = map.get(id);
      ps.printf("%s: %d\n", id, m.value);
    }
  }

  /**
   * Counts mismatches in mate flags only.
   *
   * @param list
   * @return
   */
  private int detectCorrectedMateFlagsInSecondMember(
      List<SamRecordDiscrepancy> list) {
    int count = 0;
    for (SamRecordDiscrepancy d : list) {
      if (d.record1.getMateNegativeStrandFlag() != d.record2
          .getMateNegativeStrandFlag()
          || d.record1.getMateUnmappedFlag() != d.record2
              .getMateUnmappedFlag())
        count++;
    }
    return count;
  }

  /**
   * This is supposed to check if the mates have valid pairing flags.
   *
   * @param r1
   * @param r2
   * @return
   */
  private boolean checkMateFlags(SAMRecord r1, SAMRecord r2) {
    if (!r1.getReadPairedFlag() || !r2.getReadPairedFlag())
      return false;

    if (r1.getReadUnmappedFlag() != r2.getMateUnmappedFlag())
      return false;
    if (r1.getReadNegativeStrandFlag() != r2.getMateNegativeStrandFlag())
      return false;
    if (r1.getProperPairFlag() != r2.getProperPairFlag())
      return false;
    if (r1.getFirstOfPairFlag() && r2.getFirstOfPairFlag())
      return false;
    if (r1.getSecondOfPairFlag() && r2.getSecondOfPairFlag())
      return false;

    if (r2.getReadUnmappedFlag() != r1.getMateUnmappedFlag())
      return false;
    if (r2.getReadNegativeStrandFlag() != r1.getMateNegativeStrandFlag())
      return false;

    return true;
  }

  private static void printUsage(JCommander jc) {
    StringBuilder sb = new StringBuilder();
    sb.append("\n");
    jc.usage(sb);

    System.out.println("Version "
        + Bam2Cram.class.getPackage().getImplementationVersion());
    System.out.println(sb.toString());
  }

  public static void main(String[] args) throws SQLException {
    Params params = new Params();
    JCommander jc = new JCommander(params);
    try {
      jc.parse(args);
    } catch (Exception e) {
      System.out
          .println("Failed to parse parameteres, detailed message below: ");
      System.out.println(e.getMessage());
      System.out.println();
      System.out.println("See usage: -h");
      System.exit(1);
    }

    if (args.length == 0 || params.help) {
      printUsage(jc);
      System.exit(1);
    }

    Log.setGlobalLogLevel(params.logLevel);

    if (params.referenceFasta != null)
      System.setProperty("reference",
          params.referenceFasta.getAbsolutePath());

    SAMFileReader.setDefaultValidationStringency(ValidationStringency.SILENT) ;
    SAMFileReader r1 = new SAMFileReader(params.file1);
    SAMFileReader r2 = new SAMFileReader(params.file2);

    SAMRecordIterator it1, it2;
    if (params.location != null) {
      AlignmentSliceQuery query = new AlignmentSliceQuery(params.location);
      if (SAMRecord.NO_ALIGNMENT_REFERENCE_NAME.equals(query.sequence)) {
        it1 = r1.queryUnmapped();
        it2 = r2.queryUnmapped();
      } else {
        it1 = r1.queryContained(query.sequence, query.start, query.end);
        it2 = r2.queryContained(query.sequence, query.start, query.end);
      }
    } else {
      it1 = r1.iterator();
      it2 = r2.iterator();
    }

    SamRecordComparision c = new SamRecordComparision();
    c.maxValueLen = params.maxValueLength;
    c.compareTags = params.compareTags;
    c.ignoreFlags = params.ignoreFalgs;
    c.ignoreTLENDiff = params.ignoreTLENDiff;
    c.maxValueLen = params.maxValueLength ;

    if (params.ignoreTags != null) {
      String chunks[] = params.ignoreTags.split(":");
      for (String tagId : chunks) {
        if (!tagId.matches("^[A-Z]{2}$"))
          throw new RuntimeException(
              "Expecting tag id to match ^[A-Z]{2}$ but got this: "
                  + tagId);
        c.tagsToIgnore.add(tagId);
      }
    }

    if (params.ignoreFields != null) {
      String chunks[] = params.ignoreFields.split(":");
      for (String fieldName : chunks) {
        FIELD_TYPE type = FIELD_TYPE.valueOf(fieldName);
        c.fields.remove(type);
      }
    }

    List<SamRecordDiscrepancy> discrepancies = c.compareRecords(it1, it2,
        params.maxDiscrepancies);

    if (params.countOnly)
      System.out.println(discrepancies.size());
    else if (params.dbDumpFile == null) {
      if (discrepancies.isEmpty())
        System.out.println("No discrepancies found");
      else
        System.out.println("Found discrepansies: "
            + discrepancies.size());
      if (params.dumpDiscrepancies)
        for (SamRecordDiscrepancy d : discrepancies)
          c.log(d, params.dumpRecords, System.out);

      c.summary(discrepancies, System.out);
    } else {
      db(params.dbDumpFile, "discrepancy".toUpperCase(),
          discrepancies.iterator());
    }

    r1.close();
    r2.close();

  }

  private static void db(File dbFile, String tableName,
      Iterator<SamRecordDiscrepancy> it) throws SQLException {
    // Server server = Server.createTcpServer("").start();
    Connection connection = DriverManager.getConnection("jdbc:h2:"
        + dbFile.getAbsolutePath());
    createDiscrepancyTable(tableName, connection);
    dbLog(tableName, it, connection);
    connection.commit();
    connection.close();
  }

  @Parameters(commandDescription = "Compare SAM/BAM/CRAM files.")
  static class Params {
    @Parameter(names = { "-l", "--log-level" }, description = "Change log level: DEBUG, INFO, WARNING, ERROR.", converter = LevelConverter.class)
    LogLevel logLevel = LogLevel.ERROR;

    @Parameter(names = { "--file1" }, converter = FileConverter.class, description = "First input file. ")
    File file1;
    @Parameter(names = { "--file2" }, converter = FileConverter.class, description = "Second input file. ")
    File file2;

    @Parameter(names = { "--reference-fasta-file", "-R" }, converter = FileConverter.class, description = "The reference fasta file if required.")
    File referenceFasta;

    @Parameter(names = { "--max-discrepancies" }, description = "Stop after this many discrepancies found.")
    int maxDiscrepancies = Integer.MAX_VALUE;

    @Parameter(names = { "--max-value-len" }, description = "Trim all values to this length when reporting discrepancies.")
    int maxValueLength = 20;

    @Parameter(names = { "--location" }, description = "Compare reads only for this location, expected pattern: <seq name>:<from pos>-<to pos>")
    String location;

    @Parameter(names = { "--ignore-tags" }, description = "List of tags to ignore, for example: MD:NM:AM")
    String ignoreTags;

    @Parameter(names = { "--ignore-fields" }, description = "List of tags to ignore, for example: TLEN:CIGAR")
    String ignoreFields;

    @Parameter(names = { "-h", "--help" }, description = "Print help and quit")
    boolean help = false;

    @Parameter(names = { "--count-only", "-c" }, description = "Report number of discrepancies only.")
    boolean countOnly = false;

    @Parameter(names = { "--compare-tags" }, description = "Compare tags.")
    boolean compareTags = false;

    @Parameter(names = { "--print-discrepancies" }, description = "Print out the discrepancies found, one per line.")
    boolean dumpDiscrepancies = false;

    @Parameter(names = { "--dump-conflicting-records" }, description = "Print out the records that differ.")
    boolean dumpRecords = false;

    @Parameter(names = { "--dump-to-db" }, description = "Dump the results into the specified database instead of the standard output. ")
    File dbDumpFile;

    @Parameter(names = { "--ignore-flags" }, description = "Ignore some bit flags. This should be an integer mask.")
    int ignoreFalgs = 0;

    @Parameter(names = { "--ignore-tlen-diff" }, description = "Ignore TLEN differences less of equal to this value.")
    int ignoreTLENDiff = 0;
  }
}
TOP

Related Classes of net.sf.cram.select.SamRecordComparision$Params

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.