00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031 #include "SundanceChainRuleEvaluator.hpp"
00032 #include "SundanceCombinatorialUtils.hpp"
00033
00034 #include "SundanceUnknownFuncElement.hpp"
00035 #include "SundanceEvalManager.hpp"
00036 #include "PlayaExceptions.hpp"
00037 #include "SundanceSet.hpp"
00038 #include "PlayaTabs.hpp"
00039 #include "SundanceOut.hpp"
00040
00041
00042 using namespace Sundance;
00043 using namespace Sundance;
00044
00045 using namespace Sundance;
00046 using namespace Teuchos;
00047
00048
00049 ChainRuleEvaluator::ChainRuleEvaluator(const ExprWithChildren* expr,
00050 const EvalContext& context)
00051 : SubtypeEvaluator<ExprWithChildren>(expr, context),
00052 expansions_(),
00053 childEvaluators_(expr->numChildren()),
00054 childSparsity_(expr->numChildren()),
00055 constArgDerivMap_(),
00056 varArgDerivMap_(),
00057 zerothDerivResultIndex_(-1),
00058 zerothDerivIsConstant_(false)
00059 {
00060 Tabs tabs;
00061 SUNDANCE_MSG1(context.setupVerbosity(),
00062 tabs << "ChainRuleEvaluator base class ctor for "
00063 << expr->toString());
00064 for (int i=0; i<numChildren(); i++)
00065 {
00066 childEvaluators_[i] = expr->evaluatableChild(i)->evaluator(context);
00067 childEvaluators_[i]->addClient();
00068 childSparsity_[i] = expr->evaluatableChild(i)->sparsitySuperset(context);
00069 }
00070 }
00071
00072 Sundance::Map<OrderedPair<int, int>, Array<Array<int> > >& ChainRuleEvaluator::compMap()
00073 {
00074 static Map<OrderedPair<int, int>, Array<Array<int> > > rtn;
00075 return rtn;
00076 }
00077
00078 void ChainRuleEvaluator::resetNumCalls() const
00079 {
00080 for (int i=0; i<numChildren(); i++)
00081 {
00082 childEvaluators_[i]->resetNumCalls();
00083 }
00084 Evaluator::resetNumCalls();
00085 }
00086
00087
00088 void ChainRuleEvaluator::addConstArgDeriv(const MultiSet<int>& df, int index)
00089 {
00090 constArgDerivMap_.put(df, index);
00091 }
00092
00093 void ChainRuleEvaluator::addVarArgDeriv(const MultiSet<int>& df, int index)
00094 {
00095 varArgDerivMap_.put(df, index);
00096 }
00097
00098 int ChainRuleEvaluator::constArgDerivIndex(const MultiSet<int>& df) const
00099 {
00100 TEUCHOS_TEST_FOR_EXCEPTION(!constArgDerivMap_.containsKey(df), std::logic_error,
00101 "argument derivative " << df << " not found in constant "
00102 "argument derivative map");
00103
00104 return constArgDerivMap_.get(df);
00105 }
00106
00107 int ChainRuleEvaluator::varArgDerivIndex(const MultiSet<int>& df) const
00108 {
00109 TEUCHOS_TEST_FOR_EXCEPTION(!varArgDerivMap_.containsKey(df), std::logic_error,
00110 "argument derivative " << df << " not found in variable "
00111 "argument derivative map");
00112
00113 return varArgDerivMap_.get(df);
00114 }
00115
00116
00117 const Array<Array<int> >& ChainRuleEvaluator::nComps(int N, int n) const
00118 {
00119 OrderedPair<int,int> key(n,N);
00120 if (!compMap().containsKey(key))
00121 {
00122 compMap().put(key, compositions(N)[n-1]);
00123 }
00124 return compMap().get(key);
00125 }
00126
00127
00128 double ChainRuleEvaluator::fact(int n) const
00129 {
00130 TEUCHOS_TEST_FOR_EXCEPTION(n<0, std::logic_error, "negative argument " << n << " to factorial");
00131 if (n==0 || n==1) return 1.0;
00132 return n*fact(n-1);
00133 }
00134
00135 double ChainRuleEvaluator::choose(int N, int n) const
00136 {
00137 return fact(N)/fact(n)/fact(N-n);
00138 }
00139
00140 double ChainRuleEvaluator::stirling2(int n, int k) const
00141 {
00142 if (n < k) return 0;
00143 if (n == k) return 1;
00144 if (k<=0) return 0;
00145 if (k==1) return 1;
00146 if (n-1 == k) return choose(n, 2);
00147 return k*stirling2(n-1, k) + stirling2(n-1, k-1);
00148 }
00149
00150
00151 MultipleDeriv ChainRuleEvaluator::makeMD(const Array<Deriv>& d)
00152 {
00153 MultipleDeriv rtn;
00154 for (int i=0; i<d.size(); i++)
00155 {
00156 rtn.put(d[i]);
00157 }
00158 return rtn;
00159 }
00160
00161
00162 Set<MultiSet<MultipleDeriv> > ChainRuleEvaluator::chainRuleBins(const MultipleDeriv& d,
00163 const MultiSet<int>& q)
00164 {
00165 int n = q.size();
00166 Array<Array<Array<Deriv> > > bins = binnings(d, n);
00167
00168 Set<MultiSet<MultipleDeriv> > rtn;
00169
00170 for (int i=0; i<bins.size(); i++)
00171 {
00172 MultiSet<MultipleDeriv> b;
00173 for (int j=0; j<bins[i].size(); j++)
00174 {
00175 b.put(makeMD(bins[i][j]));
00176 }
00177 rtn.put(b);
00178 }
00179
00180
00181 return rtn;
00182 }
00183
00184
00185 int ChainRuleEvaluator::derivComboMultiplicity(const MultiSet<MultipleDeriv>& b) const
00186 {
00187
00188
00189 MultipleDeriv dTot;
00190 Array<MultiSet<Deriv> > derivSets(b.size());
00191 Array<Array<Deriv> > derivArrays(b.size());
00192 Set<Deriv> allDerivs;
00193 int k=0;
00194 bool allDerivsAreDistinct = true;
00195 bool allDerivsAreIdentical = true;
00196 for (MultiSet<MultipleDeriv>::const_iterator i=b.begin(); i!=b.end(); i++, k++)
00197 {
00198 for (MultipleDeriv::const_iterator j=i->begin(); j!=i->end(); j++)
00199 {
00200 derivSets[k].put(*j);
00201 derivArrays[k].append(*j);
00202 dTot.put(*j);
00203 if (allDerivs.contains(*j)) allDerivsAreDistinct = false;
00204 if (allDerivs.size()>0 && !allDerivs.contains(*j)) allDerivsAreIdentical = false;
00205 allDerivs.put(*j);
00206 }
00207 }
00208 int totOrder = dTot.order();
00209
00210
00211 TEUCHOS_TEST_FOR_EXCEPTION(totOrder > 3, std::logic_error,
00212 "deriv order " << totOrder << " not supported");
00213
00214 if (b.size()==1) return 1;
00215 if (totOrder == (int) b.size()) return 1;
00216
00217
00218
00219
00220
00221
00222
00223
00224
00225
00226
00227 TEUCHOS_TEST_FOR_EXCEPTION(derivArrays.size() != 2, std::logic_error,
00228 "unexpected size=" << derivArrays.size());
00229
00230 if (allDerivsAreIdentical) return 3;
00231 if (allDerivsAreDistinct) return 1;
00232
00233 if (derivArrays[0].size()==1)
00234 {
00235 if (derivSets[1].contains(derivArrays[0][0])) return 2;
00236 return 1;
00237 }
00238 else
00239 {
00240 if (derivSets[0].contains(derivArrays[1][0])) return 2;
00241 return 1;
00242 }
00243 }
00244
00245
00246 void ChainRuleEvaluator::init(const ExprWithChildren* expr,
00247 const EvalContext& context)
00248 {
00249 int verb = context.setupVerbosity();
00250
00251 typedef Array<OrderedPair<Array<MultiSet<int> >, Array<MultipleDeriv> > > CR;
00252 Tabs tabs;
00253 SUNDANCE_MSG1(verb, tabs << "ChainRuleEvaluator::init() for "
00254 << expr->toString());
00255
00256 const Set<MultipleDeriv>& C = expr->findC(context);
00257 const Set<MultipleDeriv>& R = expr->findR(context);
00258
00259 Array<Set<MultipleDeriv> > argV(expr->numChildren());
00260 Array<Set<MultipleDeriv> > argC(expr->numChildren());
00261 Array<Set<MultipleDeriv> > argR(expr->numChildren());
00262
00263 for (int i=0; i<numChildren(); i++)
00264 {
00265 argV[i] = expr->evaluatableChild(i)->findV(context);
00266 argC[i] = expr->evaluatableChild(i)->findC(context);
00267 argR[i] = expr->evaluatableChild(i)->findR(context);
00268 }
00269 SUNDANCE_MSG3(verb, tabs << "sparsity = " << *(this->sparsity()));
00270 typedef Set<MultipleDeriv>::const_iterator iter;
00271
00272 int count=0;
00273 int vecResultIndex = 0;
00274 int constResultIndex = 0;
00275 for (iter md=R.begin(); md!=R.end(); md++, count++)
00276 {
00277 Tabs tab1;
00278 SUNDANCE_MSG3(verb, tab1 << "working out evaluator for " << *md);
00279 int N = md->order();
00280 bool resultIsConstant = C.contains(*md);
00281 int resultIndex;
00282 if (resultIsConstant)
00283 {
00284 Tabs tab2;
00285 SUNDANCE_MSG3(verb, tab2 << "result is constant, const index=" << constResultIndex);
00286 addConstantIndex(count, constResultIndex);
00287 resultIndex = constResultIndex;
00288 constResultIndex++;
00289 }
00290 else
00291 {
00292 Tabs tab2;
00293 SUNDANCE_MSG3(verb, tab2 << "result is variable, vec index=" << vecResultIndex);
00294 addVectorIndex(count, vecResultIndex);
00295 resultIndex = vecResultIndex;
00296 vecResultIndex++;
00297 }
00298
00299 SUNDANCE_MSG3(verb, tab1 << "order=" << N);
00300
00301 if (N==0)
00302 {
00303 Tabs tab2;
00304 SUNDANCE_MSG3(verb, tab2 << "zeroth deriv index=" << resultIndex);
00305 zerothDerivIsConstant_ = resultIsConstant;
00306 zerothDerivResultIndex_ = resultIndex;
00307 continue;
00308 }
00309
00310
00311
00312 RCP<ChainRuleSum> sum
00313 = rcp(new ChainRuleSum(*md, resultIndex, resultIsConstant));
00314
00315 const MultipleDeriv& nu = *md;
00316
00317 for (int n=1; n<=N; n++)
00318 {
00319 Tabs tab2;
00320 SUNDANCE_MSG3(verb, tab2 << "n=" << n);
00321 const Set<MultiSet<int> >& QW = expr->findQ_W(n, context);
00322 const Set<MultiSet<int> >& QC = expr->findQ_C(n, context);
00323 SUNDANCE_MSG3(verb, tab2 << "Q_W=" << QW);
00324 SUNDANCE_MSG3(verb, tab2 << "Q_C=" << QC);
00325 for (Set<MultiSet<int> >::const_iterator
00326 j=QW.begin(); j!=QW.end(); j++)
00327 {
00328 Tabs tab3;
00329 const MultiSet<int>& lambda = *j;
00330 SUNDANCE_MSG3(verb, tab3 << "arg index set =" << lambda);
00331 bool argDerivIsConstant = QC.contains(lambda);
00332 int argDerivIndex = -1;
00333 if (argDerivIsConstant)
00334 {
00335 argDerivIndex = constArgDerivIndex(lambda);
00336 }
00337 else
00338 {
00339 argDerivIndex = varArgDerivIndex(lambda);
00340 }
00341 Array<DerivProduct> pSum;
00342 for (int s=1; s<=N; s++)
00343 {
00344 Tabs tab4;
00345 SUNDANCE_MSG3(verb, tab4 << "preparing chain rule terms for "
00346 "s=" << s << ", lambda=" << lambda << ", nu=" << nu);
00347 CR p = chainRuleTerms(s, lambda, nu);
00348 for (CR::const_iterator j=p.begin(); j!=p.end(); j++)
00349 {
00350 Tabs tab5;
00351 Array<MultiSet<int> > K = j->first();
00352 Array<MultipleDeriv> L = j->second();
00353 SUNDANCE_MSG3(verb, tab5 << "K=" << K << std::endl << tab5 << "L=" << L);
00354 double weight = chainRuleMultiplicity(nu, K, L);
00355 SUNDANCE_MSG3(verb, tab5 << "weight=" << weight);
00356 DerivProduct prod(weight);
00357 bool termIsZero = false;
00358 for (int j=0; j<K.size(); j++)
00359 {
00360 for (MultiSet<int>::const_iterator
00361 k=K[j].begin(); k!=K[j].end(); k++)
00362 {
00363 int argIndex = *k;
00364 const MultipleDeriv& derivOfArg = L[j];
00365 const RCP<SparsitySuperset>& argSp
00366 = childSparsity_[argIndex];
00367 const RCP<Evaluator>& argEv
00368 = childEvaluators_[argIndex];
00369
00370 int rawValIndex = argSp->getIndex(derivOfArg);
00371 SUNDANCE_MSG3(verb, tab5 << "argR="
00372 << argR[argIndex]);
00373 if (argV[argIndex].contains(derivOfArg))
00374 {
00375 SUNDANCE_MSG3(verb, tab5 << "mdArg is variable");
00376 int valIndex
00377 = argEv->vectorIndexMap().get(rawValIndex);
00378 prod.addVariableFactor(IndexPair(argIndex, valIndex));
00379 }
00380 else if (argC[argIndex].contains(derivOfArg))
00381 {
00382 SUNDANCE_MSG3(verb, tab5 << "mdArg is constant");
00383 int valIndex
00384 = argEv->constantIndexMap().get(rawValIndex);
00385 prod.addConstantFactor(IndexPair(argIndex, valIndex));
00386 }
00387 else
00388 {
00389 SUNDANCE_MSG3(verb, tab5 << "mdArg is zero");
00390 termIsZero = true;
00391 break;
00392 }
00393 }
00394 if (termIsZero) break;
00395 }
00396 if (!termIsZero) pSum.append(prod);
00397 }
00398 }
00399 sum->addTerm(argDerivIndex, argDerivIsConstant, pSum);
00400 }
00401 }
00402 TEUCHOS_TEST_FOR_EXCEPTION(sum->numTerms()==0, std::logic_error,
00403 "Empty sum in chain rule expansion for derivative "
00404 << *md);
00405 expansions_.append(sum);
00406 }
00407
00408 SUNDANCE_MSG3(verb, tabs << "num constant results: "
00409 << this->sparsity()->numConstantDerivs());
00410
00411 SUNDANCE_MSG3(verb, tabs << "num var results: "
00412 << this->sparsity()->numVectorDerivs());
00413
00414
00415 }
00416
00417
00418
00419 void ChainRuleEvaluator::internalEval(const EvalManager& mgr,
00420 Array<double>& constantResults,
00421 Array<RCP<EvalVector> >& vectorResults) const
00422 {
00423 TimeMonitor timer(chainRuleEvalTimer());
00424 Tabs tabs(0);
00425
00426 SUNDANCE_MSG1(mgr.verb(), tabs << "ChainRuleEvaluator::eval() expr="
00427 << expr()->toString());
00428
00429
00430 SUNDANCE_MSG2(mgr.verb(), tabs << "max diff order = " << mgr.getRegion().topLevelDiffOrder());
00431 SUNDANCE_MSG2(mgr.verb(), tabs << "return sparsity " << std::endl << tabs << *(this->sparsity()));
00432
00433 constantResults.resize(this->sparsity()->numConstantDerivs());
00434 vectorResults.resize(this->sparsity()->numVectorDerivs());
00435
00436 SUNDANCE_MSG3(mgr.verb(),tabs << "num constant results: "
00437 << this->sparsity()->numConstantDerivs());
00438
00439 SUNDANCE_MSG3(mgr.verb(),tabs << "num var results: "
00440 << this->sparsity()->numVectorDerivs());
00441
00442 Array<RCP<Array<double> > > constantArgResults(numChildren());
00443 Array<RCP<Array<RCP<EvalVector> > > > varArgResults(numChildren());
00444
00445 Array<double> constantArgDerivs;
00446 Array<RCP<EvalVector> > varArgDerivs;
00447
00448 for (int i=0; i<numChildren(); i++)
00449 {
00450 Tabs tab1;
00451 SUNDANCE_MSG3(mgr.verb(), tab1 << "computing results for child #"
00452 << i);
00453
00454 constantArgResults[i] = rcp(new Array<double>());
00455 varArgResults[i] = rcp(new Array<RCP<EvalVector> >());
00456 childEvaluators_[i]->eval(mgr, *(constantArgResults[i]),
00457 *(varArgResults[i]));
00458 if (mgr.verb() > 3)
00459 {
00460 Out::os() << tabs << "constant arg #" << i <<
00461 " results:" << *(constantArgResults[i]) << std::endl;
00462 Out::os() << tabs << "variable arg #" << i << " derivs:" << std::endl;
00463 for (int j=0; j<varArgResults[i]->size(); j++)
00464 {
00465 Tabs tab1;
00466 Out::os() << tab1 << j << " ";
00467 (*(varArgResults[i]))[j]->print(Out::os());
00468 Out::os() << std::endl;
00469 }
00470 }
00471 }
00472
00473 evalArgDerivs(mgr, constantArgResults, varArgResults,
00474 constantArgDerivs, varArgDerivs);
00475
00476
00477 if (mgr.verb() > 2)
00478 {
00479 Out::os() << tabs << "constant arg derivs:" << constantArgDerivs << std::endl;
00480 Out::os() << tabs << "variable arg derivs:" << std::endl;
00481 for (int i=0; i<varArgDerivs.size(); i++)
00482 {
00483 Tabs tab1;
00484 Out::os() << tab1 << i << " ";
00485 varArgDerivs[i]->print(Out::os());
00486 Out::os() << std::endl;
00487 }
00488 }
00489
00490
00491 for (int i=0; i<expansions_.size(); i++)
00492 {
00493 Tabs tab1;
00494 int resultIndex = expansions_[i]->resultIndex();
00495 bool isConstant = expansions_[i]->resultIsConstant();
00496 SUNDANCE_MSG3(mgr.verb(), tab1 << "doing expansion for deriv #" << i
00497 << ", result index="
00498 << resultIndex << std::endl << tab1
00499 << "deriv=" << expansions_[i]->deriv());
00500 if (isConstant)
00501 {
00502 expansions_[i]->evalConstant(mgr, constantArgResults, constantArgDerivs,
00503 constantResults[resultIndex]);
00504 }
00505 else
00506 {
00507 expansions_[i]->evalVar(mgr, constantArgResults, varArgResults,
00508 constantArgDerivs, varArgDerivs,
00509 vectorResults[resultIndex]);
00510 }
00511 }
00512
00513 if (zerothDerivResultIndex_ >= 0)
00514 {
00515 SUNDANCE_MSG3(mgr.verb(), tabs << "processing zeroth-order deriv");
00516 Tabs tab1;
00517 SUNDANCE_MSG3(mgr.verb(), tab1 << "result index = " << zerothDerivResultIndex_);
00518 if (zerothDerivIsConstant_)
00519 {
00520 Tabs tab2;
00521 SUNDANCE_MSG3(mgr.verb(), tab2 << "zeroth-order deriv is constant");
00522 constantResults[zerothDerivResultIndex_] = constantArgDerivs[0];
00523 }
00524 else
00525 {
00526 Tabs tab2;
00527 SUNDANCE_MSG3(mgr.verb(), tab2 << "zeroth-order deriv is variable");
00528 vectorResults[zerothDerivResultIndex_] = varArgDerivs[0];
00529 }
00530 }
00531
00532
00533 if (mgr.verb() > 1)
00534 {
00535 Tabs tab1;
00536 Out::os() << tab1 << "chain rule results " << std::endl;
00537 mgr.showResults(Out::os(), this->sparsity(), vectorResults,
00538 constantResults);
00539 }
00540
00541 SUNDANCE_MSG1(mgr.verb(), tabs << "ChainRuleEvaluator::eval() done");
00542 }
00543
00544
00545
00546
00547 namespace Sundance {
00548
00549 MultipleDeriv makeDeriv(const Expr& a)
00550 {
00551 const UnknownFuncElement* aPtr
00552 = dynamic_cast<const UnknownFuncElement*>(a[0].ptr().get());
00553
00554 TEUCHOS_TEST_FOR_EXCEPT(aPtr==0);
00555
00556 Deriv d = funcDeriv(aPtr);
00557 MultipleDeriv rtn;
00558 rtn.put(d);
00559 return rtn;
00560 }
00561
00562
00563 MultipleDeriv makeDeriv(const Expr& a, const Expr& b)
00564 {
00565 const UnknownFuncElement* aPtr
00566 = dynamic_cast<const UnknownFuncElement*>(a[0].ptr().get());
00567
00568 TEUCHOS_TEST_FOR_EXCEPT(aPtr==0);
00569
00570 const UnknownFuncElement* bPtr
00571 = dynamic_cast<const UnknownFuncElement*>(b[0].ptr().get());
00572
00573 TEUCHOS_TEST_FOR_EXCEPT(bPtr==0);
00574
00575 Deriv da = funcDeriv(aPtr);
00576 Deriv db = funcDeriv(bPtr);
00577 MultipleDeriv rtn;
00578 rtn.put(da);
00579 rtn.put(db);
00580 return rtn;
00581 }
00582
00583
00584
00585 MultipleDeriv makeDeriv(const Expr& a, const Expr& b, const Expr& c)
00586 {
00587 const UnknownFuncElement* aPtr
00588 = dynamic_cast<const UnknownFuncElement*>(a[0].ptr().get());
00589
00590 TEUCHOS_TEST_FOR_EXCEPT(aPtr==0);
00591
00592 const UnknownFuncElement* bPtr
00593 = dynamic_cast<const UnknownFuncElement*>(b[0].ptr().get());
00594
00595 TEUCHOS_TEST_FOR_EXCEPT(bPtr==0);
00596
00597 const UnknownFuncElement* cPtr
00598 = dynamic_cast<const UnknownFuncElement*>(c[0].ptr().get());
00599
00600 TEUCHOS_TEST_FOR_EXCEPT(cPtr==0);
00601
00602 Deriv da = funcDeriv(aPtr);
00603 Deriv db = funcDeriv(bPtr);
00604 Deriv dc = funcDeriv(cPtr);
00605 MultipleDeriv rtn;
00606 rtn.put(da);
00607 rtn.put(db);
00608 rtn.put(dc);
00609 return rtn;
00610 }
00611
00612 }