00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 
00014 
00015 
00016 
00017 
00018 
00019 
00020 
00021 
00022 
00023 
00024 
00025 
00026 
00027 
00028 
00029 
00030 
00031 #include "SundanceChainRuleEvaluator.hpp"
00032 #include "SundanceCombinatorialUtils.hpp"
00033 
00034 #include "SundanceUnknownFuncElement.hpp"
00035 #include "SundanceEvalManager.hpp"
00036 #include "PlayaExceptions.hpp"
00037 #include "SundanceSet.hpp"
00038 #include "PlayaTabs.hpp"
00039 #include "SundanceOut.hpp"
00040 
00041 
00042 using namespace Sundance;
00043 using namespace Sundance;
00044 
00045 using namespace Sundance;
00046 using namespace Teuchos;
00047 
00048 
00049 ChainRuleEvaluator::ChainRuleEvaluator(const ExprWithChildren* expr, 
00050   const EvalContext& context)
00051   : SubtypeEvaluator<ExprWithChildren>(expr, context), 
00052     expansions_(),
00053     childEvaluators_(expr->numChildren()),
00054     childSparsity_(expr->numChildren()),
00055     constArgDerivMap_(),
00056     varArgDerivMap_(),
00057     zerothDerivResultIndex_(-1),
00058     zerothDerivIsConstant_(false)
00059 {
00060   Tabs tabs;
00061   SUNDANCE_MSG1(context.setupVerbosity(),
00062     tabs << "ChainRuleEvaluator base class ctor for " 
00063     << expr->toString());
00064   for (int i=0; i<numChildren(); i++)
00065   {
00066     childEvaluators_[i] = expr->evaluatableChild(i)->evaluator(context);
00067     childEvaluators_[i]->addClient();
00068     childSparsity_[i] = expr->evaluatableChild(i)->sparsitySuperset(context);
00069   }
00070 }
00071 
00072 Sundance::Map<OrderedPair<int, int>, Array<Array<int> > >& ChainRuleEvaluator::compMap()
00073 {
00074   static Map<OrderedPair<int, int>, Array<Array<int> > > rtn;
00075   return rtn;
00076 }
00077 
00078 void ChainRuleEvaluator::resetNumCalls() const
00079 {
00080   for (int i=0; i<numChildren(); i++)
00081   {
00082     childEvaluators_[i]->resetNumCalls();
00083   }
00084   Evaluator::resetNumCalls();
00085 }
00086 
00087 
00088 void ChainRuleEvaluator::addConstArgDeriv(const MultiSet<int>& df, int index)
00089 {
00090   constArgDerivMap_.put(df, index);
00091 }
00092 
00093 void ChainRuleEvaluator::addVarArgDeriv(const MultiSet<int>& df, int index)
00094 {
00095   varArgDerivMap_.put(df, index);
00096 }
00097 
00098 int ChainRuleEvaluator::constArgDerivIndex(const MultiSet<int>& df) const
00099 {
00100   TEUCHOS_TEST_FOR_EXCEPTION(!constArgDerivMap_.containsKey(df), std::logic_error,
00101     "argument derivative " << df << " not found in constant "
00102     "argument derivative map");
00103 
00104   return constArgDerivMap_.get(df);
00105 }
00106 
00107 int ChainRuleEvaluator::varArgDerivIndex(const MultiSet<int>& df) const
00108 {
00109   TEUCHOS_TEST_FOR_EXCEPTION(!varArgDerivMap_.containsKey(df), std::logic_error,
00110     "argument derivative " << df << " not found in variable "
00111     "argument derivative map");
00112 
00113   return varArgDerivMap_.get(df);
00114 }
00115 
00116 
00117 const Array<Array<int> >& ChainRuleEvaluator::nComps(int N, int n) const
00118 {
00119   OrderedPair<int,int> key(n,N);
00120   if (!compMap().containsKey(key))
00121   {
00122     compMap().put(key, compositions(N)[n-1]);
00123   }
00124   return compMap().get(key);
00125 }
00126 
00127 
00128 double ChainRuleEvaluator::fact(int n) const
00129 {
00130   TEUCHOS_TEST_FOR_EXCEPTION(n<0, std::logic_error, "negative argument " << n << " to factorial");
00131   if (n==0 || n==1) return 1.0;
00132   return n*fact(n-1);
00133 }
00134 
00135 double ChainRuleEvaluator::choose(int N, int n) const
00136 {
00137   return fact(N)/fact(n)/fact(N-n);
00138 }
00139 
00140 double ChainRuleEvaluator::stirling2(int n, int k) const
00141 {
00142   if (n < k) return 0;
00143   if (n == k) return 1;
00144   if (k<=0) return 0;
00145   if (k==1) return 1;
00146   if (n-1 == k) return choose(n, 2);
00147   return k*stirling2(n-1, k) + stirling2(n-1, k-1);
00148 }
00149 
00150 
00151 MultipleDeriv ChainRuleEvaluator::makeMD(const Array<Deriv>& d) 
00152 {
00153   MultipleDeriv rtn;
00154   for (int i=0; i<d.size(); i++)
00155   {
00156     rtn.put(d[i]);
00157   }
00158   return rtn;
00159 }
00160 
00161 
00162 Set<MultiSet<MultipleDeriv> > ChainRuleEvaluator::chainRuleBins(const MultipleDeriv& d,
00163   const MultiSet<int>& q)
00164 {
00165   int n = q.size();
00166   Array<Array<Array<Deriv> > > bins = binnings(d, n);
00167 
00168   Set<MultiSet<MultipleDeriv> > rtn;
00169 
00170   for (int i=0; i<bins.size(); i++)
00171   {
00172     MultiSet<MultipleDeriv> b;
00173     for (int j=0; j<bins[i].size(); j++)
00174     {
00175       b.put(makeMD(bins[i][j]));
00176     }
00177     rtn.put(b);
00178   }
00179 
00180 
00181   return rtn;
00182 }
00183 
00184 
00185 int ChainRuleEvaluator::derivComboMultiplicity(const MultiSet<MultipleDeriv>& b) const
00186 {
00187   
00188 
00189   MultipleDeriv dTot;
00190   Array<MultiSet<Deriv> > derivSets(b.size());
00191   Array<Array<Deriv> > derivArrays(b.size());
00192   Set<Deriv> allDerivs;
00193   int k=0;
00194   bool allDerivsAreDistinct = true;
00195   bool allDerivsAreIdentical = true;
00196   for (MultiSet<MultipleDeriv>::const_iterator i=b.begin(); i!=b.end(); i++, k++)
00197   {
00198     for (MultipleDeriv::const_iterator j=i->begin(); j!=i->end(); j++)
00199     {
00200       derivSets[k].put(*j);
00201       derivArrays[k].append(*j);
00202       dTot.put(*j);
00203       if (allDerivs.contains(*j)) allDerivsAreDistinct = false;
00204       if (allDerivs.size()>0 && !allDerivs.contains(*j)) allDerivsAreIdentical = false;
00205       allDerivs.put(*j);
00206     }
00207   }
00208   int totOrder = dTot.order();
00209 
00210   
00211   TEUCHOS_TEST_FOR_EXCEPTION(totOrder > 3, std::logic_error,
00212     "deriv order " << totOrder << " not supported");
00213 
00214   if (b.size()==1) return 1;  
00215   if (totOrder == (int) b.size()) return 1; 
00216 
00217   
00218 
00219 
00220 
00221 
00222 
00223 
00224 
00225 
00226 
00227   TEUCHOS_TEST_FOR_EXCEPTION(derivArrays.size() != 2, std::logic_error,
00228     "unexpected size=" << derivArrays.size());
00229 
00230   if (allDerivsAreIdentical) return 3;
00231   if (allDerivsAreDistinct) return 1;
00232 
00233   if (derivArrays[0].size()==1) 
00234   {
00235     if (derivSets[1].contains(derivArrays[0][0])) return 2;
00236     return 1;
00237   }
00238   else
00239   {
00240     if (derivSets[0].contains(derivArrays[1][0])) return 2;
00241     return 1;
00242   }
00243 }
00244 
00245 
00246 void ChainRuleEvaluator::init(const ExprWithChildren* expr, 
00247   const EvalContext& context)
00248 {
00249   int verb = context.setupVerbosity();
00250 
00251   typedef Array<OrderedPair<Array<MultiSet<int> >, Array<MultipleDeriv> > > CR;
00252   Tabs tabs;
00253   SUNDANCE_MSG1(verb, tabs << "ChainRuleEvaluator::init() for " 
00254     << expr->toString());
00255 
00256   const Set<MultipleDeriv>& C = expr->findC(context);
00257   const Set<MultipleDeriv>& R = expr->findR(context);
00258 
00259   Array<Set<MultipleDeriv> > argV(expr->numChildren());
00260   Array<Set<MultipleDeriv> > argC(expr->numChildren());
00261   Array<Set<MultipleDeriv> > argR(expr->numChildren());
00262 
00263   for (int i=0; i<numChildren(); i++)
00264   {
00265     argV[i] = expr->evaluatableChild(i)->findV(context);
00266     argC[i] = expr->evaluatableChild(i)->findC(context);
00267     argR[i] = expr->evaluatableChild(i)->findR(context);
00268   }
00269   SUNDANCE_MSG3(verb, tabs << "sparsity = " << *(this->sparsity()));
00270   typedef Set<MultipleDeriv>::const_iterator iter;
00271 
00272   int count=0;
00273   int vecResultIndex = 0;
00274   int constResultIndex = 0;
00275   for (iter md=R.begin(); md!=R.end(); md++, count++)
00276   {
00277     Tabs tab1;
00278     SUNDANCE_MSG3(verb, tab1 << "working out evaluator for " << *md);
00279     int N = md->order();
00280     bool resultIsConstant = C.contains(*md);
00281     int resultIndex;
00282     if (resultIsConstant)
00283     {
00284       Tabs tab2;
00285       SUNDANCE_MSG3(verb, tab2 << "result is constant, const index=" << constResultIndex);
00286       addConstantIndex(count, constResultIndex);
00287       resultIndex = constResultIndex;
00288       constResultIndex++;
00289     }
00290     else
00291     {
00292       Tabs tab2;
00293       SUNDANCE_MSG3(verb, tab2 << "result is variable, vec index=" << vecResultIndex);
00294       addVectorIndex(count, vecResultIndex);
00295       resultIndex = vecResultIndex;
00296       vecResultIndex++;
00297     }
00298 
00299     SUNDANCE_MSG3(verb, tab1 << "order=" << N);
00300       
00301     if (N==0)
00302     {
00303       Tabs tab2;
00304       SUNDANCE_MSG3(verb, tab2 << "zeroth deriv index=" << resultIndex);
00305       zerothDerivIsConstant_ = resultIsConstant;
00306       zerothDerivResultIndex_ = resultIndex;
00307       continue;
00308     }
00309 
00310 
00311       
00312     RCP<ChainRuleSum> sum 
00313       = rcp(new ChainRuleSum(*md, resultIndex, resultIsConstant));
00314 
00315     const MultipleDeriv& nu = *md;
00316 
00317     for (int n=1; n<=N; n++)
00318     {
00319       Tabs tab2;
00320       SUNDANCE_MSG3(verb, tab2 << "n=" << n);
00321       const Set<MultiSet<int> >& QW = expr->findQ_W(n, context);
00322       const Set<MultiSet<int> >& QC = expr->findQ_C(n, context);
00323       SUNDANCE_MSG3(verb, tab2 << "Q_W=" << QW);
00324       SUNDANCE_MSG3(verb, tab2 << "Q_C=" << QC);
00325       for (Set<MultiSet<int> >::const_iterator 
00326              j=QW.begin(); j!=QW.end(); j++)
00327       {
00328         Tabs tab3;
00329         const MultiSet<int>& lambda = *j;
00330         SUNDANCE_MSG3(verb, tab3 << "arg index set =" << lambda);
00331         bool argDerivIsConstant = QC.contains(lambda);
00332         int argDerivIndex = -1;
00333         if (argDerivIsConstant) 
00334         {
00335           argDerivIndex = constArgDerivIndex(lambda);
00336         }
00337         else 
00338         {
00339           argDerivIndex = varArgDerivIndex(lambda);
00340         }
00341         Array<DerivProduct> pSum;
00342         for (int s=1; s<=N; s++)
00343         {
00344           Tabs tab4;
00345           SUNDANCE_MSG3(verb, tab4 << "preparing chain rule terms for "
00346             "s=" << s << ", lambda=" << lambda << ", nu=" << nu);
00347           CR p = chainRuleTerms(s, lambda, nu);
00348           for (CR::const_iterator j=p.begin(); j!=p.end(); j++)
00349           {
00350             Tabs tab5;
00351             Array<MultiSet<int> > K = j->first();
00352             Array<MultipleDeriv> L = j->second();
00353             SUNDANCE_MSG3(verb, tab5 << "K=" << K << std::endl << tab5 << "L=" << L);
00354             double weight = chainRuleMultiplicity(nu, K, L);
00355             SUNDANCE_MSG3(verb, tab5 << "weight=" << weight);
00356             DerivProduct prod(weight);
00357             bool termIsZero = false;
00358             for (int j=0; j<K.size(); j++)
00359             {
00360               for (MultiSet<int>::const_iterator 
00361                      k=K[j].begin(); k!=K[j].end(); k++)
00362               {
00363                 int argIndex = *k;
00364                 const MultipleDeriv& derivOfArg = L[j];
00365                 const RCP<SparsitySuperset>& argSp 
00366                   = childSparsity_[argIndex];
00367                 const RCP<Evaluator>& argEv
00368                   = childEvaluators_[argIndex];
00369                                
00370                 int rawValIndex = argSp->getIndex(derivOfArg);
00371                 SUNDANCE_MSG3(verb, tab5 << "argR=" 
00372                   << argR[argIndex]);
00373                 if (argV[argIndex].contains(derivOfArg))
00374                 {
00375                   SUNDANCE_MSG3(verb, tab5 << "mdArg is variable");
00376                   int valIndex 
00377                     = argEv->vectorIndexMap().get(rawValIndex);
00378                   prod.addVariableFactor(IndexPair(argIndex, valIndex));
00379                 }
00380                 else if (argC[argIndex].contains(derivOfArg))
00381                 {
00382                   SUNDANCE_MSG3(verb, tab5 << "mdArg is constant");
00383                   int valIndex 
00384                     = argEv->constantIndexMap().get(rawValIndex);
00385                   prod.addConstantFactor(IndexPair(argIndex, valIndex));
00386                 }
00387                 else
00388                 {
00389                   SUNDANCE_MSG3(verb, tab5 << "mdArg is zero");
00390                   termIsZero = true;
00391                   break;
00392                 }
00393               }
00394               if (termIsZero) break;
00395             }
00396             if (!termIsZero) pSum.append(prod);
00397           }
00398         }
00399         sum->addTerm(argDerivIndex, argDerivIsConstant, pSum);
00400       }
00401     }
00402     TEUCHOS_TEST_FOR_EXCEPTION(sum->numTerms()==0, std::logic_error,
00403       "Empty sum in chain rule expansion for derivative "
00404       << *md);
00405     expansions_.append(sum);
00406   }
00407 
00408   SUNDANCE_MSG3(verb, tabs << "num constant results: " 
00409     << this->sparsity()->numConstantDerivs());
00410 
00411   SUNDANCE_MSG3(verb, tabs << "num var results: " 
00412     << this->sparsity()->numVectorDerivs());
00413 
00414   
00415 }
00416 
00417 
00418 
00419 void ChainRuleEvaluator::internalEval(const EvalManager& mgr,
00420   Array<double>& constantResults,
00421   Array<RCP<EvalVector> >& vectorResults) const 
00422 {
00423   TimeMonitor timer(chainRuleEvalTimer());
00424   Tabs tabs(0);
00425 
00426   SUNDANCE_MSG1(mgr.verb(), tabs << "ChainRuleEvaluator::eval() expr=" 
00427     << expr()->toString());
00428 
00429   
00430   SUNDANCE_MSG2(mgr.verb(), tabs << "max diff order = " << mgr.getRegion().topLevelDiffOrder());
00431   SUNDANCE_MSG2(mgr.verb(), tabs << "return sparsity " << std::endl << tabs << *(this->sparsity()));
00432   
00433   constantResults.resize(this->sparsity()->numConstantDerivs());
00434   vectorResults.resize(this->sparsity()->numVectorDerivs());
00435 
00436   SUNDANCE_MSG3(mgr.verb(),tabs << "num constant results: " 
00437     << this->sparsity()->numConstantDerivs());
00438 
00439   SUNDANCE_MSG3(mgr.verb(),tabs << "num var results: " 
00440     << this->sparsity()->numVectorDerivs());
00441 
00442   Array<RCP<Array<double> > > constantArgResults(numChildren());
00443   Array<RCP<Array<RCP<EvalVector> > > > varArgResults(numChildren());
00444 
00445   Array<double> constantArgDerivs;
00446   Array<RCP<EvalVector> > varArgDerivs;
00447 
00448   for (int i=0; i<numChildren(); i++)
00449   {
00450     Tabs tab1;
00451     SUNDANCE_MSG3(mgr.verb(), tab1 << "computing results for child #"
00452       << i);
00453                          
00454     constantArgResults[i] = rcp(new Array<double>());
00455     varArgResults[i] = rcp(new Array<RCP<EvalVector> >());
00456     childEvaluators_[i]->eval(mgr, *(constantArgResults[i]), 
00457       *(varArgResults[i]));
00458     if (mgr.verb() > 3)
00459     {
00460       Out::os() << tabs << "constant arg #" << i << 
00461         " results:" << *(constantArgResults[i]) << std::endl;
00462       Out::os() << tabs << "variable arg #" << i << " derivs:" << std::endl;
00463       for (int j=0; j<varArgResults[i]->size(); j++)
00464       {
00465         Tabs tab1;
00466         Out::os() << tab1 << j << " ";
00467         (*(varArgResults[i]))[j]->print(Out::os());
00468         Out::os() << std::endl;
00469       }
00470     }
00471   }
00472 
00473   evalArgDerivs(mgr, constantArgResults, varArgResults,
00474     constantArgDerivs, varArgDerivs);
00475 
00476   
00477   if (mgr.verb() > 2)
00478   {
00479     Out::os() << tabs << "constant arg derivs:" << constantArgDerivs << std::endl;
00480     Out::os() << tabs << "variable arg derivs:" << std::endl;
00481     for (int i=0; i<varArgDerivs.size(); i++)
00482     {
00483       Tabs tab1;
00484       Out::os() << tab1 << i << " ";
00485       varArgDerivs[i]->print(Out::os());
00486       Out::os() << std::endl;
00487     }
00488   }
00489   
00490 
00491   for (int i=0; i<expansions_.size(); i++)
00492   {
00493     Tabs tab1;
00494     int resultIndex = expansions_[i]->resultIndex();
00495     bool isConstant = expansions_[i]->resultIsConstant();
00496     SUNDANCE_MSG3(mgr.verb(), tab1 << "doing expansion for deriv #" << i
00497       << ", result index="
00498       << resultIndex << std::endl << tab1
00499       << "deriv=" << expansions_[i]->deriv());
00500     if (isConstant)
00501     {
00502       expansions_[i]->evalConstant(mgr, constantArgResults, constantArgDerivs,
00503         constantResults[resultIndex]);
00504     }
00505     else
00506     {
00507       expansions_[i]->evalVar(mgr, constantArgResults, varArgResults,
00508         constantArgDerivs, varArgDerivs,
00509         vectorResults[resultIndex]);
00510     }
00511   }
00512 
00513   if (zerothDerivResultIndex_ >= 0)
00514   {
00515     SUNDANCE_MSG3(mgr.verb(), tabs << "processing zeroth-order deriv");
00516     Tabs tab1;
00517     SUNDANCE_MSG3(mgr.verb(), tab1 << "result index = " << zerothDerivResultIndex_);
00518     if (zerothDerivIsConstant_)
00519     {
00520       Tabs tab2;
00521       SUNDANCE_MSG3(mgr.verb(), tab2 << "zeroth-order deriv is constant");
00522       constantResults[zerothDerivResultIndex_] = constantArgDerivs[0];
00523     }
00524     else
00525     {
00526       Tabs tab2;
00527       SUNDANCE_MSG3(mgr.verb(), tab2 << "zeroth-order deriv is variable");
00528       vectorResults[zerothDerivResultIndex_] = varArgDerivs[0];
00529     }
00530   }
00531 
00532 
00533   if (mgr.verb() > 1)
00534   {
00535     Tabs tab1;
00536     Out::os() << tab1 << "chain rule results " << std::endl;
00537     mgr.showResults(Out::os(), this->sparsity(), vectorResults,
00538       constantResults);
00539   }
00540 
00541   SUNDANCE_MSG1(mgr.verb(), tabs << "ChainRuleEvaluator::eval() done"); 
00542 }
00543 
00544 
00545 
00546 
00547 namespace Sundance {
00548 
00549 MultipleDeriv makeDeriv(const Expr& a)
00550 {
00551   const UnknownFuncElement* aPtr
00552     = dynamic_cast<const UnknownFuncElement*>(a[0].ptr().get());
00553 
00554   TEUCHOS_TEST_FOR_EXCEPT(aPtr==0);
00555 
00556   Deriv d = funcDeriv(aPtr);
00557   MultipleDeriv rtn;
00558   rtn.put(d);
00559   return rtn;
00560 }
00561 
00562 
00563 MultipleDeriv makeDeriv(const Expr& a, const Expr& b)
00564 {
00565   const UnknownFuncElement* aPtr
00566     = dynamic_cast<const UnknownFuncElement*>(a[0].ptr().get());
00567 
00568   TEUCHOS_TEST_FOR_EXCEPT(aPtr==0);
00569 
00570   const UnknownFuncElement* bPtr
00571     = dynamic_cast<const UnknownFuncElement*>(b[0].ptr().get());
00572 
00573   TEUCHOS_TEST_FOR_EXCEPT(bPtr==0);
00574 
00575   Deriv da = funcDeriv(aPtr);
00576   Deriv db = funcDeriv(bPtr);
00577   MultipleDeriv rtn;
00578   rtn.put(da);
00579   rtn.put(db);
00580   return rtn;
00581 }
00582 
00583 
00584 
00585 MultipleDeriv makeDeriv(const Expr& a, const Expr& b, const Expr& c)
00586 {
00587   const UnknownFuncElement* aPtr
00588     = dynamic_cast<const UnknownFuncElement*>(a[0].ptr().get());
00589 
00590   TEUCHOS_TEST_FOR_EXCEPT(aPtr==0);
00591 
00592   const UnknownFuncElement* bPtr
00593     = dynamic_cast<const UnknownFuncElement*>(b[0].ptr().get());
00594 
00595   TEUCHOS_TEST_FOR_EXCEPT(bPtr==0);
00596 
00597   const UnknownFuncElement* cPtr
00598     = dynamic_cast<const UnknownFuncElement*>(c[0].ptr().get());
00599 
00600   TEUCHOS_TEST_FOR_EXCEPT(cPtr==0);
00601 
00602   Deriv da = funcDeriv(aPtr);
00603   Deriv db = funcDeriv(bPtr);
00604   Deriv dc = funcDeriv(cPtr);
00605   MultipleDeriv rtn;
00606   rtn.put(da);
00607   rtn.put(db);
00608   rtn.put(dc);
00609   return rtn;
00610 }
00611 
00612 }