PlayaBlockTriangularSolverImpl.hpp

00001 /* @HEADER@ */
00002 //   
00003  /* @HEADER@ */
00004 
00005 #ifndef PLAYA_BLOCKTRIANGULARSOLVER_IMPL_HPP
00006 #define PLAYA_BLOCKTRIANGULARSOLVER_IMPL_HPP
00007 
00008 #include "PlayaDefs.hpp"
00009 #include "PlayaLinearSolverDecl.hpp" 
00010 #include "PlayaLinearCombinationImpl.hpp" 
00011 #include "PlayaSimpleZeroOpDecl.hpp" 
00012 #include "PlayaBlockTriangularSolverDecl.hpp" 
00013 
00014 
00015 #ifndef HAVE_TEUCHOS_EXPLICIT_INSTANTIATION
00016 #include "PlayaLinearSolverImpl.hpp" 
00017 #include "PlayaSimpleZeroOpDecl.hpp" 
00018 #endif
00019 
00020 namespace Playa
00021 {
00022 using namespace PlayaExprTemplates;
00023 
00024 template <class Scalar> inline
00025 BlockTriangularSolver<Scalar>
00026 ::BlockTriangularSolver(const LinearSolver<Scalar>& solver)
00027   : LinearSolverBase<Scalar>(ParameterList()), solvers_(tuple(solver)) {;}
00028 
00029 template <class Scalar> inline
00030 BlockTriangularSolver<Scalar>
00031 ::BlockTriangularSolver(const Array<LinearSolver<Scalar> >& solvers)
00032   : LinearSolverBase<Scalar>(ParameterList()), solvers_(solvers) {;}
00033 
00034 template <class Scalar> inline
00035 SolverState<Scalar> BlockTriangularSolver<Scalar>
00036 ::solve(const LinearOperator<Scalar>& op,
00037   const Vector<Scalar>& rhs,
00038   Vector<Scalar>& soln) const
00039 {
00040   int nRows = op.numBlockRows();
00041   int nCols = op.numBlockCols();
00042 
00043   soln = op.domain().createMember();
00044   //    bool converged = false;
00045 
00046   TEUCHOS_TEST_FOR_EXCEPTION(nRows != rhs.space().numBlocks(), std::runtime_error,
00047     "number of rows in operator " << op
00048     << " not equal to number of blocks on RHS "
00049     << rhs);
00050 
00051   TEUCHOS_TEST_FOR_EXCEPTION(nRows != nCols, std::runtime_error,
00052     "nonsquare block structure in block triangular "
00053     "solver: nRows=" << nRows << " nCols=" << nCols);
00054 
00055   bool isUpper = false;
00056   bool isLower = false;
00057 
00058   for (int r=0; r<nRows; r++)
00059   {
00060     for (int c=0; c<nCols; c++)
00061     {
00062       if (op.getBlock(r,c).ptr().get() == 0 ||
00063         dynamic_cast<const SimpleZeroOp<Scalar>* >(op.getBlock(r,c).ptr().get()))
00064       {
00065         TEUCHOS_TEST_FOR_EXCEPTION(r==c, std::runtime_error,
00066           "zero diagonal block (" << r << ", " << c 
00067           << " detected in block "
00068           "triangular solver. Operator is " << op);
00069         continue;
00070       }
00071       else
00072       {
00073         if (r < c) isUpper = true;
00074         if (c < r) isLower = true;
00075       }
00076     }
00077   }
00078 
00079   TEUCHOS_TEST_FOR_EXCEPTION(isUpper && isLower, std::runtime_error, 
00080     "block triangular solver detected non-triangular operator "
00081     << op);
00082 
00083   bool oneSolverFitsAll = false;
00084   if ((int) solvers_.size() == 1 && nRows != 1) 
00085   {
00086     oneSolverFitsAll = true;
00087   }
00088 
00089   for (int i=0; i<nRows; i++)
00090   {
00091     int r = i;
00092     if (isUpper) r = nRows - 1 - i;
00093     Vector<Scalar> rhs_r = rhs.getBlock(r);
00094     for (int j=0; j<i; j++)
00095     {
00096       int c = j;
00097       if (isUpper) c = nCols - 1 - j;
00098       if (op.getBlock(r,c).ptr().get() != 0)
00099       {
00100         rhs_r = rhs_r - op.getBlock(r,c) * soln.getBlock(c);
00101       }
00102     }
00103 
00104     SolverState<Scalar> state;
00105     Vector<Scalar> soln_r;
00106     if (oneSolverFitsAll)
00107     {
00108       state = solvers_[0].solve(op.getBlock(r,r), rhs_r, soln_r);
00109     }
00110     else
00111     {
00112       state = solvers_[r].solve(op.getBlock(r,r), rhs_r, soln_r);
00113     }
00114     if (nRows > 1) soln.setBlock(r, soln_r);
00115     else soln = soln_r;
00116     if (state.finalState() != SolveConverged)
00117     {
00118       return state;
00119     }
00120   }
00121 
00122   return SolverState<Scalar>(SolveConverged, "block solves converged",
00123     0, ScalarTraits<Scalar>::zero());
00124 }
00125   
00126 }
00127 
00128 #endif

doxygen