00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031 #include "SundanceUserDefOpEvaluator.hpp"
00032 #include "SundanceUserDefOpCommonEvaluator.hpp"
00033 #include "SundanceUserDefOpElement.hpp"
00034 #include "SundanceEvalManager.hpp"
00035
00036 #include "PlayaTabs.hpp"
00037 #include "SundanceOut.hpp"
00038 #include "SundanceUserDefOp.hpp"
00039
00040 using namespace Sundance;
00041 using namespace Sundance;
00042 using namespace Sundance;
00043 using namespace Teuchos;
00044
00045
00046
00047
00048
00049 UserDefOpEvaluator
00050 ::UserDefOpEvaluator(const UserDefOpElement* expr,
00051 const RCP<const UserDefOpCommonEvaluator>& commonEval,
00052 const EvalContext& context)
00053 : ChainRuleEvaluator(expr, context),
00054 argValueIndex_(expr->numChildren()),
00055 argValueIsConstant_(expr->numChildren()),
00056 functor_(expr->functorElement()),
00057 commonEval_(commonEval),
00058 maxOrder_(0),
00059 numVarArgDerivs_(0),
00060 numConstArgDerivs_(0),
00061 allArgsAreConstant_(true)
00062 {
00063 Tabs tab1;
00064 SUNDANCE_VERB_LOW(tab1 << "initializing user defined op evaluator for "
00065 << expr->toString());
00066 Array<int> orders = findRequiredOrders(expr, context);
00067
00068 for (int i=0; i<orders.size(); i++)
00069 {
00070 if (orders[i] > maxOrder_) maxOrder_ = orders[i];
00071 }
00072 commonEval->updateMaxOrder(maxOrder_);
00073
00074 SUNDANCE_VERB_HIGH(tab1 << "setting arg deriv indices");
00075
00076
00077
00078
00079 Map<MultiSet<int>, int> varArgDerivs;
00080 Map<MultiSet<int>, int> constArgDerivs;
00081 expr->getArgDerivIndices(orders, varArgDerivs, constArgDerivs);
00082 numVarArgDerivs_ = varArgDerivs.size();
00083 numConstArgDerivs_ = constArgDerivs.size();
00084 typedef Map<MultiSet<int>, int>::const_iterator iter;
00085 for (iter i=varArgDerivs.begin(); i!=varArgDerivs.end(); i++)
00086 {
00087 Tabs tab2;
00088 SUNDANCE_VERB_EXTREME(tab2 << "variable arg deriv " << i->first
00089 << " will be at index "
00090 << i->second);
00091 addVarArgDeriv(i->first, i->second);
00092 }
00093
00094 for (iter i=constArgDerivs.begin(); i!=constArgDerivs.end(); i++)
00095 {
00096 Tabs tab2;
00097 SUNDANCE_VERB_EXTREME(tab2 << "constant arg deriv " << i->first
00098 << " will be at index "
00099 << i->second);
00100 addConstArgDeriv(i->first, i->second);
00101 }
00102
00103
00104
00105 for (int i=0; i<expr->numChildren(); i++)
00106 {
00107 const SparsitySuperset* sArg = childSparsity(i);
00108 int numConst=0;
00109 int numVec=0;
00110 for (int j=0; j<sArg->numDerivs(); j++)
00111 {
00112 if (sArg->deriv(j).order() == 0)
00113 {
00114 if (sArg->state(j)==VectorDeriv)
00115 {
00116 argValueIndex_[i] = numVec;
00117 allArgsAreConstant_ = false;
00118 }
00119 else
00120 {
00121 argValueIndex_[i] = numConst;
00122 }
00123 break;
00124 }
00125 if (sArg->state(j) == VectorDeriv)
00126 {
00127 numVec++;
00128 }
00129 else
00130 {
00131 numConst++;
00132 }
00133 }
00134 }
00135
00136
00137 init(expr, context);
00138 }
00139
00140
00141
00142
00143 void UserDefOpEvaluator::resetNumCalls() const
00144 {
00145 commonEval()->markCacheAsInvalid();
00146 ChainRuleEvaluator::resetNumCalls();
00147 }
00148
00149
00150
00151
00152 Array<int> UserDefOpEvaluator::findRequiredOrders(const ExprWithChildren* expr,
00153 const EvalContext& context)
00154 {
00155 Tabs tab0;
00156 SUNDANCE_VERB_HIGH(tab0 << "finding required arg deriv orders");
00157
00158 Set<int> orders;
00159
00160 const Set<MultipleDeriv>& R = expr->findR(context);
00161 typedef Set<MultipleDeriv>::const_iterator iter;
00162
00163 for (iter md=R.begin(); md!=R.end(); md++)
00164 {
00165 Tabs tab1;
00166
00167 int N = md->order();
00168 if (N > maxOrder_) maxOrder_ = N;
00169 if (N==0) orders.put(N);
00170 for (int n=1; n<=N; n++)
00171 {
00172 const Set<MultiSet<int> >& QW = expr->findQ_W(n, context);
00173 for (Set<MultiSet<int> >::const_iterator q=QW.begin(); q!=QW.end(); q++)
00174 {
00175 orders.put(q->size());
00176 }
00177 }
00178 }
00179 SUNDANCE_VERB_HIGH(tab0 << "arg deriv orders=" << orders);
00180 return orders.elements();
00181 }
00182
00183
00184
00185
00186 void UserDefOpEvaluator
00187 ::evalArgDerivs(const EvalManager& mgr,
00188 const Array<RCP<Array<double> > >& constArgVals,
00189 const Array<RCP<Array<RCP<EvalVector> > > >& varArgVals,
00190 Array<double>& constArgDerivs,
00191 Array<RCP<EvalVector> >& varArgDerivs) const
00192 {
00193 if (!commonEval()->cacheIsValid())
00194 {
00195 commonEval()->evalAllComponents(mgr, constArgVals, varArgVals);
00196 }
00197 if (allArgsAreConstant_)
00198 {
00199 constArgDerivs = commonEval()->constArgDerivCache(myIndex());
00200 }
00201 else
00202 {
00203 varArgDerivs = commonEval()->varArgDerivCache(myIndex());
00204 }
00205 }
00206