Package edu.berkeley.spectralHMM.matrix

Source Code of edu.berkeley.spectralHMM.matrix.LinearSolver

/*
    This file is part of spectralHMM.

    spectralHMM is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    spectralHMM is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with spectralHMM.  If not, see <http://www.gnu.org/licenses/>.
  */

package edu.berkeley.spectralHMM.matrix;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.MathContext;
import java.math.RoundingMode;
import java.util.Random;

public class LinearSolver {
 
  //solve square system Ax = B using Gaussian elimination with partial pivoting
  //Results will be overwritten into B, and A will be converted into row-echelon form
  //NOTE: Both A & B WILL be changed by the procedure, so make a copy before passing them in.
  //The desired precision and rounding mode can be passed through the MathContext object mc
  public static void geppLinearSolve(BigDecimal[][] A, BigDecimal[] B, MathContext mc)  {
    int n = A.length;
    assert((A[0].length == n) && (B.length == n));
    for (int i = 0; i < n; i++)  {
      //pivot
      int besti = i;
      for (int j = i; j < n; j++)  {
        if (A[j][i].abs(mc).compareTo(A[besti][i].abs(mc)) > 0
          besti = j;
      }
   
      //swap rows besti and i
      for (int j = i; j < n; j++)  {
        BigDecimal temp = A[besti][j];
        A[besti][j] = A[i][j];
        A[i][j] = temp;
      }
      BigDecimal temp = B[besti];
      B[besti] = B[i];
      B[i] = temp;
   
      //scale i-th row
      BigDecimal div = A[i][i];
      assert(!div.unscaledValue().equals(BigInteger.ZERO));
      for (int j = i; j < n; j++)  {
        A[i][j] = A[i][j].divide(div, mc);
      }
      B[i] = B[i].divide(div, mc);
   
      //subtract i^th row from all rows
      for (int j = i+1; j < n; j++)  {
        BigDecimal m = A[j][i];
        if (m.unscaledValue().equals(BigInteger.ZERO))  continue;
        for (int k = i+1; k < n; k++)  {
          if (A[i][k].unscaledValue().equals(BigInteger.ZERO))  continue;
          A[j][k] = A[j][k].subtract(A[i][k].multiply(m, mc), mc);
        }
        B[j] = B[j].subtract(B[i].multiply(m, mc), mc);
      }
    }
 
    //back substitution
    for (int i = n-1; i >= 0; i--)  {
      for (int j = i+1; j < n; j++)  {
        B[i] = B[i].subtract(A[i][j].multiply(B[j], mc), mc);
      }
    }
  }
 

 
  //solve square system Ax = B using Gaussian elimination with partial pivoting
  //where A is a banded matrix with the specified bandwidth.
  //Results will be overwritten into B, and A will be converted into row-echelon form
  //NOTE: Both A & B WILL be changed by the procedure, so make a copy before passing them in.
  //The desired precision and rounding mode can be passed through the MathContext object mc   
  public static void geppLinearSolveBanded(BigDecimal[][] A, BigDecimal[] B, int bandwidth, MathContext mc) {
    int n = A.length;
    assert(A[0].length == n & B.length == n);
    for (int i = 0; i < n; i++)  {
      //pivot
      int besti = i;
      for (int j = i; j <= Math.min(i + bandwidth, n-1); j++)  {
        if (A[j][i].abs(mc).compareTo(A[besti][i].abs(mc)) > 0
          besti = j;
      }
   
      //swap rows besti and i
      for (int j = i; j <= Math.min(i + 2*bandwidth, n-1); j++)  {
        BigDecimal temp = A[besti][j];
        A[besti][j] = A[i][j];
        A[i][j] = temp;
      }
      BigDecimal temp = B[besti];
      B[besti] = B[i];
      B[i] = temp;
   
      //scale i-th row
      BigDecimal div = A[i][i];
      assert(!div.unscaledValue().equals(BigInteger.ZERO));
      for (int j = i; j <= Math.min(i + 2*bandwidth, n-1); j++)  {
        A[i][j] = A[i][j].divide(div, mc);
      }
      B[i] = B[i].divide(div, mc);
   
      //subtract i^th row from all rows
      for (int j = i+1; j <= Math.min(i + bandwidth, n-1); j++)  {
        BigDecimal m = A[j][i];
        if (m.unscaledValue().equals(BigInteger.ZERO))  continue;
        for (int k = i+1; k <= Math.min(j + 2*bandwidth, n-1); k++)  {
          if (A[i][k].unscaledValue().equals(BigInteger.ZERO))  continue;
          A[j][k] = A[j][k].subtract(A[i][k].multiply(m, mc), mc);
        }
        B[j] = B[j].subtract(B[i].multiply(m, mc), mc);
      }
    }
 
    //back substitution
    for (int i = n-1; i >= 0; i--)  {
      for (int j = i+1; j <= Math.min(i + 2*bandwidth, n-1); j++)  {
        B[i] = B[i].subtract(A[i][j].multiply(B[j], mc), mc);
      }
    }
  }
 
 
  //tests the function, and shows example usage
  public static void main(String[] args)  {
    //100 digits of precision.
    int precision = 150;
    MathContext mc = new MathContext(precision, RoundingMode.HALF_EVEN);
   
    //generate a random square 10x10 linear system
    int dim = 300;
    BigDecimal[][] A = new BigDecimal[dim][dim];
    BigDecimal[][] A2 = new BigDecimal[dim][dim];
    BigDecimal[][] Acopy = new BigDecimal[dim][dim];   
    BigDecimal[] B = new BigDecimal[dim];
    BigDecimal[] B2 = new BigDecimal[dim];
    BigDecimal[] Bcopy = new BigDecimal[dim];
    int band = 4;
   
    Random rnd = new Random();
    for (int i = 0; i < dim; i++)  {
      for (int j = 0; j < dim; j++)  {
        double x = 0.;
        if (Math.abs(i - j) <= band)  {
          x = rnd.nextDouble();
        }
        Acopy[i][j] = BigDecimal.valueOf(x);
        //System.out.print(A[i][j] + "\t");
      }
      double x = rnd.nextDouble();
      Bcopy[i] = BigDecimal.valueOf(x);
      //System.out.println(B[i]);
    }
    //System.out.println();   
   
    long start = System.currentTimeMillis();
    for (int runs = 0; runs < 1200; runs++)  {
      for (int i = 0; i < dim; i++)  {
        for (int j = 0; j < dim; j++)  {
          A[i][j] = Acopy[i][j];
          //A2[i][j] = Acopy[i][j];
        }
        B[i] = Bcopy[i];
        //B2[i] = Bcopy[i];
      }
     
      //Solve A*x = B, where the output x will be stored in B
      LinearSolver.geppLinearSolve(A, B, mc);
    }
    long end = System.currentTimeMillis();
    System.out.printf("Time taken for full solver = %d ms\n", (end - start));

    start = System.currentTimeMillis();
    for (int runs = 0; runs < 1200; runs++)  {
      for (int i = 0; i < dim; i++)  {
        for (int j = 0; j < dim; j++)  {
          A2[i][j] = Acopy[i][j];
        }
        B2[i] = Bcopy[i];
      }
     
      //Solve A*x = B, where the output x will be stored in B
      //LinearSolver.geppLinearSolve(A, B, mc);
      LinearSolver.geppLinearSolveBanded(A2, B2, band, mc);
    }
    end = System.currentTimeMillis();
    System.out.printf("Time taken for banded solver = %d ms\n", end - start);

   
    System.out.printf("Worst error between banded and full solver\n");
    BigDecimal worstErr = BigDecimal.ZERO.setScale(precision);
    for (int i = 0; i < dim; i++)  {
      worstErr = worstErr.max(B[i].subtract(B2[i], mc).abs(mc));
    }
    System.out.println(worstErr);
   
    System.out.println("Worst relative error between A*x and b, where x is the result of solving A*x = b, upto 100 digits of precision");
    BigDecimal worstRelErr = BigDecimal.ZERO.setScale(precision);
    //check if the ouput is correct by comparing if A*x == Bcopy
    for (int i = 0; i < dim; i++)  {
      BigDecimal tot = BigDecimal.ZERO;
      for (int j = 0; j < dim; j++)  {
        tot = tot.add(Acopy[i][j].multiply(B[j], mc), mc);
      }
      //print relative error
      BigDecimal tmp = tot.subtract(Bcopy[i], mc).divide(tot.min(Bcopy[i]), mc);
      worstRelErr = worstRelErr.max(tmp);
    }
    System.out.println(worstRelErr);
   
  }


  // solves x*A=b
  public static void geppLinearSolveTransposed (BigDecimal[][] A, BigDecimal[] b, MathContext mc) {
    BigDecimal[][] tA= MatrixPower.transpose (A);
    // and solve the transposed problem
    geppLinearSolve(tA, b, mc);
  }


};
TOP

Related Classes of edu.berkeley.spectralHMM.matrix.LinearSolver

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.