42 #ifndef TPETRA_TSQR_ADAPTOR_UQ_PCE_HPP 43 #define TPETRA_TSQR_ADAPTOR_UQ_PCE_HPP 45 #include <Tpetra_ConfigDefs.hpp> 47 #ifdef HAVE_TPETRA_TSQR 51 # include <Tsqr_NodeTsqrFactory.hpp> 53 # include <Tsqr_DistTsqr.hpp> 56 # include <Tsqr_TeuchosMessenger.hpp> 57 # include <Tpetra_MultiVector.hpp> 58 # include <Teuchos_ParameterListAcceptorDefaultBase.hpp> 62 # include <Tpetra_TsqrAdaptor.hpp> 72 template <
class Storage,
class LO,
class GO,
class Node>
73 class TsqrAdaptor<
Tpetra::MultiVector< Sacado::UQ::PCE<Storage>,
75 public Teuchos::ParameterListAcceptorDefaultBase {
85 typedef Teuchos::SerialDenseMatrix<ordinal_type, scalar_type> dense_matrix_type;
86 typedef typename Teuchos::ScalarTraits<scalar_type>::magnitudeType magnitude_type;
90 typedef TSQR::NodeTsqrFactory<node_type, scalar_type, ordinal_type> node_tsqr_factory_type;
91 typedef typename node_tsqr_factory_type::node_tsqr_type node_tsqr_type;
92 typedef TSQR::DistTsqr<ordinal_type, scalar_type> dist_tsqr_type;
93 typedef TSQR::Tsqr<ordinal_type, scalar_type, node_tsqr_type> tsqr_type;
102 TsqrAdaptor (
const Teuchos::RCP<Teuchos::ParameterList>& plist) :
103 nodeTsqr_ (new node_tsqr_type),
104 distTsqr_ (new dist_tsqr_type),
105 tsqr_ (new tsqr_type (nodeTsqr_, distTsqr_)),
108 setParameterList (plist);
113 nodeTsqr_ (new node_tsqr_type),
114 distTsqr_ (new dist_tsqr_type),
115 tsqr_ (new tsqr_type (nodeTsqr_, distTsqr_)),
118 setParameterList (Teuchos::null);
121 Teuchos::RCP<const Teuchos::ParameterList>
122 getValidParameters ()
const 126 using Teuchos::ParameterList;
127 using Teuchos::parameterList;
129 if (defaultParams_.is_null()) {
130 RCP<ParameterList> params = parameterList (
"TSQR implementation");
131 params->set (
"NodeTsqr", *(nodeTsqr_->getValidParameters ()));
132 params->set (
"DistTsqr", *(distTsqr_->getValidParameters ()));
133 defaultParams_ = params;
135 return defaultParams_;
139 setParameterList (
const Teuchos::RCP<Teuchos::ParameterList>& plist)
141 using Teuchos::ParameterList;
142 using Teuchos::parameterList;
144 using Teuchos::sublist;
146 RCP<ParameterList> params = plist.is_null() ?
147 parameterList (*getValidParameters ()) : plist;
148 nodeTsqr_->setParameterList (sublist (params,
"NodeTsqr"));
149 distTsqr_->setParameterList (sublist (params,
"DistTsqr"));
151 this->setMyParamList (params);
176 factorExplicit (MV& A,
178 dense_matrix_type& R,
179 const bool forceNonnegativeDiagonal=
false)
190 getNonConstView (numRows, numCols, A_ptr, LDA, A);
191 getNonConstView (numRows, numCols, Q_ptr, LDQ, Q);
192 const bool contiguousCacheBlocks =
false;
193 tsqr_->factorExplicitRaw (numRows, numCols, A_ptr, LDA,
194 Q_ptr, LDQ, R.values (), R.stride (),
195 contiguousCacheBlocks,
196 forceNonnegativeDiagonal);
231 dense_matrix_type& R,
232 const magnitude_type& tol)
244 getNonConstView (numRows, numCols, Q_ptr, LDQ, Q);
245 const bool contiguousCacheBlocks =
false;
246 return tsqr_->revealRankRaw (numRows, numCols, Q_ptr, LDQ,
247 R.values (), R.stride (), tol,
248 contiguousCacheBlocks);
253 Teuchos::RCP<node_tsqr_type> nodeTsqr_;
256 Teuchos::RCP<dist_tsqr_type> distTsqr_;
259 Teuchos::RCP<tsqr_type> tsqr_;
262 mutable Teuchos::RCP<const Teuchos::ParameterList> defaultParams_;
288 prepareTsqr (
const MV& mv)
291 prepareDistTsqr (mv);
292 prepareNodeTsqr (mv);
301 prepareNodeTsqr (
const MV& mv)
303 node_tsqr_factory_type::prepareNodeTsqr (nodeTsqr_, mv.getMap()->getNode());
313 prepareDistTsqr (
const MV& mv)
316 using Teuchos::rcp_implicit_cast;
317 typedef TSQR::TeuchosMessenger<scalar_type> mess_type;
318 typedef TSQR::MessengerBase<scalar_type> base_mess_type;
320 RCP<const Teuchos::Comm<int> > comm = mv.getMap()->getComm();
321 RCP<mess_type> mess (
new mess_type (comm));
322 RCP<base_mess_type> messBase = rcp_implicit_cast<base_mess_type> (mess);
323 distTsqr_->init (messBase);
344 TEUCHOS_TEST_FOR_EXCEPTION
345 (! A.isConstantStride(), std::invalid_argument,
346 "TSQR does not currently support Tpetra::MultiVector " 347 "inputs that do not have constant stride.");
357 typedef typename MV::dual_view_type view_type;
358 typedef typename view_type::t_dev::array_type flat_array_type;
364 view_type pce_mv = A.getDualView();
365 flat_array_type flat_mv = pce_mv.d_view;
367 numRows =
static_cast<ordinal_type> (flat_mv.dimension_0 ());
368 numCols =
static_cast<ordinal_type> (flat_mv.dimension_1 ());
369 A_ptr = flat_mv.ptr_on_device ();
372 flat_mv.stride (strides);
378 #endif // HAVE_TPETRA_TSQR 380 #endif // TPETRA_TSQR_ADAPTOR_UQ_PCE_HPP
KokkosClassic::DefaultNode::DefaultNodeType Node