PlayaBlockTriangularSolver.hpp

00001 /* @HEADER@ */
00002 //   
00003  /* @HEADER@ */
00004 
00005 #ifndef PLAYA_BLOCKTRIANGULARSOLVER_HPP
00006 #define PLAYA_BLOCKTRIANGULARSOLVER_HPP
00007 
00008 #include "PlayaDefs.hpp"
00009 #include "PlayaLinearSolverDecl.hpp" 
00010 #include "PlayaLinearCombinationDecl.hpp" 
00011 #include "PlayaCommonOperatorsDecl.hpp" 
00012 
00013 
00014 namespace Playa
00015 {
00017   template <class Scalar>
00018   class BlockTriangularSolver : public LinearSolverBase<Scalar>,
00019                                 public Playa::Handleable<LinearSolverBase<Scalar> >
00020   {
00021   public:
00023     BlockTriangularSolver(const LinearSolver<Scalar>& solver)
00024       : LinearSolverBase<Scalar>(ParameterList()), solvers_(tuple(solver)) {;}
00025 
00027     BlockTriangularSolver(const Array<LinearSolver<Scalar> >& solvers)
00028       : LinearSolverBase<Scalar>(ParameterList()), solvers_(solvers) {;}
00029 
00031     virtual ~BlockTriangularSolver(){;}
00032 
00034     virtual SolverState<Scalar> solve(const LinearOperator<Scalar>& op,
00035                                       const Vector<Scalar>& rhs,
00036                                       Vector<Scalar>& soln) const ;
00037 
00038     /* */
00039     GET_RCP(LinearSolverBase<Scalar>);
00040   private:
00041     Array<LinearSolver<Scalar> > solvers_;
00042   };
00043 
00044 
00045   template <class Scalar> inline
00046   SolverState<Scalar> BlockTriangularSolver<Scalar>
00047   ::solve(const LinearOperator<Scalar>& op,
00048           const Vector<Scalar>& rhs,
00049           Vector<Scalar>& soln) const
00050   {
00051     int nRows = op.numBlockRows();
00052     int nCols = op.numBlockCols();
00053 
00054     soln = op.domain().createMember();
00055     //    bool converged = false;
00056 
00057     TEUCHOS_TEST_FOR_EXCEPTION(nRows != rhs.space().numBlocks(), std::runtime_error,
00058                        "number of rows in operator " << op
00059                        << " not equal to number of blocks on RHS "
00060                        << rhs);
00061 
00062     TEUCHOS_TEST_FOR_EXCEPTION(nRows != nCols, std::runtime_error,
00063                        "nonsquare block structure in block triangular "
00064                        "solver: nRows=" << nRows << " nCols=" << nCols);
00065 
00066     bool isUpper = false;
00067     bool isLower = false;
00068 
00069     for (int r=0; r<nRows; r++)
00070       {
00071         for (int c=0; c<nCols; c++)
00072           {
00073             if (op.getBlock(r,c).ptr().get() == 0 ||
00074                 dynamic_cast<const SimpleZeroOp<Scalar>* >(op.getBlock(r,c).ptr().get()))
00075               {
00076                 TEUCHOS_TEST_FOR_EXCEPTION(r==c, std::runtime_error,
00077                                    "zero diagonal block (" << r << ", " << c 
00078                                    << " detected in block "
00079                                    "triangular solver. Operator is " << op);
00080                 continue;
00081               }
00082             else
00083               {
00084                 if (r < c) isUpper = true;
00085                 if (c < r) isLower = true;
00086               }
00087           }
00088       }
00089 
00090     TEUCHOS_TEST_FOR_EXCEPTION(isUpper && isLower, std::runtime_error, 
00091                        "block triangular solver detected non-triangular operator "
00092                        << op);
00093 
00094     bool oneSolverFitsAll = false;
00095     if ((int) solvers_.size() == 1 && nRows != 1) 
00096       {
00097         oneSolverFitsAll = true;
00098       }
00099 
00100     for (int i=0; i<nRows; i++)
00101       {
00102         int r = i;
00103         if (isUpper) r = nRows - 1 - i;
00104         Vector<Scalar> rhs_r = rhs.getBlock(r);
00105         for (int j=0; j<i; j++)
00106           {
00107             int c = j;
00108             if (isUpper) c = nCols - 1 - j;
00109             if (op.getBlock(r,c).ptr().get() != 0)
00110               {
00111                 rhs_r = rhs_r - op.getBlock(r,c) * soln.getBlock(c);
00112               }
00113           }
00114 
00115         SolverState<Scalar> state;
00116         Vector<Scalar> soln_r;
00117         if (oneSolverFitsAll)
00118           {
00119             state = solvers_[0].solve(op.getBlock(r,r), rhs_r, soln_r);
00120           }
00121         else
00122           {
00123             state = solvers_[r].solve(op.getBlock(r,r), rhs_r, soln_r);
00124           }
00125         if (nRows > 1) soln.setBlock(r, soln_r);
00126         else soln = soln_r;
00127         if (state.finalState() != SolveConverged)
00128           {
00129             return state;
00130           }
00131       }
00132 
00133     return SolverState<Scalar>(SolveConverged, "block solves converged",
00134                                0, ScalarTraits<Scalar>::zero());
00135   }
00136   
00137 }
00138 
00139 #endif

doxygen