00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 
00014 
00015 
00016 
00017 
00018 
00019 
00020 
00021 
00022 
00023 
00024 
00025 
00026 
00027 
00028 
00029 
00030 
00031 #include "SundanceUserDefOpEvaluator.hpp"
00032 #include "SundanceUserDefOpCommonEvaluator.hpp"
00033 #include "SundanceUserDefOpElement.hpp"
00034 #include "SundanceEvalManager.hpp"
00035 
00036 #include "PlayaTabs.hpp"
00037 #include "SundanceOut.hpp"
00038 #include "SundanceUserDefOp.hpp"
00039 
00040 using namespace Sundance;
00041 using namespace Sundance;
00042 using namespace Sundance;
00043 using namespace Teuchos;
00044 
00045 
00046 
00047 
00048 
00049 UserDefOpEvaluator
00050 ::UserDefOpEvaluator(const UserDefOpElement* expr,
00051                      const RCP<const UserDefOpCommonEvaluator>& commonEval,
00052                      const EvalContext& context)
00053   : ChainRuleEvaluator(expr, context),
00054     argValueIndex_(expr->numChildren()),
00055     argValueIsConstant_(expr->numChildren()),
00056     functor_(expr->functorElement()),
00057     commonEval_(commonEval),
00058     maxOrder_(0),
00059     numVarArgDerivs_(0),
00060     numConstArgDerivs_(0),
00061     allArgsAreConstant_(true)
00062 {
00063   Tabs tab1;
00064   SUNDANCE_VERB_LOW(tab1 << "initializing user defined op evaluator for " 
00065                     << expr->toString());
00066   Array<int> orders = findRequiredOrders(expr, context);
00067 
00068   for (int i=0; i<orders.size(); i++) 
00069     {
00070       if (orders[i] > maxOrder_) maxOrder_ = orders[i];
00071     }
00072   commonEval->updateMaxOrder(maxOrder_);
00073 
00074   SUNDANCE_VERB_HIGH(tab1 << "setting arg deriv indices");
00075   
00076   
00077   
00078 
00079   Map<MultiSet<int>, int> varArgDerivs;
00080   Map<MultiSet<int>, int> constArgDerivs;
00081   expr->getArgDerivIndices(orders, varArgDerivs, constArgDerivs);
00082   numVarArgDerivs_ = varArgDerivs.size();
00083   numConstArgDerivs_ = constArgDerivs.size();
00084   typedef Map<MultiSet<int>, int>::const_iterator iter;
00085   for (iter i=varArgDerivs.begin(); i!=varArgDerivs.end(); i++)
00086     {
00087       Tabs tab2;
00088       SUNDANCE_VERB_EXTREME(tab2 << "variable arg deriv " << i->first 
00089                             << " will be at index "
00090                             << i->second);
00091       addVarArgDeriv(i->first, i->second);
00092     }
00093   
00094   for (iter i=constArgDerivs.begin(); i!=constArgDerivs.end(); i++)
00095     {
00096       Tabs tab2;
00097       SUNDANCE_VERB_EXTREME(tab2 << "constant arg deriv " << i->first 
00098                             << " will be at index "
00099                             << i->second);
00100       addConstArgDeriv(i->first, i->second);
00101     }
00102 
00103   
00104   
00105   for (int i=0; i<expr->numChildren(); i++)
00106     {
00107       const SparsitySuperset* sArg = childSparsity(i);
00108       int numConst=0;
00109       int numVec=0;
00110       for (int j=0; j<sArg->numDerivs(); j++)
00111         {
00112           if (sArg->deriv(j).order() == 0) 
00113             {
00114               if (sArg->state(j)==VectorDeriv)
00115                 {
00116                   argValueIndex_[i] = numVec;              
00117                   allArgsAreConstant_ = false;
00118                 }
00119               else
00120                 {
00121                   argValueIndex_[i] = numConst;              
00122                 }
00123               break;
00124             }
00125           if (sArg->state(j) == VectorDeriv) 
00126             {
00127               numVec++;
00128             }
00129           else
00130             {
00131               numConst++;
00132             }
00133         }
00134     }
00135   
00136   
00137   init(expr, context);
00138 }
00139 
00140 
00141 
00142 
00143 void UserDefOpEvaluator::resetNumCalls() const
00144 {
00145   commonEval()->markCacheAsInvalid();
00146   ChainRuleEvaluator::resetNumCalls();
00147 }
00148 
00149 
00150 
00151 
00152 Array<int> UserDefOpEvaluator::findRequiredOrders(const ExprWithChildren* expr, 
00153                                                   const EvalContext& context)
00154 {
00155   Tabs tab0;
00156   SUNDANCE_VERB_HIGH(tab0 << "finding required arg deriv orders");
00157 
00158   Set<int> orders;
00159   
00160   const Set<MultipleDeriv>& R = expr->findR(context);
00161   typedef Set<MultipleDeriv>::const_iterator iter;
00162 
00163   for (iter md=R.begin(); md!=R.end(); md++)
00164     {
00165       Tabs tab1;
00166       
00167       int N = md->order();
00168       if (N > maxOrder_) maxOrder_ = N;
00169       if (N==0) orders.put(N);
00170       for (int n=1; n<=N; n++)
00171         {
00172           const Set<MultiSet<int> >& QW = expr->findQ_W(n, context);
00173           for (Set<MultiSet<int> >::const_iterator q=QW.begin(); q!=QW.end(); q++)
00174             {
00175               orders.put(q->size());
00176             }
00177         }
00178     }
00179   SUNDANCE_VERB_HIGH(tab0 << "arg deriv orders=" << orders);
00180   return orders.elements();
00181 }
00182 
00183 
00184 
00185 
00186 void UserDefOpEvaluator
00187 ::evalArgDerivs(const EvalManager& mgr,
00188                 const Array<RCP<Array<double> > >& constArgVals,
00189                 const Array<RCP<Array<RCP<EvalVector> > > >& varArgVals,
00190                 Array<double>& constArgDerivs,
00191                 Array<RCP<EvalVector> >& varArgDerivs) const
00192 {
00193   if (!commonEval()->cacheIsValid())
00194     {
00195       commonEval()->evalAllComponents(mgr, constArgVals, varArgVals);
00196     }
00197   if (allArgsAreConstant_)
00198     {
00199       constArgDerivs = commonEval()->constArgDerivCache(myIndex());
00200     }
00201   else
00202     {
00203       varArgDerivs = commonEval()->varArgDerivCache(myIndex());
00204     }
00205 }
00206