Package edu.berkeley.spectralHMM.oneD

Source Code of edu.berkeley.spectralHMM.oneD.SelectionHMM$TimedSample

/*
    This file is part of spectralHMM.

    spectralHMM is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    spectralHMM is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with spectralHMM.  If not, see <http://www.gnu.org/licenses/>.
  */

package edu.berkeley.spectralHMM.oneD;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.LineNumberReader;
import java.io.PrintStream;
import java.io.Reader;
import java.math.BigDecimal;
import java.math.MathContext;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.Formatter;

import org.nevec.rjm.BigDecimalMath;

import com.martiansoftware.jsap.FlaggedOption;
import com.martiansoftware.jsap.JSAP;
import com.martiansoftware.jsap.JSAPException;
import com.martiansoftware.jsap.JSAPResult;
import com.martiansoftware.jsap.Parameter;
import com.martiansoftware.jsap.SimpleJSAP;
import com.martiansoftware.jsap.Switch;
import com.martiansoftware.jsap.stringparsers.FileStringParser;

import edu.berkeley.spectralHMM.algorithms.Combinatorics;
import edu.berkeley.spectralHMM.matrix.MatrixPower;
import edu.berkeley.spectralHMM.matrix.TDFEigenSolverBanded.EigenRefineError;
import edu.berkeley.spectralHMM.matrix.TDFEigenSolverBanded.EigenRefineException;

public class SelectionHMM {

  public static BigDecimal run (BigDecimal alpha, BigDecimal beta, BigDecimal hetF, BigDecimal homF, InitialConditionEnum initCondition, BigDecimal initFrequency, int matrixCutoff, int maxM, int maxN, TimedSample sample, Boolean ignoreBinomials, boolean condOnLastSegregating, MathContext mc) throws EigenRefineException, EigenRefineError 

    // get the sample
    ArrayList<BigDecimal> times = sample.times;
    ArrayList<Integer> N = sample.N;
    ArrayList<Integer> S = sample.S;
   
    // this the number of samples
    int numSamples = N.size();

    // everything all right?
    assert (times.size() == N.size() + 1);
    assert (times.size() == S.size() + 1);

    // a and b
    ArrayList<ArrayList<BigDecimal>> a = new ArrayList<ArrayList<BigDecimal>> (numSamples + 1);
    ArrayList<ArrayList<BigDecimal>> b = new ArrayList<ArrayList<BigDecimal>> (numSamples + 1);
   
    // initialize for all timepoints
    for (int i = 0; i <= numSamples; i++) {
      ArrayList<BigDecimal> va = new ArrayList<BigDecimal>(maxN + 1);
      ArrayList<BigDecimal> vb = new ArrayList<BigDecimal>(maxN + 1);
      a.add(va);
      b.add(vb);
    }

   
    // get constants
    BigDecimal C = SelectionEigenfunktion.getC (alpha, beta, hetF, homF, matrixCutoff, maxM, mc);
    BigDecimal c0 = SelectionEigenfunktion.evaluateSquaredLength (alpha, beta, hetF, homF, matrixCutoff, maxM, 0, mc);
    // check C
    System.out.println ("# [see]\t" + C + "\t" + c0);
   
    BigDecimal[][] W = SelectionEigenfunktion.getEigenVectorMatrixCopy (alpha, beta, hetF, homF, matrixCutoff, maxN+1, maxM+1, mc);
   
    // transpose the matrix
    BigDecimal[][] tildeW = MatrixPower.transpose (W);
    for (int j = 0; j <= maxN; j++)  {
      BigDecimal cj = SelectionEigenfunktion.evaluateSquaredLength(alpha, beta, hetF, homF, matrixCutoff, maxM, j, mc);
      for (int i = 0; i <= maxM; i++)  {
        tildeW[i][j] = tildeW[i][j].divide(cj, mc);
      }
    }
    for (int i = 0; i <= maxM; i++)  {
      BigDecimal ciSquaredLength = JacobiPolynomials.evaluateSquaredLength (alpha, beta, i, mc);
      for (int j = 0; j <= maxN; j++)  {
        tildeW[i][j] = tildeW[i][j].multiply(ciSquaredLength, mc);
      }
    }
   
   
    // copy the initial values
    ArrayList<BigDecimal> vb0 = b.get(0);

    // and initialize them
    if (initCondition == InitialConditionEnum.MutationSelection) {
      // this is initial vector for mutation selection balance
      for (int n = 0; n <= maxN; n++)  {
        // don't forget the normalizing constant
        if (n == 0) vb0.add(C.divide(c0, mc));
        else vb0.add(BigDecimal.ZERO.setScale(mc.getPrecision()));
      }
    }
    else if (initCondition == InitialConditionEnum.InitialFrequency) {
      // this is initial vector for a given initial frequency
      for (int n = 0; n <= maxN; n++)  {
        BigDecimal BnX = SelectionEigenfunktion.evaluate (alpha, beta, hetF, homF, matrixCutoff, maxM, n, initFrequency, mc);
        BigDecimal cn = SelectionEigenfunktion.evaluateSquaredLength (alpha, beta, hetF, homF, matrixCutoff, maxM, n, mc);
        vb0.add (BnX.divide(cn, mc));
      }
    }
    else if (initCondition == InitialConditionEnum.MutationDrift) {
    // this is initial vector for a drift mutation balance
      //implicitly using heterozygous and homozygous fitnesses for hetF and homF
      BigDecimal mHetF = BigDecimal.ZERO.subtract(hetF, mc);
      BigDecimal mHomF = BigDecimal.ZERO.subtract(homF, mc);
     
      // get a copy of the eigenvector
      BigDecimal[] v = SelectionEigenfunktion.getEigenVectorCopy(alpha, beta, mHetF, mHomF, matrixCutoff, maxM, 0, mc);
      BigDecimal Cp = SelectionEigenfunktion.getC (alpha, beta, mHetF, mHomF, matrixCutoff, maxM, mc);
      BigDecimal thisBetaFunction = JacobiPolynomials.betaFunction(alpha, beta, mc);
      for (int i=0; i<v.length; i++) {
        v[i] = v[i].multiply(JacobiPolynomials.evaluateSquaredLength (alpha, beta, i, mc), mc);
      }

      // calculate entries
      for (int n = 0; n <= maxN; n++)  {
        BigDecimal value = BigDecimal.ZERO;
        for (int m=0; m<=maxM; m++) {
          BigDecimal increment = W[n][m].multiply(v[m], mc);
          value = value.add(increment, mc);
        }
        BigDecimal cn = SelectionEigenfunktion.evaluateSquaredLength (alpha, beta, hetF, homF, matrixCutoff, maxM, n, mc);
        value = value.divide(cn, mc).divide(thisBetaFunction, mc).divide(Cp, mc);
        vb0.add(value);
      }   
    }
    else {
      // no other initial condition allowed for now
      assert (false);
    }

   
    // analyze sample
    for (int k = 1; k <= numSamples; k++) {
      Integer Nk = N.get(k-1);
      Integer Sk = S.get(k-1);
     
      //update a's
      ArrayList<BigDecimal> va = a.get(k);
      for (int n = 0; n <= maxN; n++)  {
        BigDecimal LambdaN = SelectionEigenfunktion.getEigenValue (alpha, beta, hetF, homF, matrixCutoff, maxM, n, mc);
        BigDecimal val = BigDecimalMath.exp (BigDecimal.ZERO.setScale(mc.getPrecision()).subtract(times.get(k).subtract(times.get(k-1), mc).multiply(LambdaN , mc), mc), mc);
        val = val.multiply(b.get(k-1).get(n), mc);
        va.add(val);
      }

      //update b's
      ArrayList<BigDecimal> vb = b.get(k);
     
      // get the whole N
      BigDecimal[][] nMatrix = getNMatrix (alpha, beta, Nk, Sk, maxM, mc);
     
      BigDecimal[][] aMat = new BigDecimal[1][maxN+1];
      for (int i = 0; i <= maxN; i++)  {
        aMat[0][i] = a.get(k).get(i);
      }
      BigDecimal[][] ret = MatrixPower.multiplyMatricesBanded(aMat, W, maxN, maxM, mc);
      ret = MatrixPower.multiplyMatricesBanded(ret, nMatrix, maxM, Nk, mc);
      ret = MatrixPower.multiplyMatricesBanded(ret, tildeW, maxM, maxM, mc);
      for (int i = 0; i <= maxN; i++)  {
        if (! ignoreBinomials)  {
          // calculate the binomial
          BigDecimal binom = new BigDecimal (Combinatorics.choose(N.get(k-1), S.get(k-1)).toString());
          ret[0][i] = ret[0][i].multiply(binom, mc);
        }
        vb.add(ret[0][i]);
      }
     
    }
   
    // to get final answer, we see whether we want to condition
    BigDecimal ans = b.get(numSamples).get(0);
   
    // whats the conditioning
    if (!condOnLastSegregating) {
      // just normalize it with the right value
      ans = ans.multiply(c0, mc).divide(C, mc);
    }
    else {
      // if we condition on the last one segregating, we have to normalize our result
      // we have to start with the initial condition in b^(0) and evolve it all the way till the end
     
      // get the total time
      BigDecimal totalTime = times.get(numSamples).subtract (times.get (0), mc);
     
      // and evolve
      ArrayList<BigDecimal> totalA = new ArrayList<BigDecimal>(maxN + 1);
      for (int n = 0; n <= maxN; n++)  {
        BigDecimal LambdaN = SelectionEigenfunktion.getEigenValue (alpha, beta, hetF, homF, matrixCutoff, maxM, n, mc);
        BigDecimal val = BigDecimalMath.exp (BigDecimal.ZERO.subtract(totalTime, mc).multiply(LambdaN , mc), mc);
        val = val.multiply(b.get(0).get(n), mc);
        totalA.add(val);
      }

      //  then compute the probability of finding no derived allele at the end
      // we only need endB_0
      // thus, we only need some matrix product between some endA matrix
      BigDecimal[][] totalAMat = new BigDecimal[1][maxN+1];
      for (int i = 0; i <= maxN; i++)  {
        totalAMat[0][i] = totalA.get(i);
      }
      // end the corresponding sampling matrix
      int lastN = N.get(N.size()-1);
      BigDecimal[][] nMatrix = getNMatrix (alpha, beta, lastN, 0, maxM, mc);
      // with some W involved
      BigDecimal[][] ret = MatrixPower.multiplyMatricesBanded (totalAMat, W, maxN, maxM, mc);
      ret = MatrixPower.multiplyMatricesBanded(ret, nMatrix, maxM, lastN, mc);
      ret = MatrixPower.multiplyMatricesBanded(ret, tildeW, maxM, maxM, mc);

      // copy the last entry
      BigDecimal totalB_0 = ret[0][0];
      if (! ignoreBinomials)  {
        // calculate the binomial
        BigDecimal binom = new BigDecimal (Combinatorics.choose(lastN, 0).toString());
        totalB_0 = totalB_0.multiply(binom, mc);
      }

      // then adjust the final answer (divide by prob of segregating) [put the factor in there]
      ans = ans.divide (C.divide(c0, mc).subtract (totalB_0, mc), mc);

    }
   
    // and return the result
    return ans;
  }

  public static BigDecimal[][] getNMatrix(BigDecimal alpha, BigDecimal beta, Integer Nk, Integer Sk, int maxM, MathContext mc) {
    BigDecimal[][] L = JacobiPolynomials.coefficientMatrix(alpha, beta, Sk, maxM, mc, true);
    BigDecimal[][] M = JacobiPolynomials.coefficientMatrix(alpha, beta, Nk - Sk, maxM, mc, false);
    BigDecimal[][] ret = MatrixPower.multiplyMatricesBanded (L, M, Sk, Nk - Sk, mc);
    return ret;
  }

  @SuppressWarnings("unused")
  private static BigDecimal[][] getNMatrixRescaled(BigDecimal alpha, BigDecimal beta, Integer Nk, Integer Sk, int maxM, MathContext mc) {
    BigDecimal[][] L = JacobiPolynomials.coefficientMatrixRescaled(alpha, beta, Nk, Sk, maxM, mc, true, (Sk > Nk - Sk));
    BigDecimal[][] M = JacobiPolynomials.coefficientMatrixRescaled(alpha, beta, Nk, Nk - Sk, maxM, mc, false, (Sk <= Nk - Sk));
    BigDecimal[][] ret = MatrixPower.multiplyMatricesBanded (L, M, Sk, Nk - Sk, mc);
    return ret;
  }


//  @SuppressWarnings("unused")
//  private static BigDecimal[][] getEigenfunctions(BigDecimal alpha, BigDecimal beta, BigDecimal hetF, BigDecimal homF, int matrixCutoff, int maxM, int maxN, MathContext mc) throws EigenRefineException, EigenRefineError {
//    // make a grid
//    int resolution = 50;
//    BigDecimal bigResolution = new BigDecimal(resolution);
//    BigDecimal[] theGrid = new BigDecimal[resolution + 1];
//    // beginning
//    theGrid[0] = BigDecimal.ZERO;
//    theGrid[0] = theGrid[0].setScale(mc.getPrecision());
//    // the middle
//    for (int i=1; i< resolution; i++) {
//      theGrid[i] = new BigDecimal(i).divide(bigResolution, mc);
//      theGrid[i] = theGrid[i].setScale(mc.getPrecision());
//    }
//    // end
//    theGrid[resolution] = BigDecimal.ONE;
//    theGrid[resolution] = theGrid[resolution].setScale(mc.getPrecision());
//
//    // we also need the weight function on the yGrid
//    BigDecimal[] weightFunction = new BigDecimal[theGrid.length];
//    for (int y=0; y<weightFunction.length; y++) {
////       exp(\bar\sigma(y)) y^(alpha-1) (1-y)^(beta-1)
//      BigDecimal jacobiWeight = BigDecimalMath.pow(theGrid[y], alpha.subtract(BigDecimal.ONE, mc), mc).multiply(BigDecimalMath.pow (BigDecimal.ONE.subtract(theGrid[y], mc), beta.subtract(BigDecimal.ONE, mc), mc), mc);
//      weightFunction[y] = (BigDecimalMath.exp (SelectionEigenfunktion.meanFitness(hetF, homF, theGrid[y], mc), mc)).multiply (jacobiWeight, mc);
//    }
//   
//    // get the forward eigenfunctions on that grid
//    BigDecimal[][] results = new BigDecimal[maxN+1][theGrid.length];
//    for (int n=0; n<=maxN; n++) {
//      for (int y=0; y<theGrid.length; y++) {
//        results[n][y] = weightFunction[y].multiply(SelectionEigenfunktion.evaluate (alpha, beta, hetF, homF, matrixCutoff, maxM, n, theGrid[y], mc), mc);
//      }
//    }
//    return results;
//  }

  public static BigDecimal N(int j, int l, int nk, int sk, BigDecimal alpha, BigDecimal beta, MathContext mc) {
   
    // just to be sure
    assert (nk >= sk);

    // initialize return value
    BigDecimal ret = BigDecimal.ZERO;
   
    int offset = JacobiPolynomials.calculateOffset(alpha, beta);
    for (int mu = Math.max(0, Math.max(j - sk, l - (nk - sk))); mu <= Math.min(j + sk, l + (nk - sk)); mu++){
      ret = ret.add(JacobiPolynomials.L(alpha, beta, j, mu, sk, offset, mc).multiply(JacobiPolynomials.M(alpha, beta, mu, l, (nk-sk), offset, mc), mc), mc);
    }
    // multiply it with length
    BigDecimal clSquaredLength = JacobiPolynomials.evaluateSquaredLength (alpha, beta, l, mc);
    ret = ret.multiply(clSquaredLength, mc);
   
    // return the final thing
    return ret;
  }

  // debugging output of the vector
  @SuppressWarnings("unused")
  private static void printVector (String name, ArrayList<BigDecimal> vb0, int maxN, BigDecimal c0, BigDecimal C, MathContext mc) {
    System.out.println ("# " + name + "\t");
    for (BigDecimal value : vb0) {
      System.out.println (value + "\t");
    }
    System.out.println();
    assert (vb0.size() == maxN + 1);   
  }

 
  // wrapper to catch exceptions a bit better
  public static void main (String[] args) {
    try {
      realMain (args);
    } catch (FileNotFoundException e) {
      System.err.println ("Error: Invalid input file secified:\n\t" + e.getMessage());
    } catch (JSAPException e) {
      System.err.println ("Error while parsing command line arguments (--help for usage):\n\t" + e.getMessage());
    } catch (IOException e) {
      System.err.println ("I/O error:\n\t" + e.getMessage());
    }
  }

 
  public static void realMain (String[] args) throws JSAPException, FileNotFoundException, IOException {
    long startTime = System.currentTimeMillis();

    // build a parser for the input file, with the right parameters
    FileStringParser myFileParser = FileStringParser.getParser();
    myFileParser.setMustBeFile(true);
    myFileParser.setMustExist(true);
   
    // build the jsap object
    SimpleJSAP jsap = new SimpleJSAP(
              "spectralHMM",
              "Analyze temporal data using a spectral HMM method.",
              new Parameter[] {
          new FlaggedOption( "inputFile", myFileParser, JSAP.NO_DEFAULT, JSAP.REQUIRED, 'f', "inputFile", "Specify an input file. See documentation for formatting." ),
          new Switch ("mutSelBalance", 'j', "mutSelBalance", "Set the initial condition to be selection-drift balance."),
          new Switch ("mutDriftBalance", 'd', "mutDriftBalance", "Set the initial condition to be mutation-drift balance."),
          new FlaggedOption( "initFrequency", JSAP.BIGDECIMAL_PARSER, "-0.5", JSAP.NOT_REQUIRED, 'i', "initFrequency", "Set the initial frequency." ),
          new FlaggedOption( "initTime", JSAP.BIGDECIMAL_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 't', "initTime", "The initial time. If not given, the initial time is expected to be in the file with each dataset. Also, this only works with multiplexing." ),
          new FlaggedOption( "mutToBenef", JSAP.BIGDECIMAL_PARSER, JSAP.NO_DEFAULT, JSAP.REQUIRED, 'a', "mutToBenef", "The mutation rate from the wild type allele to the selected allele." ),
          new FlaggedOption( "mutFromBenef", JSAP.BIGDECIMAL_PARSER, JSAP.NO_DEFAULT, JSAP.REQUIRED, 'b', "mutFromBenef", "The mutation rate from the selected allele to the wild-type allele." ),
          new FlaggedOption( "selection", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 's', "selection", "Selection coefficient s (can be a range: [start:step:stop])." ),
          new FlaggedOption( "dominance", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'h', "dominance", "Dominance parameter h (can be a range: [start:step:stop])." ),
          new FlaggedOption( "hetF", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'w', "hetF", "Fitness of the heterozygote (can be a range: [start:step:stop])." ),
          new FlaggedOption( "homF", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'v', "homF", "Fitness of the homozygote (can be a range: [start:step:stop])." ),
          new FlaggedOption( "effPopSize", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.REQUIRED, 'e', "effPopSize", "The effective population size (diploid)." ),
          new FlaggedOption( "yearsPerGen", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.REQUIRED, 'y', "yearsPerGen", "Specify how many years a generation takes." ),
          new FlaggedOption( "matrixCutoff", JSAP.INTEGER_PARSER, JSAP.NO_DEFAULT, JSAP.REQUIRED, 'c', "matrixCutoff", "The cutoff for the matrix whose eigenvectors yield the coefficients for the eigenfunctions." ),
          new FlaggedOption( "maxM", JSAP.INTEGER_PARSER, JSAP.NO_DEFAULT, JSAP.REQUIRED, 'm', "maxM", "Specify how many summands to use in the infinite sum for each eigenfunction." ),
          new FlaggedOption( "maxN", JSAP.INTEGER_PARSER, JSAP.NO_DEFAULT, JSAP.REQUIRED, 'n', "maxN", "Specify how many eigenfunctions/-values to use in the computations." ),
          new FlaggedOption( "precision", JSAP.INTEGER_PARSER, JSAP.NO_DEFAULT, JSAP.REQUIRED, 'p', "precision", "Specify a precision to be used for the computations. This gives the number of significant digits used." ),
          new Switch ("ignoreBinomials", 'z', "ignoreBinomials", "If set, the likelihood is calculated without the binomial coefficients required to make it a probability."),
          new Switch ("multiplex", JSAP.NO_SHORTFLAG, "multiplex", "Analyze each of the multiple datasests given in the input file using all the selection parameters specified."),
          new Switch ("condOnLastSegregating", JSAP.NO_SHORTFLAG, "condOnLastSegregating", "When set, condition on the last sample being segregating."),
              }
          );

    JSAPResult config = jsap.parse(args);
    if (jsap.messagePrinted()) { System.exit(1); }
   
    // what about the binomials
    boolean ignoreBinomials = config.getBoolean ("ignoreBinomials");
    // what about multiplexing
    boolean multiplex = config.getBoolean ("multiplex");
    // what about conditioning
    boolean condOnLastSegregating = config.getBoolean ("condOnLastSegregating");
   
    // precision
    int precision = config.getInt("precision");
    int scale = precision;
    // set precision and some math context
    MathContext mc = new MathContext (precision, RoundingMode.HALF_EVEN);

    // what is the parametrization
    boolean paramByHS = false;
    String param1Name = null, param2Name = null;
    // selection parameterization
    if (config.contains("selection") && config.contains("dominance"))  {
      paramByHS = true;
      param1Name = "selection";
      param2Name = "dominance";
      if (config.contains("hetF") || config.contains("homF"))  {
        System.err.println("Specify exactly either the selection and dominance parameters, or the heterozygote and homozygote advantage parameters");
        System.exit(1);
      }
    }
    else if (config.contains("hetF") && config.contains("homF"))  {
      param1Name = "hetF";
      param2Name = "homF";
      if (config.contains("selection") || config.contains("dominance"))  {
        System.err.println("Specify exactly either the selection and dominance parameters, or the heterozygote and homozygote advantage parameters");
        System.exit(1);
      }
    }
    else {
      System.err.println("Specify exactly either the selection and dominance parameters, or the heterozygote and homozygote advantage parameters");
      System.exit(1);
    }
       
   
    // initial time?
    BigDecimal initTime = null;
    boolean initTimesInFile = !config.contains ("initTime");
    if (initTimesInFile) {
      if (!multiplex) {
        System.err.println ("The flag timesInFile only works together with multiplexing.");
        System.exit(-1);
      }
    }
    else {
      //we do have an initial time
      initTime = config.getBigDecimal("initTime");
    }
   
    // get the initial condition right
    BigDecimal initialFrequency = config.getBigDecimal("initFrequency");
    boolean initFreqSet = (initialFrequency.compareTo(BigDecimal.ZERO) >= 0d) && (initialFrequency.compareTo(BigDecimal.ONE) <= 0d);
    boolean mutSelSet = config.getBoolean ("mutSelBalance");
    boolean mutDriftSet = config.getBoolean ("mutDriftBalance");
    // not check whether exactly one is set, and which one it is
    InitialConditionEnum initialCondition = null;
    if (initFreqSet && !mutSelSet && !mutDriftSet) {
      // initial frequency given
      initialCondition = InitialConditionEnum.InitialFrequency;
    }
    else if (!initFreqSet && mutSelSet && !mutDriftSet) {
      // mutation selection balance requested
      initialCondition = InitialConditionEnum.MutationSelection;
    }
    else if (!initFreqSet && !mutSelSet && mutDriftSet) {
      // mutation drift balance requested
      initialCondition = InitialConditionEnum.MutationDrift;
    }
    else {
      System.err.println ("Must specify consistent initial conditions.");
      System.exit(1);
    }
   
    // get input file
    File inputFile = config.getFile("inputFile");
   
    // get raw mutation parameters
    BigDecimal Ne = (new BigDecimal(config.getString("effPopSize"))).setScale(scale);
    BigDecimal preAlpha = config.getBigDecimal("mutToBenef").setScale(scale);
    if (BigDecimal.ZERO.compareTo(preAlpha) >= 0) throw new IOException ("Zero mutation rate not implemented yet.");
    BigDecimal preBeta = config.getBigDecimal("mutFromBenef").setScale(scale);
    if (BigDecimal.ZERO.compareTo(preBeta) >= 0) throw new IOException ("Zero mutation rate not implemented yet.");
    // and rescale them to be population scaled
    BigDecimal alpha = (new BigDecimal("4")).multiply(Ne, mc).multiply(preAlpha, mc);
    BigDecimal beta = (new BigDecimal("4")).multiply(Ne, mc).multiply(preBeta, mc);
   
    BigDecimal yearsPerGen = (new BigDecimal(config.getString("yearsPerGen"))).setScale(scale);
   
    BigDecimal[] param1Range = null;
    BigDecimal[] param2Range = null;
   
    // some raw selection coefficients
    if (config.contains("selection"))  {
      assert(config.contains("dominance"));
      param1Range = parseRange (config.getString("selection"), scale, mc);
      param2Range = parseRange (config.getString("dominance"), scale, mc);
    }
    if (config.contains("hetF"))  {
      assert(config.contains("homF"));
      param1Range = parseRange (config.getString("hetF"), scale, mc);
      param2Range = parseRange (config.getString("homF"), scale, mc);
    }

    // fill the list of values
    ArrayList<BigDecimal> param1s = buildRange (param1Range[0], param1Range[1], param1Range[2], mc);

    // also we need the other parameters
   
    // make a list of other parameters
    ArrayList<BigDecimal> param2s = buildRange (param2Range[0], param2Range[1], param2Range[2], mc);
   
    // the remaining parameters
    int matrixCutoff = config.getInt("matrixCutoff");
    int maxM = config.getInt("maxM");
    int maxN = config.getInt("maxN");
   
    assert(maxM <= matrixCutoff);
    assert(maxN <= maxM);
   
    //print the command line arguments directly, and also the parameters nicely
//    PrintStream[] outStreams = new PrintStream[]{System.out, System.err};
    PrintStream[] outStreams = new PrintStream[]{System.out};
    for (PrintStream outStream: outStreams) {
      outStream.print("# Command-line arguments: ");
      for (String arg: args)  {
        outStream.print(arg + " ");
      }
      outStream.println();
     
      outStream.println("# Parameter values:");
      outStream.println("# mutToBenef = " + bigDecimalToString(preAlpha));
      outStream.println("# alpha = " + bigDecimalToString(alpha));
      outStream.println("# mutFromBenef = " + bigDecimalToString(preBeta));
      outStream.println("# beta = " + bigDecimalToString(beta));
      outStream.println("# selection param 1 (" + param1Name + ") = " + bigDecimalToString(param1Range[0]) + "\t" + bigDecimalToString(param1Range[1]) + "\t" + bigDecimalToString(param1Range[2]));
      outStream.println("# selection param 2 (" + param2Name + ") = " + bigDecimalToString(param2Range[0]) + "\t" + bigDecimalToString(param2Range[1]) + "\t" + bigDecimalToString(param2Range[2]));
      outStream.println("# precision = " + precision);
      outStream.println("# matrixCutoff = " + matrixCutoff);
      outStream.println("# maxM = " + maxM);
      outStream.println("# maxN = " + maxN);
      outStream.println("# effective Population size = " + bigDecimalToString(Ne) + " diploids");
      outStream.println("# Years per generation = " + bigDecimalToString(yearsPerGen));
      outStream.println("# ignore binomials = " + ignoreBinomials);
      outStream.println("# mulitplex = " + multiplex);
      outStream.println("# condition on last segregating = " + condOnLastSegregating);
      outStream.println ("# initial times in input file = " + initTimesInFile);
    }

    ArrayList<TimedSample> listOfSamples = new ArrayList<TimedSample>();
    // get the data
    if (!multiplex) {
      // only one dataset
      // read it
      ArrayList<BigDecimal> rawTimes = new ArrayList<BigDecimal>();
      ArrayList<Integer> N = new ArrayList<Integer>();
      ArrayList<Integer> S = new ArrayList<Integer>();
 
      // parse it from file
      parseInputSingleData (new FileReader(inputFile), rawTimes, N, S, scale);
     
      // add an additional time for the beginning
      assert (!rawTimes.isEmpty());
      rawTimes.add (0, initTime);
      // check here
      if (!ordered  (rawTimes)) {
        throw new IOException ("The times are not increasing.");
      }
 
      // adjust times in the sample for Ne and generation time
      BigDecimal divisor = (new BigDecimal("2")).multiply(Ne, mc).multiply(yearsPerGen, mc);
      ArrayList<BigDecimal> times = new ArrayList<BigDecimal>();
      for (BigDecimal time : rawTimes) {
        times.add (time.divide(divisor, mc));
      }
     
      // put it into a list
      listOfSamples.add (new TimedSample(times, N, S));
    }
    else {
      // a lot of datasets
      listOfSamples = parseInputMultipleSamples  (new FileReader(inputFile), initTimesInFile, initTime, Ne, yearsPerGen, mc);
    }
   
    if (condOnLastSegregating) {
      // see, what's the minimal sample that occurs at the last position
      int minLastPosition = scanLastPosForMin (listOfSamples);
     
      if (minLastPosition < 1) {
        System.err.println ("If we analyze samples conditional on segregation at last timepoint, then all your samples should be segregating.");
        System.exit(-1);
      }
    }
   
    // build some likelihood surface
    for (BigDecimal param1 : param1s) {
      for (BigDecimal param2 : param2s) {
        // some variables
        BigDecimal hetF, homF, h=null, s;
       
        //which parametrization?
        if (paramByHS)  {
          // homF = s
          homF = param1;
          // homF = s*h
          hetF = param2.multiply(param1, mc);
          // get the others for printing
          s = param1;
          h = param2;
        }
        else  {
          hetF = param1;
          homF = param2;
          // get the others to print nicely
          s = homF;
          if (homF.compareTo(BigDecimal.ZERO) != 0)  {
            h = hetF.divide(homF, mc);
          }
        }
       
        BigDecimal hetFScaled = (new BigDecimal("2")).multiply(Ne, mc).multiply(hetF, mc);
        BigDecimal homFScaled = (new BigDecimal("2")).multiply(Ne, mc).multiply(homF, mc);
       
        BigDecimal sigma  = (new BigDecimal("2")).setScale(scale).multiply(Ne, mc).multiply(s, mc);
       
        // print sigma into comment
        System.out.println("# " + param1Name + " = " + bigDecimalToString(hetF) + "\t" + param2Name + " = " + bigDecimalToString(homF));
        System.out.println("# hetF = " + bigDecimalToString(hetF) + "\thomF = " + bigDecimalToString(homF));
        System.out.println("# hetFScaled = " + bigDecimalToString(hetFScaled) + "\thomFScaled = " + bigDecimalToString(homFScaled));
        System.out.println("# s = " + bigDecimalToString(s) + "\th = " + (h == null ? "undefined" : bigDecimalToString(h)));
        System.out.println("# sigma = " + bigDecimalToString(sigma));

        // need a flag
        boolean successful = true;
       
        // now iterate over all samples
        for (int d=0; d<listOfSamples.size(); d++) {
          TimedSample thisSample = listOfSamples.get(d);
                   
          // compute the likelihood
          BigDecimal ans;
          try {
            // have we been successful before
            if (successful) {
              // run it
              ans = run (alpha, beta, hetFScaled, homFScaled, initialCondition, initialFrequency, matrixCutoff, maxM, maxN, thisSample, ignoreBinomials, condOnLastSegregating, mc);
            }
            else {
              // if not, return the dummy
              ans = returnUnsuccessful;
            }
          } catch (EigenRefineException e) {
            // say something about not converged
            System.out.println ("# [EXCEPTION (maybe recoverable)] " + e.getMessage());
            // remember that it was not successful
            successful = false;
            // and give a dummy value
            ans = returnUnsuccessful;
          } catch (EigenRefineError e) {
            // say something about not converged
            System.out.println ("# [ERROR (not recoverable)] " + e.getMessage());
            // remember that it was not successful
            successful = false;
            // and give a dummy value
            ans = returnUnsuccessful;
          }
         
          // print sample
          printSample (thisSample, initialCondition, initialFrequency, System.out);
         
          // print result out
          String outPut = "";
         
          // add sample number if you want
          if (listOfSamples.size() > 1) {
            // plus one, cause we want a more natural numbering
            outPut += (d+1) + "\t";
          }
         
          // add first selection parameter
          outPut += bigDecimalToString(param1) + "\t";
         
          // add second selection parameter
          outPut += bigDecimalToString(param2) + "\t";
         
          // print it
          System.out.println (outPut + ans);
        }
       
        // clear caches
        // only selection, after all samples have been done
        SelectionEigenfunktion.clearCaches();
      }
    }
   
    long endTime = System.currentTimeMillis();
   
    System.out.printf("# Elapsed time = %d ms\n", endTime-startTime);
  }

 
  private static boolean ordered (ArrayList<BigDecimal> times) {
    for (int i=1; i<times.size(); i++) {
      if (times.get(i-1).compareTo(times.get(i)) >= 0) return false;
    }
    return true;
  }

  private static int scanLastPosForMin(ArrayList<TimedSample> listOfSamples) {
    // get the number of sampling times
    // WARNING: we allow this to vary
   
    // and start finding the minimum
    int thisNumSamplingTimes = listOfSamples.get(0).S.size();
    int min = listOfSamples.get(0).S.get(thisNumSamplingTimes-1);
    for (int i=1; i<listOfSamples.size(); i++) {
      thisNumSamplingTimes = listOfSamples.get(i).S.size();
      min = Math.min (min, listOfSamples.get(i).S.get(thisNumSamplingTimes-1));
    }

    // give it away
    return min;
  }

  private static ArrayList<TimedSample> parseInputMultipleSamples (FileReader fileReader, boolean initTimesInFile, BigDecimal initTime, BigDecimal Ne, BigDecimal yearsPerGen, MathContext mc) throws IOException {

    // the list with the results
    ArrayList<TimedSample> listOfSamples = new ArrayList<SelectionHMM.TimedSample>();
   
    // create an object to read lines
    LineNumberReader lineReader = new LineNumberReader (fileReader);
   
    // and go through lines
    String line;
    // each line should contain a sample
    while ((line = lineReader.readLine()) != null) {

      // ignore lines starting with '#'
      if (line.startsWith("#") || line.trim().isEmpty()) continue;
     
      // new containers
      ArrayList<BigDecimal> rawTimes = new ArrayList<BigDecimal>();
      if (!initTimesInFile) {
        // add the standard initTime
        rawTimes.add (initTime);
      }

      ArrayList<Integer> sampleSizes = new ArrayList<Integer>();
      ArrayList<Integer> numDerivedAlleles = new ArrayList<Integer>();
     
      // get the values
      readVectorFromLine (line, initTimesInFile, rawTimes, sampleSizes, numDerivedAlleles);
           
      // there should be some time
      if (rawTimes.isEmpty()) {
        throw new IOException("no times");
      }
     
      // adjust times in the sample for Ne and generation time
      BigDecimal divisor = (new BigDecimal("2")).multiply(Ne, mc).multiply(yearsPerGen, mc);
      ArrayList<BigDecimal> times = new ArrayList<BigDecimal>();
      for (BigDecimal time : rawTimes) {
        times.add (time.divide(divisor, mc));
      }
     
      // put it into a list
      listOfSamples.add (new TimedSample(times, sampleSizes, numDerivedAlleles));
    }
   
    // give away the list of samples read in
    return listOfSamples;
  }

  static void readVectorFromLine(String line, boolean initTimesInFile, ArrayList<BigDecimal> rawTimes, ArrayList<Integer> sampleSizes, ArrayList<Integer> numDerivedAlleles) throws IOException {
    // first clear the lists
    sampleSizes.clear();
    numDerivedAlleles.clear();
   
    // get the three-tuples from the vector
    // first split it into larger pieces
    String[] fields = line.split(";");
   
    int sampleIdx = 0;
    if (initTimesInFile) {
      // the first entry should only be the time
      rawTimes.add (new BigDecimal(fields[0].trim()));
      // set index to skip the first entry in the rest
      sampleIdx = 1;
    }
   
    // now, if we should have a time in the file, then it should be at the first position
    for (; sampleIdx<fields.length; sampleIdx++) {
     
      // then split the fields into hopefully three pieces
      String[] pieces = fields[sampleIdx].split("(\\))|(\\()|(\\,)");
     
      // we have to clean it a bit
      ArrayList<String> cleanPieces = new ArrayList<String>();
      for (String piece : pieces) {
        // throw away empty ones
        if (!piece.trim().isEmpty()) {
          cleanPieces.add(piece);
        }
      }
     
      // now it should be three pieces
      if (cleanPieces.size() != 3) {
        throw new IOException ("Component does not contain the right number of values: " + fields[sampleIdx] + "\nin line: " + line);
      }
     
      // now store them
      rawTimes.add (new BigDecimal(cleanPieces.get(0).trim()));
      sampleSizes.add (new Integer (cleanPieces.get(1).trim()));
      numDerivedAlleles.add (new Integer (cleanPieces.get(2).trim()));
    }
   
    // before we go, check the times
    if (!ordered  (rawTimes)) {
      throw new IOException ("The times are not increasing in at least one dataset.");
    }

  }

  static void printSample (TimedSample thisSample, InitialConditionEnum initialCondition, BigDecimal initialFrequency, PrintStream outStream) {
    // print initial condition
    if (initialCondition == InitialConditionEnum.InitialFrequency) {
      outStream.println("# initial condition: frequency (" + initialFrequency + ")");
    }
    else if (initialCondition == InitialConditionEnum.MutationSelection) {
      outStream.println("# initial condition: mutation selection balance");
    }
    else if (initialCondition == InitialConditionEnum.MutationDrift) {
      outStream.println("# initial condition: mutation drift balance");
     
    }
    else {
      assert (false);
    }
    outStream.println("# Sample");
   
    outStream.print("# ");
    for (BigDecimal time : thisSample.times)  {
      outStream.print(bigDecimalToString(time) + "\t");
    }
    outStream.println();
   
    outStream.print("# \t");
    for (Integer n : thisSample.N)  {
      outStream.print(n + "\t");
    }
    outStream.println();
   
    outStream.print("# \t");
    for (Integer k : thisSample.S)  {
      outStream.print (k + "\t");
    }
    outStream.println();
  }

  private static ArrayList<BigDecimal> buildRange (BigDecimal start, BigDecimal step, BigDecimal stop, MathContext mc) {
    // just be sure
    assert (start.compareTo(stop) < 1);
    assert (step.compareTo(BigDecimal.ZERO) > -1);
   
    ArrayList<BigDecimal> result = new ArrayList<BigDecimal>();
    // initial
    BigDecimal tmp = start;
    while (tmp.compareTo(stop) < 1){
      // add
      result.add(tmp);
      // increase
      tmp = tmp.add(step, mc);
    }

    return result;
  }

 
  private static BigDecimal[] parseRange(String rangeString, int bigDecimalScale, MathContext mc) throws JSAPException {
    if (rangeString.startsWith("[")) {
      // expect a real range
      if (!rangeString.endsWith("]")) { throw new JSAPException ("Range for selection has to be in format [start:step:stop]"); }
     
      // parse the range
      String[] fields = rangeString.substring(1, rangeString.length()-1).split(":");

      // get the array
      if (fields.length != 3) { throw new JSAPException ("Range for selection has to be in format [start:step:stop]"); }
      BigDecimal[] tmp = new BigDecimal[] { (new BigDecimal(fields[0])).setScale(bigDecimalScale), (new BigDecimal(fields[1])).setScale(bigDecimalScale), (new BigDecimal(fields[2])).setScale(bigDecimalScale) };
     
      // now check for the right stepping direction
      if (tmp[2].subtract(tmp[0], mc).multiply(tmp[1], mc).compareTo(BigDecimal.ZERO) < 0) { throw new JSAPException ("Range for selection has to be in format [start:step:stop]"); }
      // and change the order if necessary
      if (tmp[0].compareTo(tmp[2]) > 0) {
        BigDecimal swap = tmp[0];
        tmp[0] = tmp[2];
        tmp[1] = tmp[1].multiply(new BigDecimal("-1"), mc);
        tmp[2] = swap;
      }
      // special case, if boundaries are equal
      if (tmp[0].compareTo(tmp[2]) == 0 && tmp[1].compareTo(BigDecimal.ZERO) < 0) {
        // step is negative, so change it
        tmp[1] = tmp[1].multiply(BigDecimal.ZERO.subtract(BigDecimal.ONE, mc), mc);
      }
     
      // give it away now
      return tmp;
    }
    else {
      // should just be a single value
      BigDecimal val = new BigDecimal(rangeString);
      // give it like this
      return new BigDecimal[] {val, BigDecimal.ONE, val};
    }
  }

  private static void parseInputSingleData (Reader In, ArrayList<BigDecimal> times, ArrayList<Integer> sampleSizes, ArrayList<Integer> numDerivedAlleles, int bigDecimalScale) throws IOException {
   
    // clear containers
    times.clear();
    sampleSizes.clear();
    numDerivedAlleles.clear();
   
    // create an object to read lines
    LineNumberReader lineReader = new LineNumberReader (In);
   
    // and go through lines
    String line;
    while ((line = lineReader.readLine()) != null) {
      // ignore lines starting with '#'
      if (line.startsWith("#") || line.trim().isEmpty()) continue;
     
      // get the values
      String[] fields = line.split("\\s+");
      // should have 3 fields
      if (fields.length != 3) {
        throw new IOException ("Invalid input format (maybe missing multiplex).");
      }
     
      // add stuff
      times.add (new BigDecimal(fields[0]).setScale(bigDecimalScale));
      sampleSizes.add (new Integer(fields[1]));
      numDerivedAlleles.add (new Integer(fields[2]));
    }
    // done
  }
 
 
  public static String bigDecimalToString (BigDecimal x){
    StringBuilder sb = new StringBuilder();
    Formatter formatter = new Formatter(sb);
   
    formatter.format("%g", x);
   
    // close the formatter
    formatter.close();
   
    return sb.toString();
  }
 
  // class for a dataset
  public static class TimedSample {
    // members
    public ArrayList<BigDecimal> times;
    public ArrayList<Integer> N;
    public ArrayList<Integer> S;
   
    // easy constructor
    public TimedSample (ArrayList<BigDecimal> times, ArrayList<Integer> N, ArrayList<Integer> S) {
      // check some
      assert (times.size() == N.size()+1);
      assert (N.size() == S.size());
     
      // remember them
      this.times = times;
      this.N = N;
      this.S = S;
    }
  }
 
  public static int firstObservedIndex (ArrayList<BigDecimal> rawTimes, ArrayList<Integer> numDerivedAlleles) {
    // loop through
    int observationIndex = 0;
    while (numDerivedAlleles.get(observationIndex) == 0) {
      observationIndex++;
    }
    // and return the first time
    return observationIndex;
  }

 
  public static TimedSample buildSample (BigDecimal t0, ArrayList<BigDecimal> rawTimes, ArrayList<Integer> sampleSizes, ArrayList<Integer> numDerivedAlleles, BigDecimal Ne, BigDecimal yearsPerGen, MathContext mc) {
    // get the index of first observation
    int observationIndex = firstObservedIndex (rawTimes, numDerivedAlleles);
    // t0 should be right
    assert (t0.compareTo(rawTimes.get(observationIndex)) < 0);   
   
    // make the right times
    ArrayList<BigDecimal> diffusionTimes = new ArrayList<BigDecimal>();

    // get the divisor
    BigDecimal divisor = (new BigDecimal("2")).multiply(Ne).multiply(yearsPerGen);

    // put in the initial time
    BigDecimal diffusionT0 = t0.divide (divisor, mc);
    diffusionTimes.add (diffusionT0);
    // and rest
    for (int i=observationIndex; i<rawTimes.size(); i++) {
      diffusionTimes.add (rawTimes.get(i).divide(divisor, mc));
    }
    // now copy the sample size and derived number appropriately
    ArrayList<Integer> newSampleSizes = new ArrayList<Integer>();
    ArrayList<Integer> newNumDerivedAlleles = new ArrayList<Integer>();
    for (int i=observationIndex; i<sampleSizes.size(); i++) {
      newSampleSizes.add (sampleSizes.get(i));
      newNumDerivedAlleles.add (numDerivedAlleles.get(i));
    }
   
    // and return a new object
    return new TimedSample (diffusionTimes, newSampleSizes, newNumDerivedAlleles);
  }


  // enumeration for the initial condition
  public enum InitialConditionEnum { MutationDrift, MutationSelection, InitialFrequency};
 
  public static BigDecimal returnUnsuccessful = BigDecimal.ZERO.subtract(BigDecimal.ONE);
}
TOP

Related Classes of edu.berkeley.spectralHMM.oneD.SelectionHMM$TimedSample

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.