• Main Page
  • Related Pages
  • Namespaces
  • Classes
  • Files
  • File List
  • File Members

MLAPI_SerialMatrix.h

Go to the documentation of this file.
00001 #ifndef MLAPI_SERIALMATRIX_H
00002 #define MLAPI_SERIALMATRIX_H
00003 
00013 /* ******************************************************************** */
00014 /* See the file COPYRIGHT for a complete copyright notice, contact      */
00015 /* person and disclaimer.                                               */        
00016 /* ******************************************************************** */
00017 
00018 #include "ml_common.h"
00019 
00020 #include "ml_include.h"
00021 //#include "ml_lapack.h"
00022 #include "ml_comm.h"
00023 #include "MLAPI_Error.h"
00024 #include "MLAPI_Space.h"
00025 #include "MLAPI_Operator.h"
00026 #include "Epetra_Vector.h"
00027 #include "Epetra_RowMatrix.h"
00028 #include "Teuchos_RefCountPtr.hpp"
00029 #include <iomanip>
00030 
00031 namespace MLAPI {
00032 
00033 class Epetra_SerialMatrix : public Epetra_RowMatrix {
00034 
00035 public:
00036 
00037   Epetra_SerialMatrix(const Space& RowSpace, const Space& ColSpace)
00038   {
00039     NumMyRows_ = RowSpace.GetNumMyElements();
00040     NumMyCols_ = ColSpace.GetNumMyElements();
00041     
00042     NumMyNonzeros_ = 0;
00043     NumMyDiagonals_ = 0;
00044 
00045     if (GetNumProcs() != 1)
00046       ML_THROW("Class SerialMatrix can only be used for serial computations.", -1);
00047 
00048     RowMap_ = Teuchos::rcp(new Epetra_Map(NumMyRows_,0,GetEpetra_Comm()));
00049     ColMap_ = Teuchos::rcp(new Epetra_Map(NumMyCols_,0,GetEpetra_Comm()));
00050 
00051     ptr_.resize(NumMyRows_);
00052   }
00053 
00054   virtual int NumMyRowEntries(int MyRow, int & NumEntries) const
00055   {
00056 #ifdef MLAPI_CHECK
00057     if (MyRow < 0 || MyRow >= NumMyRows())
00058       ML_THROW("Requested not valid row (" + GetString(MyRow) +").", -1);
00059 #endif
00060     NumEntries = ptr_[MyRow].size();
00061 
00062     return(0);
00063   }
00064 
00065   virtual int MaxNumEntries() const
00066   {
00067     int res = 0, res_i = 0;
00068 
00069     for (int i = 0 ; i < NumMyRows() ; ++i) {
00070       NumMyRowEntries(i, res_i);
00071       if (res_i > res)
00072         res = res_i;
00073     }
00074 
00075     return(res);
00076   }
00077 
00078   virtual int ExtractMyRowCopy(int MyRow, int Length, int & NumEntries, 
00079                                double *Values, int * Indices) const
00080   {
00081     NumMyRowEntries(MyRow, NumEntries);
00082     if (Length < NumEntries) ML_CHK_ERR(-1);
00083     if (MyRow < 0 || MyRow >= NumMyRows()) 
00084       ML_CHK_ERR(-2);
00085 
00086     int count = 0;
00087     for (where_ = ptr_[MyRow].begin() ; where_ != ptr_[MyRow].end() ; ++where_) {
00088       Indices[count] = where_->first;
00089       Values[count] = where_->second;
00090       ++count;
00091     }
00092     return(0);
00093   }
00094 
00095   virtual int ExtractDiagonalCopy(Epetra_Vector & Diagonal) const
00096   {
00097 #ifdef MLAPI_CHECK
00098     if (!Diagonal.Map().SameAs(RowMatrixRowMap()))
00099       ML_CHK_ERR(-1);
00100 #endif
00101 
00102     Diagonal.PutScalar(0.0);
00103                        
00104     for (int i = 0 ; i < NumMyRows() ; ++i) {
00105       for (where_ = ptr_[i].begin() ; where_ != ptr_[i].end() ; ++where_) {
00106         if (where_->first == i) {
00107           Diagonal[i] = where_->second;
00108           break;
00109         }
00110       }
00111     }
00112     return(0);
00113   }
00114 
00115   virtual int Multiply(bool TransA, const Epetra_MultiVector& X, 
00116                        Epetra_MultiVector& Y) const
00117   {
00118 
00119     Y.PutScalar(0.0);
00120 
00121     if (!TransA) {
00122       for (int v = 0 ; v < X.NumVectors() ; ++v) {
00123         for (int i = 0 ; i < NumMyRows() ; ++i) {
00124           for (where_ = ptr_[i].begin() ; where_ != ptr_[i].end() ; ++where_) {
00125             Y[v][i] += (where_->second) * X[v][where_->first];
00126           }
00127         }
00128       }
00129     }
00130     else {
00131       for (int v = 0 ; v < X.NumVectors() ; ++v) {
00132         for (int i = 0 ; i < NumMyRows() ; ++i) {
00133           for (where_ = ptr_[i].begin() ; where_ != ptr_[i].end() ; ++where_) {
00134             Y[v][where_->first] += (where_->second) * X[v][i];
00135           }
00136         }
00137       }
00138     }
00139     
00140     return(0);
00141   }
00142 
00143   virtual int Solve(bool Upper, bool Trans, bool UnitDiagonal, const Epetra_MultiVector& X, 
00144                     Epetra_MultiVector& Y) const
00145   {
00146     ML_CHK_ERR(-1);
00147   }
00148 
00149   virtual int InvRowSums(Epetra_Vector& x) const
00150   {
00151     ML_CHK_ERR(-1);
00152   }
00153 
00154   virtual int LeftScale(const Epetra_Vector& x)
00155   {
00156     ML_CHK_ERR(-1);
00157   }
00158 
00159   virtual int InvColSums(Epetra_Vector& x) const
00160   {
00161     ML_CHK_ERR(-1);
00162   }
00163 
00164   virtual int RightScale(const Epetra_Vector& x)
00165   {
00166     ML_CHK_ERR(-1);
00167   }
00168 
00169   virtual bool Filled() const
00170   {
00171     return(true);
00172   }
00173 
00174   virtual double NormInf() const 
00175   {
00176     ML_CHK_ERR(-1);
00177   }
00178 
00179   virtual double NormOne() const
00180   {
00181     ML_CHK_ERR(-1);
00182   }
00183 
00184   virtual int NumGlobalNonzeros() const
00185   {
00186     return(NumMyNonzeros_);
00187   }
00188 
00189   virtual int NumGlobalRows() const
00190   {
00191      return(NumMyRows_);
00192   }
00193 
00194   virtual int NumGlobalCols() const
00195   {
00196     return(NumMyCols_);
00197   }
00198 
00199   virtual int NumGlobalDiagonals() const
00200   {
00201     return(NumMyDiagonals_);
00202   }
00203 
00204   virtual int NumMyNonzeros() const
00205   {
00206     return(NumMyNonzeros_);
00207   }
00208 
00209   virtual int NumMyRows() const
00210   {
00211     return(NumMyRows_);
00212   }
00213 
00214   virtual int NumMyCols() const
00215   {
00216     return(NumMyCols_);
00217   }
00218 
00219   virtual int NumMyDiagonals() const
00220   {
00221     return(NumMyDiagonals_);
00222   }
00223 
00224   virtual bool LowerTriangular() const
00225   {
00226     return(false);
00227   }
00228 
00229   virtual bool UpperTriangular() const 
00230   {
00231     return(false);
00232   }
00233 
00234   virtual const Epetra_Map & RowMatrixRowMap() const
00235   {
00236     return(*(RowMap_.get()));
00237   }
00238 
00239   virtual const Epetra_Map & RowMatrixColMap() const
00240   {
00241     return(*(ColMap_.get()));
00242   }
00243 
00244   virtual const Epetra_Import * RowMatrixImporter() const
00245   {
00246     return(0);
00247   }
00248 
00249   virtual const Epetra_Map& OperatorDomainMap() const
00250   {
00251     return(*(ColMap_.get()));
00252   }
00253 
00254   virtual const Epetra_Map& OperatorRangeMap() const
00255   {
00256     return(*(RowMap_.get()));
00257   }
00258 
00259   virtual const Epetra_Map& Map() const
00260   {
00261     return(*(ColMap_.get()));
00262   }
00263     
00265 
00266   virtual int SetUseTranspose(bool)
00267   {
00268     ML_CHK_ERR(-1);
00269   }
00270   
00271   virtual int Apply(const Epetra_MultiVector& X, Epetra_MultiVector& Y) const
00272   {
00273     return(Multiply(false, X, Y));
00274   }
00275 
00276   virtual int ApplyInverse(const Epetra_MultiVector& X,
00277                            Epetra_MultiVector& Y) const
00278   {
00279     ML_CHK_ERR(-1);
00280   }
00281 
00282   virtual const char* Label() const
00283   {
00284     return("Epetra_SerialMatrix");
00285   }
00286 
00287   virtual bool UseTranspose() const
00288   {
00289     return(false);
00290   }
00291 
00292   virtual bool HasNormInf() const
00293   {
00294     return(false);
00295   }
00296 
00297   virtual const Epetra_Comm& Comm() const
00298   {
00299     return(GetEpetra_Comm());
00300   }
00301 
00302   inline double& operator()(const int row, const int col)
00303   {
00304 #ifdef MLAPI_CHECK
00305     if (row < 0 || row >= NumMyRows())
00306       ML_THROW("Requested not valid row (" + GetString(row) +").", -1);
00307     if (col < 0 || row >= NumMyCols())
00308       ML_THROW("Requested not valid column (" + GetString(col) +").", -1);
00309 #endif
00310     where_ = ptr_[row].find(col);
00311 
00312     if (where_ != ptr_[row].end())
00313       // return a reference to this guy
00314       return(where_->second);
00315     else {
00316       ptr_[row][col] = 0.0;
00317       // track number of stored elements
00318       ++NumMyNonzeros_;
00319       // track number of diagonals 
00320       if (row == col)
00321         ++NumMyDiagonals_;
00322       // return a reference to this guy
00323       return(ptr_[row][col]);
00324     }
00325   }
00326            
00327 private:
00328 
00329   Epetra_SerialMatrix(const Epetra_SerialMatrix& rhs)
00330   {
00331   }
00332 
00333   Epetra_SerialMatrix& operator=(const Epetra_SerialMatrix& rhs)
00334   {
00335     return(*this);
00336   }
00337 
00338   int NumMyRows_;
00339   int NumMyCols_;
00340   int NumMyDiagonals_;
00341   int NumMyNonzeros_;
00342 
00343   mutable std::map<int,double>::iterator where_;
00344   mutable std::vector<std::map<int,double> > ptr_;
00345 
00346   Teuchos::RefCountPtr<Epetra_Map> RowMap_;
00347   Teuchos::RefCountPtr<Epetra_Map> ColMap_;
00348 
00349 }; // class Epetra_SerialMatrix
00350 
00351 class SerialMatrix : public Operator 
00352 {
00353 public:
00354   SerialMatrix()
00355   {
00356     Matrix_ = 0;
00357   }
00358 
00359   SerialMatrix& operator()(const SerialMatrix& rhs)
00360   {
00361     Matrix_ = rhs.Matrix_;
00362     Operator::operator=(rhs);
00363     return(*this);
00364   }
00365             
00366   SerialMatrix(const Space& RowSpace, const Space& ColSpace)
00367   {
00368     Matrix_ = new Epetra_SerialMatrix(RowSpace, ColSpace);
00369 
00370     Reshape(RowSpace, ColSpace, Matrix_, true);
00371   }
00372 
00373   inline double& operator()(const int row, const int col)
00374   {
00375     return((*Matrix_)(row, col));
00376   }
00377     
00378   std::ostream& Print(std::ostream& os, const bool verbose = true) const
00379   {
00380     int Length = Matrix_->MaxNumEntries();
00381     std::vector<double> Values(Length);
00382     std::vector<int>    Indices(Length);
00383 
00384     os << endl;
00385     os << "*** MLAPI::SerialMatrix ***" << endl;
00386     os << "Label = " << GetLabel() << endl;
00387     os << "Number of rows = " << Matrix_->NumMyRows() << endl;
00388     os << "Number of columns = " << Matrix_->NumMyCols() << endl;
00389     os << endl;
00390     os.width(10); os << "row ID";
00391     os.width(10); os << "col ID";
00392     os.width(30); os << "value";
00393     os << endl;
00394     os << endl;
00395 
00396     for (int i = 0 ; i < Matrix_->NumMyRows() ; ++i) {
00397       int NnzRow = 0;
00398       Matrix_->ExtractMyRowCopy(i, Length, NnzRow, &Values[0], &Indices[0]);
00399       for (int j = 0 ; j < NnzRow ; ++j) {
00400         os.width(10); os << i;
00401         os.width(10); os << Indices[j];
00402         os.width(30); os << Values[j];
00403         os << endl;
00404       }
00405     }
00406     return(os);
00407   }
00408 
00409 private:
00410   Epetra_SerialMatrix* Matrix_;
00411 };
00412 
00413 } // namespace MLAPI
00414 
00415 #endif // ifndef MLAPI_SERIALMATRIX_H