/*
* Copyright (c) 2009-2013, Peter Abeles. All Rights Reserved.
*
* This file is part of Efficient Java Matrix Library (EJML).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package mikera.matrixx.solve.impl.lu;
import mikera.matrixx.AMatrix;
import mikera.matrixx.Matrix;
import mikera.matrixx.decompose.impl.lu.AltLU;
import mikera.matrixx.decompose.impl.lu.LUPResult;
import mikera.matrixx.impl.ADenseArrayMatrix;
/**
* @author Peter Abeles
*/
public class LUSolver {
protected AltLU decomp;
private LUPResult result;
boolean doImprove = false;
protected AMatrix A;
protected int numRows;
protected int numCols;
public AMatrix getA() {
return A;
}
public LUSolver( boolean improve ) {
this.doImprove = improve;
}
public LUSolver() {
this.doImprove = false;
}
public LUPResult setA(AMatrix A) {
if(!A.isSquare())
throw new IllegalArgumentException("Input must be a square matrix.");
this.A = A;
this.numRows = A.rowCount();
this.numCols = A.columnCount();
decomp = new AltLU();
result = decomp._decompose(A);
return result;
}
public double quality() {
return decomp.quality();
}
public AMatrix invert() {
if (!A.isSquare()) { throw new IllegalArgumentException(
"Matrix must be square for inverse!"); }
double []vv = decomp._getVV();
AMatrix LU = decomp.getLU();
Matrix A_inv = Matrix.create(LU.rowCount(), LU.columnCount());
int n = A.columnCount();
double dataInv[] = A_inv.data;
for( int j = 0; j < n; j++ ) {
// don't need to change inv into an identity matrix before hand
for( int i = 0; i < n; i++ ) vv[i] = i == j ? 1 : 0;
decomp._solveVectorInternal(vv);
// for( int i = 0; i < n; i++ ) dataInv[i* n +j] = vv[i];
int index = j;
for( int i = 0; i < n; i++ , index += n) dataInv[ index ] = vv[i];
}
return A_inv;
}
public ADenseArrayMatrix solve(AMatrix b) {
if( b.rowCount() != numCols )
throw new IllegalArgumentException("Unexpected matrix size");
if(Math.abs(result.computeDeterminant()) < 1e-10)
return null;
Matrix x = Matrix.create(numCols, b.columnCount());
int numCols = b.columnCount();
double dataB[] = b.asDoubleArray();
if (dataB == null) {
dataB = b.toDoubleArray();
}
double dataX[] = x.data;
double []vv = decomp._getVV();
// for( int j = 0; j < numCols; j++ ) {
// for( int i = 0; i < this.numCols; i++ ) vv[i] = dataB[i*numCols+j];
// decomp._solveVectorInternal(vv);
// for( int i = 0; i < this.numCols; i++ ) dataX[i*numCols+j] = vv[i];
// }
for( int j = 0; j < numCols; j++ ) {
int index = j;
for( int i = 0; i < this.numCols; i++ , index += numCols ) vv[i] = dataB[index];
decomp._solveVectorInternal(vv);
index = j;
for( int i = 0; i < this.numCols; i++ , index += numCols ) dataX[index] = vv[i];
}
if( doImprove ) {
improveSol(b,x);
}
return x;
}
/**
* This attempts to improve upon the solution generated by account
* for numerical imprecisions. See numerical recipes for more information. It
* is assumed that solve has already been run on 'b' and 'x' at least once.
*
* @param b A matrix. Not modified.
* @param x A matrix. Modified.
*/
public void improveSol( AMatrix b , AMatrix x )
{
if( b.columnCount() != x.columnCount() ) {
throw new IllegalArgumentException("bad shapes");
}
double dataA[] = A.asDoubleArray();
double dataB[] = b.asDoubleArray();
double dataX[] = x.asDoubleArray();
final int nc = b.columnCount();
final int n = b.columnCount();
double []vv = decomp._getVV();
// AMatrix LU = decomp.getLU();
// BigDecimal sdp = new BigDecimal(0);
for( int k = 0; k < nc; k++ ) {
for( int i = 0; i < n; i++ ) {
// *NOTE* in the book this is a long double. extra precision might be required
double sdp = -dataB[ i * nc + k];
// BigDecimal sdp = new BigDecimal(-dataB[ i * nc + k]);
for( int j = 0; j < n; j++ ) {
sdp += dataA[i* n +j] * dataX[ j * nc + k];
// sdp = sdp.add( BigDecimal.valueOf(dataA[i* n +j] * dataX[ j * nc + k]));
}
vv[i] = sdp;
// vv[i] = sdp.doubleValue();
}
decomp._solveVectorInternal(vv);
for( int i = 0; i < n; i++ ) {
dataX[i*nc + k] -= vv[i];
}
}
}
}