00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031 #include "SundanceChainRuleSum.hpp"
00032 #include "SundanceEvalManager.hpp"
00033 #include "SundanceEvalVector.hpp"
00034 #include "PlayaExceptions.hpp"
00035 #include "SundanceSet.hpp"
00036 #include "PlayaTabs.hpp"
00037 #include "SundanceOut.hpp"
00038
00039 using namespace Sundance;
00040 using namespace Sundance;
00041
00042 using namespace Sundance;
00043 using namespace Teuchos;
00044
00045
00046 ChainRuleSum::ChainRuleSum(const MultipleDeriv& md,
00047 int resultIndex,
00048 bool resultIsConstant)
00049 : md_(md),
00050 resultIndex_(resultIndex),
00051 resultIsConstant_(resultIsConstant),
00052 argDerivIndex_(),
00053 argDerivIsConstant_(),
00054 terms_()
00055 {;}
00056
00057
00058 void ChainRuleSum::addTerm(int argDerivIndex,
00059 bool argDerivIsConstant,
00060 const Array<DerivProduct>& sum)
00061 {
00062 argDerivIndex_.append(argDerivIndex);
00063 argDerivIsConstant_.append(argDerivIsConstant);
00064 terms_.append(sum);
00065 }
00066
00067
00068 void ChainRuleSum
00069 ::evalConstant(const EvalManager& mgr,
00070 const Array<RCP<Array<double> > >& constantArgResults,
00071 const Array<double>& constantArgDerivs,
00072 double& result) const
00073 {
00074 Tabs tabs;
00075 SUNDANCE_VERB_HIGH(tabs << "ChainRuleSum::evalConstant()");
00076 result = 0.0;
00077 for (int i=0; i<numTerms(); i++)
00078 {
00079 const double& argDeriv = constantArgDerivs[argDerivIndex(i)];
00080 const Array<DerivProduct>& sumOfDerivProducts = terms(i);
00081 double innerSum = 0.0;
00082 for (int j=0; j<sumOfDerivProducts.size(); j++)
00083 {
00084 double prod = 1.0;
00085 const DerivProduct& p = sumOfDerivProducts[j];
00086 for (int k=0; k<p.numConstants(); k++)
00087 {
00088 const IndexPair& ip = p.constant(k);
00089 prod *= (*(constantArgResults[ip.argIndex()]))[ip.valueIndex()];
00090 }
00091 innerSum += prod;
00092 }
00093 result += innerSum*argDeriv;
00094 }
00095 }
00096
00097
00098 void ChainRuleSum
00099 ::evalVar(const EvalManager& mgr,
00100 const Array<RCP<Array<double> > >& constantArgResults,
00101 const Array<RCP<Array<RCP<EvalVector> > > > & vArgResults,
00102 const Array<double>& constantArgDerivs,
00103 const Array<RCP<EvalVector> >& varArgDerivs,
00104 RCP<EvalVector>& varResult) const
00105 {
00106 Tabs tabs;
00107 SUNDANCE_VERB_HIGH(tabs << "ChainRuleSum::evalVar()");
00108 int vecSize=-1;
00109 for (int i=0; i<varArgDerivs.size(); i++)
00110 {
00111 int s = varArgDerivs[i]->length();
00112 TEUCHOS_TEST_FOR_EXCEPTION(vecSize != -1 && s != vecSize, std::logic_error,
00113 "inconsistent vector sizes " << vecSize
00114 << " and " << s);
00115 vecSize = s;
00116 }
00117 for (int i=0; i<vArgResults.size(); i++)
00118 {
00119 for (int j=0; j<vArgResults[i]->size(); j++)
00120 {
00121 int s = (*(vArgResults[i]))[j]->length();
00122 TEUCHOS_TEST_FOR_EXCEPTION(vecSize != -1 && s != vecSize, std::logic_error,
00123 "inconsistent vector sizes " << vecSize
00124 << " and " << s);
00125 vecSize = s;
00126 }
00127 }
00128 TEUCHOS_TEST_FOR_EXCEPT(vecSize==-1);
00129
00130 varResult = mgr.popVector();
00131 varResult->resize(vecSize);
00132 varResult->setToConstant(0.0);
00133
00134 for (int i=0; i<numTerms(); i++)
00135 {
00136 Tabs tab1;
00137 SUNDANCE_VERB_HIGH(tab1 << "term=" << i << " of " << numTerms());
00138 RCP<EvalVector> innerSum = mgr.popVector();
00139 innerSum->resize(vecSize);
00140 innerSum->setToConstant(0.0);
00141 const Array<DerivProduct>& sumOfDerivProducts = terms(i);
00142
00143 SUNDANCE_VERB_HIGH(tab1 << "inner sum init = " << *innerSum
00144 << ", num terms = " << terms(i).size());
00145
00146 for (int j=0; j<sumOfDerivProducts.size(); j++)
00147 {
00148 Tabs tab2;
00149 SUNDANCE_VERB_HIGH(tab2 << "dp=" << j << " of " << sumOfDerivProducts.size());
00150 const DerivProduct& p = sumOfDerivProducts[j];
00151 double cc = p.coeff();
00152 SUNDANCE_VERB_HIGH(tab2 << "multiplicity=" << cc);
00153 for (int k=0; k<p.numConstants(); k++)
00154 {
00155 const IndexPair& ip = p.constant(k);
00156 cc *= (*(constantArgResults[ip.argIndex()]))[ip.valueIndex()];
00157 }
00158 if (p.numVariables()==0)
00159 {
00160 innerSum->add_S(cc);
00161 }
00162 else if (p.numVariables()==1)
00163 {
00164 const IndexPair& ip = p.variable(0);
00165 const EvalVector* v
00166 = (*(vArgResults[ip.argIndex()]))[ip.valueIndex()].get();
00167 if (cc==1.0) innerSum->add_V(v);
00168 else innerSum->add_SV(cc, v);
00169 }
00170 else if (p.numVariables()==2)
00171 {
00172 const IndexPair& ip0 = p.variable(0);
00173 const EvalVector* v0
00174 = (*(vArgResults[ip0.argIndex()]))[ip0.valueIndex()].get();
00175 const IndexPair& ip1 = p.variable(1);
00176 const EvalVector* v1
00177 = (*(vArgResults[ip1.argIndex()]))[ip1.valueIndex()].get();
00178 if (cc==1.0) innerSum->add_VV(v0, v1);
00179 else innerSum->add_SVV(cc, v0, v1);
00180 }
00181 else
00182 {
00183 const IndexPair& ip0 = p.variable(0);
00184 const EvalVector* v0
00185 = (*(vArgResults[ip0.argIndex()]))[ip0.valueIndex()].get();
00186 RCP<EvalVector> tmp = v0->clone();
00187 for (int k=1; k<p.numVariables(); k++)
00188 {
00189 const IndexPair& ip1 = p.variable(k);
00190 const EvalVector* v1
00191 = (*(vArgResults[ip1.argIndex()]))[ip1.valueIndex()].get();
00192 tmp->multiply_V(v1);
00193 }
00194 if (cc==1.0) innerSum->add_V(tmp.get());
00195 else innerSum->add_SV(cc, tmp.get());
00196 }
00197 SUNDANCE_VERB_HIGH(tab2 << "inner sum=" << *innerSum);
00198 }
00199
00200 int adi = argDerivIndex(i);
00201 if (argDerivIsConstant(i))
00202 {
00203 const double& df_dq = constantArgDerivs[adi];
00204 varResult->add_SV(df_dq, innerSum.get());
00205 }
00206 else
00207 {
00208 const EvalVector* df_dq = varArgDerivs[adi].get();
00209 SUNDANCE_VERB_HIGH(tab1 << "arg deriv=" << *df_dq);
00210 varResult->add_VV(df_dq, innerSum.get());
00211 SUNDANCE_VERB_HIGH(tab1 << "outer sum=" << *varResult);
00212 }
00213 }
00214 }
00215
00216