00001
00002
00003
00004
00005 #ifndef PLAYA_BLOCKTRIANGULARSOLVER_HPP
00006 #define PLAYA_BLOCKTRIANGULARSOLVER_HPP
00007
00008 #include "PlayaDefs.hpp"
00009 #include "PlayaLinearSolverDecl.hpp"
00010 #include "PlayaLinearCombinationDecl.hpp"
00011 #include "PlayaCommonOperatorsDecl.hpp"
00012
00013
00014 namespace Playa
00015 {
00016
00017 template <class Scalar>
00018 class BlockTriangularSolver : public LinearSolverBase<Scalar>,
00019 public Playa::Handleable<LinearSolverBase<Scalar> >
00020 {
00021 public:
00022
00023 BlockTriangularSolver(const LinearSolver<Scalar>& solver)
00024 : LinearSolverBase<Scalar>(ParameterList()), solvers_(tuple(solver)) {;}
00025
00026
00027 BlockTriangularSolver(const Array<LinearSolver<Scalar> >& solvers)
00028 : LinearSolverBase<Scalar>(ParameterList()), solvers_(solvers) {;}
00029
00030
00031 virtual ~BlockTriangularSolver(){;}
00032
00033
00034 virtual SolverState<Scalar> solve(const LinearOperator<Scalar>& op,
00035 const Vector<Scalar>& rhs,
00036 Vector<Scalar>& soln) const ;
00037
00038
00039 GET_RCP(LinearSolverBase<Scalar>);
00040 private:
00041 Array<LinearSolver<Scalar> > solvers_;
00042 };
00043
00044
00045 template <class Scalar> inline
00046 SolverState<Scalar> BlockTriangularSolver<Scalar>
00047 ::solve(const LinearOperator<Scalar>& op,
00048 const Vector<Scalar>& rhs,
00049 Vector<Scalar>& soln) const
00050 {
00051 int nRows = op.numBlockRows();
00052 int nCols = op.numBlockCols();
00053
00054 soln = op.domain().createMember();
00055
00056
00057 TEUCHOS_TEST_FOR_EXCEPTION(nRows != rhs.space().numBlocks(), std::runtime_error,
00058 "number of rows in operator " << op
00059 << " not equal to number of blocks on RHS "
00060 << rhs);
00061
00062 TEUCHOS_TEST_FOR_EXCEPTION(nRows != nCols, std::runtime_error,
00063 "nonsquare block structure in block triangular "
00064 "solver: nRows=" << nRows << " nCols=" << nCols);
00065
00066 bool isUpper = false;
00067 bool isLower = false;
00068
00069 for (int r=0; r<nRows; r++)
00070 {
00071 for (int c=0; c<nCols; c++)
00072 {
00073 if (op.getBlock(r,c).ptr().get() == 0 ||
00074 dynamic_cast<const SimpleZeroOp<Scalar>* >(op.getBlock(r,c).ptr().get()))
00075 {
00076 TEUCHOS_TEST_FOR_EXCEPTION(r==c, std::runtime_error,
00077 "zero diagonal block (" << r << ", " << c
00078 << " detected in block "
00079 "triangular solver. Operator is " << op);
00080 continue;
00081 }
00082 else
00083 {
00084 if (r < c) isUpper = true;
00085 if (c < r) isLower = true;
00086 }
00087 }
00088 }
00089
00090 TEUCHOS_TEST_FOR_EXCEPTION(isUpper && isLower, std::runtime_error,
00091 "block triangular solver detected non-triangular operator "
00092 << op);
00093
00094 bool oneSolverFitsAll = false;
00095 if ((int) solvers_.size() == 1 && nRows != 1)
00096 {
00097 oneSolverFitsAll = true;
00098 }
00099
00100 for (int i=0; i<nRows; i++)
00101 {
00102 int r = i;
00103 if (isUpper) r = nRows - 1 - i;
00104 Vector<Scalar> rhs_r = rhs.getBlock(r);
00105 for (int j=0; j<i; j++)
00106 {
00107 int c = j;
00108 if (isUpper) c = nCols - 1 - j;
00109 if (op.getBlock(r,c).ptr().get() != 0)
00110 {
00111 rhs_r = rhs_r - op.getBlock(r,c) * soln.getBlock(c);
00112 }
00113 }
00114
00115 SolverState<Scalar> state;
00116 Vector<Scalar> soln_r;
00117 if (oneSolverFitsAll)
00118 {
00119 state = solvers_[0].solve(op.getBlock(r,r), rhs_r, soln_r);
00120 }
00121 else
00122 {
00123 state = solvers_[r].solve(op.getBlock(r,r), rhs_r, soln_r);
00124 }
00125 if (nRows > 1) soln.setBlock(r, soln_r);
00126 else soln = soln_r;
00127 if (state.finalState() != SolveConverged)
00128 {
00129 return state;
00130 }
00131 }
00132
00133 return SolverState<Scalar>(SolveConverged, "block solves converged",
00134 0, ScalarTraits<Scalar>::zero());
00135 }
00136
00137 }
00138
00139 #endif