Package mikera.matrixx.solve.impl

Source Code of mikera.matrixx.solve.impl.CholeskyLDUSolver

/*
* Copyright (c) 2009-2013, Peter Abeles. All Rights Reserved.
*
* This file is part of Efficient Java Matrix Library (EJML).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package mikera.matrixx.solve.impl;

import mikera.matrixx.AMatrix;
import mikera.matrixx.Matrix;
import mikera.matrixx.decompose.ICholeskyLDUResult;
import mikera.matrixx.decompose.impl.chol.CholeskyLDU;

/**
* @author Peter Abeles
*/
public class CholeskyLDUSolver {
 
  protected Matrix A;
    protected int numRows;
    protected int numCols;

    private ICholeskyLDUResult ans;
   
    private int n;
    private double vv[];
    private double el[];
    private double d[];

    public boolean setA(AMatrix _A) {
//        _setA(A);
     
      this.A = Matrix.create(_A);
        this.numRows = A.rowCount();
        this.numCols = A.columnCount();

        ans = CholeskyLDU.decompose(A);
        if( ans != null ){
            n = A.columnCount();
//          vv = decomp._getVV();
            vv = new double[A.rowCount()];
            el = ans.getL().toMatrix().data;
            d = ans.getD().getLeadingDiagonal().toDoubleArray();
            return true;
        } else {
            return false;
        }
    }

    public double quality() {
        return Math.abs(diagProd(ans.getL()));
    }

    private double diagProd(AMatrix m) {
      double prod = 1.0;
      int diagonalLength = m.rowCount();
      for(int i=0; i<diagonalLength; i++) {
        prod *= m.get(i, i);
      }
    return prod;
  }

  /**
     * <p>
     * Using the decomposition, finds the value of 'X' in the linear equation below:<br>
     *
     * A*x = b<br>
     *
     * where A has dimension of n by n, x and b are n by m dimension.
     * </p>
     * <p>
     * *Note* that 'b' and 'x' can be the same matrix instance.
     * </p>
     *
     * @param B A matrix that is n by m.  Not modified.
     * @param X An n by m matrix where the solution is writen to.  Modified.
     */
    public AMatrix solve(AMatrix B) {
      Matrix X = Matrix.create(B.rowCount(), B.columnCount());
        if( B.columnCount() != X.columnCount() && B.rowCount() != n && X.rowCount() != n) {
            throw new IllegalArgumentException("Unexpected matrix size");
        }

        int numCols = B.columnCount();

        double dataB[] = B.toMatrix().data;
        double dataX[] = X.data;

        for( int j = 0; j < numCols; j++ ) {
            for( int i = 0; i < n; i++ ) vv[i] = dataB[i*numCols+j];
            solveInternal();
            for( int i = 0; i < n; i++ ) dataX[i*numCols+j] = vv[i];
        }
        return X;
    }

    /**
     * Used internally to find the solution to a single column vector.
     */
    private void solveInternal() {
        // solve L*s=b storing y in x
        TriangularSolver.solveL(el,vv,n);

        // solve D*y=s
        for( int i = 0; i < n; i++ ) {
            vv[i] /= d[i];
        }

        // solve L^T*x=y
        TriangularSolver.solveTranL(el,vv,n);
    }

    /**
     * returns the matrix 'inv' equal to the inverse of the matrix that was decomposed.
     *
     * @return inverse of matrix that was decomposed
     */
    public AMatrix invert() {
      Matrix inv = Matrix.create(numRows, numCols);
        if( inv.rowCount() != n || inv.columnCount() != n ) {
            throw new RuntimeException("Unexpected matrix dimension");
        }

        double a[] = inv.data;

        // solve L*z = b
        for( int i =0; i < n; i++ ) {
            for( int j = 0; j <= i; j++ ) {
                double sum = (i==j) ? 1.0 : 0.0;
                for( int k=i-1; k >=j; k-- ) {
                    sum -= el[i*n+k]*a[j*n+k];
                }
                a[j*n+i] = sum;
            }
        }

        // solve D*y=z
        for( int i =0; i < n; i++ ) {
            double inv_d = 1.0/d[i];
            for( int j = 0; j <= i; j++ ) {
                a[j*n+i] *= inv_d;
            }
        }

        // solve L^T*x = y
        for( int i=n-1; i>=0; i-- ) {
            for( int j = 0; j <= i; j++ ) {
                double sum = (i<j) ? 0 : a[j*n+i];
                for( int k=i+1;k<n;k++) {
                    sum -= el[k*n+i]*a[j*n+k];
                }
                a[i*n+j] = a[j*n+i] = sum;
            }
        }
        return inv;
    }
}
TOP

Related Classes of mikera.matrixx.solve.impl.CholeskyLDUSolver

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.