00001
00002
00003
00004
00005 #ifndef PLAYA_BLOCKTRIANGULARSOLVER_IMPL_HPP
00006 #define PLAYA_BLOCKTRIANGULARSOLVER_IMPL_HPP
00007
00008 #include "PlayaDefs.hpp"
00009 #include "PlayaLinearSolverDecl.hpp"
00010 #include "PlayaLinearCombinationImpl.hpp"
00011 #include "PlayaSimpleZeroOpDecl.hpp"
00012 #include "PlayaBlockTriangularSolverDecl.hpp"
00013
00014
00015 #ifndef HAVE_TEUCHOS_EXPLICIT_INSTANTIATION
00016 #include "PlayaLinearSolverImpl.hpp"
00017 #include "PlayaSimpleZeroOpDecl.hpp"
00018 #endif
00019
00020 namespace Playa
00021 {
00022 using namespace PlayaExprTemplates;
00023
00024 template <class Scalar> inline
00025 BlockTriangularSolver<Scalar>
00026 ::BlockTriangularSolver(const LinearSolver<Scalar>& solver)
00027 : LinearSolverBase<Scalar>(ParameterList()), solvers_(tuple(solver)) {;}
00028
00029 template <class Scalar> inline
00030 BlockTriangularSolver<Scalar>
00031 ::BlockTriangularSolver(const Array<LinearSolver<Scalar> >& solvers)
00032 : LinearSolverBase<Scalar>(ParameterList()), solvers_(solvers) {;}
00033
00034 template <class Scalar> inline
00035 SolverState<Scalar> BlockTriangularSolver<Scalar>
00036 ::solve(const LinearOperator<Scalar>& op,
00037 const Vector<Scalar>& rhs,
00038 Vector<Scalar>& soln) const
00039 {
00040 int nRows = op.numBlockRows();
00041 int nCols = op.numBlockCols();
00042
00043 soln = op.domain().createMember();
00044
00045
00046 TEUCHOS_TEST_FOR_EXCEPTION(nRows != rhs.space().numBlocks(), std::runtime_error,
00047 "number of rows in operator " << op
00048 << " not equal to number of blocks on RHS "
00049 << rhs);
00050
00051 TEUCHOS_TEST_FOR_EXCEPTION(nRows != nCols, std::runtime_error,
00052 "nonsquare block structure in block triangular "
00053 "solver: nRows=" << nRows << " nCols=" << nCols);
00054
00055 bool isUpper = false;
00056 bool isLower = false;
00057
00058 for (int r=0; r<nRows; r++)
00059 {
00060 for (int c=0; c<nCols; c++)
00061 {
00062 if (op.getBlock(r,c).ptr().get() == 0 ||
00063 dynamic_cast<const SimpleZeroOp<Scalar>* >(op.getBlock(r,c).ptr().get()))
00064 {
00065 TEUCHOS_TEST_FOR_EXCEPTION(r==c, std::runtime_error,
00066 "zero diagonal block (" << r << ", " << c
00067 << " detected in block "
00068 "triangular solver. Operator is " << op);
00069 continue;
00070 }
00071 else
00072 {
00073 if (r < c) isUpper = true;
00074 if (c < r) isLower = true;
00075 }
00076 }
00077 }
00078
00079 TEUCHOS_TEST_FOR_EXCEPTION(isUpper && isLower, std::runtime_error,
00080 "block triangular solver detected non-triangular operator "
00081 << op);
00082
00083 bool oneSolverFitsAll = false;
00084 if ((int) solvers_.size() == 1 && nRows != 1)
00085 {
00086 oneSolverFitsAll = true;
00087 }
00088
00089 for (int i=0; i<nRows; i++)
00090 {
00091 int r = i;
00092 if (isUpper) r = nRows - 1 - i;
00093 Vector<Scalar> rhs_r = rhs.getBlock(r);
00094 for (int j=0; j<i; j++)
00095 {
00096 int c = j;
00097 if (isUpper) c = nCols - 1 - j;
00098 if (op.getBlock(r,c).ptr().get() != 0)
00099 {
00100 rhs_r = rhs_r - op.getBlock(r,c) * soln.getBlock(c);
00101 }
00102 }
00103
00104 SolverState<Scalar> state;
00105 Vector<Scalar> soln_r;
00106 if (oneSolverFitsAll)
00107 {
00108 state = solvers_[0].solve(op.getBlock(r,r), rhs_r, soln_r);
00109 }
00110 else
00111 {
00112 state = solvers_[r].solve(op.getBlock(r,r), rhs_r, soln_r);
00113 }
00114 if (nRows > 1) soln.setBlock(r, soln_r);
00115 else soln = soln_r;
00116 if (state.finalState() != SolveConverged)
00117 {
00118 return state;
00119 }
00120 }
00121
00122 return SolverState<Scalar>(SolveConverged, "block solves converged",
00123 0, ScalarTraits<Scalar>::zero());
00124 }
00125
00126 }
00127
00128 #endif