PlayaBICGSTABSolverImpl.hpp

00001 /* @HEADER@ */
00002 //   
00003  /* @HEADER@ */
00004 
00005 #ifndef PLAYA_BICGSTABSOLVER_IMPL_HPP
00006 #define PLAYA_BICGSTABSOLVER_IMPL_HPP
00007 
00008 #include "PlayaBICGSTABSolverDecl.hpp"
00009 #include "PlayaLinearCombinationImpl.hpp"
00010 #include "PlayaSimpleScaledOpDecl.hpp"
00011 #include "PlayaSimpleComposedOpDecl.hpp"
00012 
00013 #ifndef HAVE_TEUCHOS_EXPLICIT_INSTANTIATION
00014 #include "PlayaLinearOperatorImpl.hpp"
00015 #include "PlayaLinearSolverBaseImpl.hpp"
00016 #include "PlayaSimpleScaledOpImpl.hpp"
00017 #include "PlayaSimpleComposedOpImpl.hpp"
00018 #include "PlayaSimpleTransposedOpImpl.hpp"
00019 #endif
00020 
00021 
00022 
00023 namespace Playa
00024 {
00025 using namespace Teuchos;
00026 
00027 
00028 /* */
00029 template <class Scalar> inline
00030 BICGSTABSolver<Scalar>
00031 ::BICGSTABSolver(const ParameterList& params)
00032   : KrylovSolver<Scalar>(params) {;}
00033 
00034 /* */
00035 template <class Scalar> inline
00036 BICGSTABSolver<Scalar>::BICGSTABSolver(const ParameterList& params,
00037   const PreconditionerFactory<Scalar>& precond)
00038   : KrylovSolver<Scalar>(params, precond) {;}
00039 
00040 /* Write to a stream  */
00041 template <class Scalar> inline
00042 void BICGSTABSolver<Scalar>::print(std::ostream& os) const 
00043 {
00044   os << description() << "[" << std::endl;
00045   os << this->parameters() << std::endl;
00046   os << "]" << std::endl;
00047 }
00048 
00049     
00050 template <class Scalar> inline
00051 SolverState<Scalar> BICGSTABSolver<Scalar>
00052 ::solveUnprec(const LinearOperator<Scalar>& op,
00053   const Vector<Scalar>& b,
00054   Vector<Scalar>& soln) const
00055 {
00056   int maxiters = this->getMaxiters();
00057   Scalar tol = this->getTol();
00058   int verbosity = this->verb();
00059 
00060   Scalar normOfB = sqrt(b.dot(b));
00061 
00062   /* check for trivial case of zero rhs */
00063   if (normOfB < tol) 
00064   {
00065     soln = b.space().createMember();
00066     soln.zero();
00067     return SolverState<Scalar>(SolveConverged, "RHS was zero", 0, 0.0);
00068   }
00069 
00070   /* check for initial zero residual */
00071   Vector<Scalar> x0 = b.copy();
00072   Vector<Scalar> r0 = b.space().createMember();
00073   Vector<Scalar> tmp = b.space().createMember();
00074 
00075   // r0 =  b - op*x0;
00076   op.apply(x0, tmp);
00077   r0 = b - tmp;
00078   
00079   if (sqrt(r0.dot(r0)) < tol*normOfB) 
00080   {
00081     soln = x0;
00082     return SolverState<Scalar>(SolveConverged, "initial resid was zero", 
00083       0, 0.0);
00084   }
00085 
00086   Vector<Scalar> p0 = r0.copy();
00087   //    p0.randomize();
00088   Vector<Scalar> r0Hat = r0.copy();
00089   Vector<Scalar> xMid = b.space().createMember();
00090   Vector<Scalar> rMid = b.space().createMember();
00091   Vector<Scalar> ArMid = b.space().createMember();
00092   Vector<Scalar> x = b.space().createMember();
00093   Vector<Scalar> r = b.space().createMember();
00094   Vector<Scalar> s = b.space().createMember();
00095   Vector<Scalar> ap = b.space().createMember();
00096 
00097   int myRank = MPIComm::world().getRank();
00098 
00099   Scalar resid = -1.0;
00100 
00101   for (int k=1; k<=maxiters; k++)
00102   {
00103     // ap = A*p0
00104     op.apply(p0, ap);
00105 
00106     Scalar den = ap.dot(r0Hat);
00107     if (Utils::chop(sqrt(fabs(den))/normOfB)==0) 
00108     {
00109       SolverState<Scalar> rtn(SolveCrashed, 
00110         "BICGSTAB failure mode 1", k, resid);
00111       return rtn;
00112     }
00113       
00114     Scalar a0 = r0.dot(r0Hat)/den;
00115       
00116     xMid = x0 + a0*p0;
00117     //xMid.axpy(a0, p0, x0);
00118 
00119     rMid = r0 - a0*ap;
00120     //rMid.axpy(-a0, ap, r0);
00121 
00122     // check for convergence
00123     Scalar resid = rMid.norm2()/normOfB;
00124     if (resid < tol) 
00125     {
00126       soln = xMid; 
00127       SolverState<Scalar> rtn(SolveConverged, "yippee!!", k, resid);
00128       return rtn;
00129     }
00130 
00131     // ArMid = A*rMid
00132     op.apply(rMid, ArMid);
00133 
00134     den = ArMid.dot(ArMid);
00135     if (Utils::chop(sqrt(fabs(den))/normOfB)==0)  
00136     {
00137       SolverState<Scalar> rtn(SolveCrashed, 
00138         "BICGSTAB failure mode 2", k, resid);
00139       return rtn;
00140     }
00141 
00142     Scalar w = rMid.dot(ArMid)/den;
00143       
00144     x = xMid + w*rMid;
00145     //x.axpy(w, rMid, xMid);
00146       
00147     r = rMid - w*ArMid;
00148     //r.axpy(-w, ArMid, rMid);
00149 
00150     // check for convergence
00151     resid = sqrt(r.dot(r))/normOfB;
00152     if (resid < tol) 
00153     {
00154       soln = x;
00155       SolverState<Scalar> rtn(SolveConverged, "yippee!!", k, resid);
00156       return rtn;
00157     }
00158 
00159     den = w*(r0.dot(r0Hat));
00160     if (Utils::chop(sqrt(fabs(den))/normOfB)==0) 
00161     {
00162       SolverState<Scalar> rtn(SolveCrashed, 
00163         "BICGSTAB failure mode 3", k, resid);
00164       return rtn;
00165     }
00166     Scalar beta = a0*(r.dot(r0Hat))/den;
00167 
00168     s = p0 - w*ap;
00169     p0 = r + beta*s;
00170 //    p0 = r + beta*p0 - beta*w*ap;
00171     //s.axpy(-w, ap, p0);
00172     //p0.axpy(beta, s, r);
00173 
00174     r0 = r.copy();
00175     x0 = x.copy();
00176 
00177     if (myRank==0 && verbosity > 1 ) 
00178     {
00179       Out::os() << "BICGSTAB: iteration=";
00180       Out::os().width(8);
00181       Out::os() << k;
00182       Out::os().width(20);
00183       Out::os() << " resid=" << resid << std::endl;
00184     }
00185   }
00186     
00187   SolverState<Scalar> rtn(SolveFailedToConverge, 
00188     "BICGSTAB failed to converge", 
00189     maxiters, resid);
00190   return rtn;
00191 }
00192 
00193 
00194 }
00195 
00196 #endif

doxygen