42 #ifndef __Tpetra_TsqrAdaptor_hpp 43 #define __Tpetra_TsqrAdaptor_hpp 49 #include "Tpetra_ConfigDefs.hpp" 51 #ifdef HAVE_TPETRA_TSQR 52 # include "Tsqr_NodeTsqrFactory.hpp" 54 # include "Tsqr_DistTsqr.hpp" 57 # include "Tsqr_TeuchosMessenger.hpp" 58 # include "Tpetra_MultiVector.hpp" 59 # include "Teuchos_ParameterListAcceptorDefaultBase.hpp" 88 class TsqrAdaptor :
public Teuchos::ParameterListAcceptorDefaultBase {
91 typedef typename MV::local_ordinal_type ordinal_type;
93 typedef Teuchos::SerialDenseMatrix<ordinal_type, scalar_type> dense_matrix_type;
94 typedef typename Teuchos::ScalarTraits<scalar_type>::magnitudeType magnitude_type;
98 typedef TSQR::NodeTsqrFactory<node_type, scalar_type, ordinal_type> node_tsqr_factory_type;
99 typedef typename node_tsqr_factory_type::node_tsqr_type node_tsqr_type;
100 typedef TSQR::DistTsqr<ordinal_type, scalar_type> dist_tsqr_type;
101 typedef TSQR::Tsqr<ordinal_type, scalar_type, node_tsqr_type> tsqr_type;
110 TsqrAdaptor (
const Teuchos::RCP<Teuchos::ParameterList>& plist) :
111 nodeTsqr_ (new node_tsqr_type),
112 distTsqr_ (new dist_tsqr_type),
113 tsqr_ (new tsqr_type (nodeTsqr_, distTsqr_)),
116 setParameterList (plist);
121 nodeTsqr_ (new node_tsqr_type),
122 distTsqr_ (new dist_tsqr_type),
123 tsqr_ (new tsqr_type (nodeTsqr_, distTsqr_)),
126 setParameterList (Teuchos::null);
130 Teuchos::RCP<const Teuchos::ParameterList>
131 getValidParameters ()
const 135 using Teuchos::ParameterList;
136 using Teuchos::parameterList;
138 if (defaultParams_.is_null()) {
139 RCP<ParameterList> params = parameterList (
"TSQR implementation");
140 params->set (
"NodeTsqr", *(nodeTsqr_->getValidParameters ()));
141 params->set (
"DistTsqr", *(distTsqr_->getValidParameters ()));
142 defaultParams_ = params;
144 return defaultParams_;
173 setParameterList (
const Teuchos::RCP<Teuchos::ParameterList>& plist)
175 using Teuchos::ParameterList;
176 using Teuchos::parameterList;
178 using Teuchos::sublist;
180 RCP<ParameterList> params = plist.is_null() ?
181 parameterList (*getValidParameters ()) : plist;
182 nodeTsqr_->setParameterList (sublist (params,
"NodeTsqr"));
183 distTsqr_->setParameterList (sublist (params,
"DistTsqr"));
185 this->setMyParamList (params);
210 factorExplicit (MV& A,
212 dense_matrix_type& R,
213 const bool forceNonnegativeDiagonal=
false)
215 TEUCHOS_TEST_FOR_EXCEPTION
216 (! A.isConstantStride (), std::invalid_argument,
"TsqrAdaptor::" 217 "factorExplicit: Input MultiVector A must have constant stride.");
218 TEUCHOS_TEST_FOR_EXCEPTION
219 (! Q.isConstantStride (), std::invalid_argument,
"TsqrAdaptor::" 220 "factorExplicit: Input MultiVector Q must have constant stride.");
225 A.template sync<Kokkos::HostSpace> ();
226 A.template modify<Kokkos::HostSpace> ();
227 Q.template sync<Kokkos::HostSpace> ();
228 Q.template modify<Kokkos::HostSpace> ();
229 auto A_view = A.template getLocalView<Kokkos::HostSpace> ();
230 auto Q_view = Q.template getLocalView<Kokkos::HostSpace> ();
232 reinterpret_cast<scalar_type*
> (A_view.ptr_on_device ());
234 reinterpret_cast<scalar_type*
> (Q_view.ptr_on_device ());
235 const bool contiguousCacheBlocks =
false;
236 tsqr_->factorExplicitRaw (A_view.dimension_0 (),
237 A_view.dimension_1 (),
238 A_ptr, A.getStride (),
239 Q_ptr, Q.getStride (),
240 R.values (), R.stride (),
241 contiguousCacheBlocks,
242 forceNonnegativeDiagonal);
277 dense_matrix_type& R,
278 const magnitude_type& tol)
280 TEUCHOS_TEST_FOR_EXCEPTION
281 (! Q.isConstantStride (), std::invalid_argument,
"TsqrAdaptor::" 282 "revealRank: Input MultiVector Q must have constant stride.");
288 Q.template sync<Kokkos::HostSpace> ();
289 Q.template modify<Kokkos::HostSpace> ();
290 auto Q_view = Q.template getLocalView<Kokkos::HostSpace> ();
292 reinterpret_cast<scalar_type*
> (Q_view.ptr_on_device ());
293 const bool contiguousCacheBlocks =
false;
294 return tsqr_->revealRankRaw (Q_view.dimension_0 (),
295 Q_view.dimension_1 (),
296 Q_ptr, Q.getStride (),
297 R.values (), R.stride (),
298 tol, contiguousCacheBlocks);
303 Teuchos::RCP<node_tsqr_type> nodeTsqr_;
306 Teuchos::RCP<dist_tsqr_type> distTsqr_;
309 Teuchos::RCP<tsqr_type> tsqr_;
312 mutable Teuchos::RCP<const Teuchos::ParameterList> defaultParams_;
338 prepareTsqr (
const MV& mv)
341 prepareDistTsqr (mv);
342 prepareNodeTsqr (mv);
351 prepareNodeTsqr (
const MV& mv)
353 node_tsqr_factory_type::prepareNodeTsqr (nodeTsqr_, mv.getMap()->getNode());
363 prepareDistTsqr (
const MV& mv)
366 using Teuchos::rcp_implicit_cast;
367 typedef TSQR::TeuchosMessenger<scalar_type> mess_type;
368 typedef TSQR::MessengerBase<scalar_type> base_mess_type;
370 RCP<const Teuchos::Comm<int> > comm = mv.getMap()->getComm();
371 RCP<mess_type> mess (
new mess_type (comm));
372 RCP<base_mess_type> messBase = rcp_implicit_cast<base_mess_type> (mess);
373 distTsqr_->init (messBase);
379 #endif // HAVE_TPETRA_TSQR 381 #endif // __Tpetra_TsqrAdaptor_hpp Namespace Tpetra contains the class and methods constituting the Tpetra library.
KokkosClassic::DefaultNode::DefaultNodeType node_type
Default value of Node template parameter.
double scalar_type
Default value of Scalar template parameter.