escript  Revision_
MUMPS.h
Go to the documentation of this file.
1 
2 /*****************************************************************************
3 *
4 * Copyright (c) 2003-2020 by The University of Queensland
5 * http://www.uq.edu.au
6 *
7 * Primary Business: Queensland, Australia
8 * Licensed under the Apache License, version 2.0
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Development until 2012 by Earth Systems Science Computational Center (ESSCC)
12 * Development 2012-2013 by School of Earth Sciences
13 * Development from 2014 by Centre for Geoscience Computing (GeoComp)
14 *
15 *****************************************************************************/
16 
17 
18 /****************************************************************************/
19 
20 /* Paso: interface to the MUMPS library */
21 
22 /****************************************************************************/
23 
24 #ifndef __PASO_MUMPS_H__
25 #define __PASO_MUMPS_H__
26 
27 #include "SparseMatrix.h"
28 #include "Options.h"
29 #include "PasoException.h"
30 
31 #ifdef ESYS_HAVE_MUMPS
32 // TODO: is this needed? #pragma push_macro("MPI_COMM_WORLD")
33 #if defined(MPI_COMM_WORLD)
34 #undef MPI_COMM_WORLD // breaks mumps_mpi.h, defined in escriptcore/src/EsysMPI.h
35 #endif
36 #include <mumps_mpi.h>
37 // TODO: is this needed? #pragma pop_macro("MPI_COMM_WORLD")
38 // #include <zmumps_c.h>
39 #include <dmumps_c.h>
40 #include <zmumps_c.h>
41 #define MUMPS_JOB_INIT -1
42 #define MUMPS_JOB_END -2
43 #define MUMPS_USE_COMM_WORLD -987654
44 #define ICNTL(I) icntl[(I)-1] // macro s.t. indices match documentation
45 
46 #ifdef _WIN32
47 #define NOMINMAX
48 #include <windows.h>
49 #undef NOMINMAX
50 #endif
51 #endif // ESYS_HAVE_MUMPS
52 
53 namespace paso {
54 
56  bool verbose;
57  std::stringstream ssExceptMsg;
58 #ifdef ESYS_HAVE_MUMPS
59  MUMPS_INT myid;
60 #ifdef _WIN32 // workaround for d/zmumps dll clash
61  HINSTANCE h_mumps_c_dll;
62 #endif
63 #endif // ESYS_HAVE_MUMPS
64 };
65 
66 template <typename T>
68  T* rhs;
69 };
70 
71 template <typename T>
73 
74 template <typename T>
75 void MUMPS_solve(SparseMatrix_ptr<T> A, T* out, T* in, dim_t numRefinements, bool verbose);
76 
77 template <typename T>
78 void MUMPS_print_list(const char* name, const T* vals, const int n, const int max_n=100);
79 
80 std::ostream& operator<<(std::ostream& os, const cplx_t& c);
81 
82 template <>
83 struct MUMPS_Handler<double> : MUMPS_Handler_t {
84  double* rhs;
85 #ifdef ESYS_HAVE_MUMPS
86  DMUMPS_STRUC_C id;
87  typedef double mumps_t;
88 #ifdef _WIN32 // workaround for d/zmumps dll clash
89  typedef HRESULT (CALLBACK* MUMPS_C_FUNC_PTR)(DMUMPS_STRUC_C*);
90  MUMPS_C_FUNC_PTR mumps_c;
91  const char* mumps_lib = "libdmumps";
92  const char* mumps_proc = "dmumps_c";
93 #else
94  void (*mumps_c)(DMUMPS_STRUC_C*) = &dmumps_c;
95 #endif
96 #endif // ESYS_HAVE_MUMPS
97 };
98 
99 template <>
102 #ifdef ESYS_HAVE_MUMPS
103  ZMUMPS_STRUC_C id;
104  typedef ZMUMPS_COMPLEX mumps_t;
105 #ifdef _WIN32 // workaround for dmumps/zdmumps dll clash
106  typedef HRESULT (CALLBACK* MUMPS_C_FUNC_PTR)(ZMUMPS_STRUC_C*);
107  MUMPS_C_FUNC_PTR mumps_c;
108  const char* mumps_lib = "libzmumps";
109  const char* mumps_proc = "zmumps_c";
110 #else
111  void (*mumps_c)(ZMUMPS_STRUC_C*) = &zmumps_c;
112 #endif
113 #endif // ESYS_HAVE_MUMPS
114 };
115 
117 template <typename T>
119 {
120  if (A && A->solver_p) {
121 #ifdef ESYS_HAVE_MUMPS
122  // Terminate instance.
123  auto pt = static_cast<MUMPS_Handler<T>*>(A->solver_p);
124  delete[] pt->rhs;
125  pt->id.job = MUMPS_JOB_END;
126  pt->mumps_c(&pt->id);
127 #ifdef _WIN32
128  FreeLibrary(pt->h_mumps_c_dll);
129 #endif
130  if (pt->myid == 0) {
131  std::string message = pt->ssExceptMsg.str();
132  if (!message.empty()) {
133  // terminating with solve error message
134  throw PasoException(message);
135  }
136  }
137  MUMPS_INT ierr = MPI_Finalize();
138  if (pt->verbose) {
139  std::cout << "MUMPS: instance terminated." << std::endl;
140  }
141  delete pt;
142 #endif
143  A->solver_p = NULL;
144  }
145 }
146 
148 template <typename T>
149 void MUMPS_solve(SparseMatrix_ptr<T> A, T* out, T* in, dim_t numRefinements, bool verbose)
150 {
151 #ifdef ESYS_HAVE_MUMPS
152  if (! (A->type & (MATRIX_FORMAT_OFFSET1 + MATRIX_FORMAT_BLK1)) ) {
153  throw PasoException("Paso: MUMPS requires CSR format with index offset 1 and block size 1.");
154  }
155 
156  auto pt = reinterpret_cast<MUMPS_Handler<T>*>(A->solver_p);
157  if (pt == NULL) {
158  pt = new MUMPS_Handler<T>;
159 #ifdef _WIN32
160  pt->h_mumps_c_dll = LoadLibrary(pt->mumps_lib);
161  if (pt->h_mumps_c_dll == NULL) {
162  std::stringstream ss;
163  ss << "Paso: MUMPS LoadLibrary failed - \"" << pt->mumps_lib << "\".";
164  throw PasoException(ss.str());
165  }
166  pt->mumps_c = (MUMPS_Handler<T>::MUMPS_C_FUNC_PTR)GetProcAddress(pt->h_mumps_c_dll, pt->mumps_proc);
167  if (pt->mumps_c == NULL) {
168  std::stringstream ss;
169  ss << "Paso: MUMPS GetProcAddress failed - \"" << pt->mumps_proc << "\".";
170  throw PasoException(ss.str());
171  }
172 #endif
173  A->solver_p = (void*) pt;
174  A->solver_package = PASO_MUMPS;
175  double time0 = escript::gettime();
176 
177  A->pattern->csrToHB(); // generate Harwell-Boeing format needed for MUMPS from CSR
178  MUMPS_INT n = A->numRows; // matrix order
179  MUMPS_INT8 nnz = A->pattern->len; // number non-zeros
180  MUMPS_INT* irn = reinterpret_cast<MUMPS_INT*>(A->pattern->hb_row); // row indices array
181  MUMPS_INT* jcn = reinterpret_cast<MUMPS_INT*>(A->pattern->hb_col); // col indices array
182  pt->verbose = verbose;
183  if (pt->verbose) {
184  std::cout << "MUMPS in ===>" << std::endl;
185  std::cout << "n = " << n << std::endl;
186  std::cout << "nnz = " << nnz << std::endl;
187  MUMPS_print_list("val", A->val, nnz);
188  MUMPS_print_list("in", in, n);
189  MUMPS_print_list("ptr", A->pattern->ptr, n+1);
190  MUMPS_print_list("index", A->pattern->index, nnz);
191  MUMPS_print_list("hb_row", A->pattern->hb_row, nnz);
192  MUMPS_print_list("hb_col", A->pattern->hb_col, nnz);
193  }
194  pt->rhs = new T[n];
195  std::memcpy(pt->rhs, in, n*sizeof(T));
196  MUMPS_INT ierr;
197  ierr = MPI_Init(NULL, NULL);
198  ierr = MPI_Comm_rank(MPI_COMM_WORLD, &pt->myid);
199 
200  // Initialize a MUMPS instance. Use MPI_COMM_WORLD
201  pt->id.comm_fortran = MUMPS_USE_COMM_WORLD;
202  pt->id.par = 1; pt->id.sym = 0;
203  pt->id.job = MUMPS_JOB_INIT;
204  pt->mumps_c(&pt->id);
205  // Define the problem on the host
206  if (pt->myid == 0) {
207  pt->id.n = n; pt->id.nnz = nnz;
208  pt->id.irn = irn; pt->id.jcn = jcn;
209  pt->id.a = reinterpret_cast<typename MUMPS_Handler<T>::mumps_t*>(A->val);
210  pt->id.rhs = reinterpret_cast<typename MUMPS_Handler<T>::mumps_t*>(pt->rhs);
211  }
212  if (!pt->verbose) {
213  // No outputs
214  pt->id.ICNTL(1)=-1; pt->id.ICNTL(2)=-1; pt->id.ICNTL(3)=-1; pt->id.ICNTL(4)=0;
215  }
216 
217  // Call the MUMPS package (analyse, factorization and solve).
218  pt->id.job = 6;
219  pt->mumps_c(&pt->id);
220  if (pt->id.infog[0] < 0) {
221  pt->ssExceptMsg << "(PROC " << pt->myid << ") MUMPS ERROR: INFOG(1)=" << pt->id.infog[0]
222  << ", INFOG(2)=" << pt->id.infog[1];
223  } else {
224  std::memcpy(out, reinterpret_cast<T*>(pt->rhs), n*sizeof(T));
225  if (pt->id.infog[0] > 0) {
226  std::cout << "(PROC " << pt->myid << ") MUMPS WARNING: INFOG(1)=" << pt->id.infog[0]
227  << ", INFOG(2)=" << pt->id.infog[1];
228  }
229  if (pt->verbose) {
230  std::cout << "MUMPS out ===>" << std::endl;
231  MUMPS_print_list("out", out, n);
232  std::cout << "MUMPS: factorization and solve completed (time = "
233  << escript::gettime()-time0 << ")." << std::endl;
234  }
235  }
236  }
237 #else // ESYS_HAVE_MUMPS
238  throw PasoException("Paso: Not compiled with MUMPS.");
239 #endif
240 }
241 
242 // output array data for debugging solver
243 // array length limit is 100 by default, use 0 for no limit
244 template <typename T>
245 void MUMPS_print_list(const char* name, const T* vals, const int n, const int max_n)
246 {
247  std::cout << name << " = [ ";
248  for (int i=0; i<n; i++) {
249  if (i > 0) {
250  std::cout << ", ";
251  }
252  std::cout << vals[i];
253  if (max_n > 0) {
254  if (i > max_n) {
255  std::cout << ", ...";
256  break;
257  }
258  }
259  }
260  std::cout << " ]" << std::endl;
261 }
262 
263 } // namespace paso
264 
265 #endif // __PASO_MUMPS_H__
266 
MATRIX_FORMAT_BLK1
#define MATRIX_FORMAT_BLK1
Definition: Paso.h:63
message
Definition: blocktools.h:70
paso::SparseMatrix_ptr
boost::shared_ptr< SparseMatrix< T > > SparseMatrix_ptr
Definition: SparseMatrix.h:37
paso::MUMPS_solve
void MUMPS_solve(SparseMatrix_ptr< T > A, T *out, T *in, dim_t numRefinements, bool verbose)
calls the solver
Definition: MUMPS.h:149
paso::MUMPS_Handler::rhs
T * rhs
Definition: MUMPS.h:68
paso::MUMPS_print_list
void MUMPS_print_list(const char *name, const T *vals, const int n, const int max_n=100)
Definition: MUMPS.h:245
paso::MUMPS_free
void MUMPS_free(SparseMatrix< T > *A)
frees any MUMPS related data from the matrix
Definition: MUMPS.h:118
paso::operator<<
std::ostream & operator<<(std::ostream &os, const cplx_t &c)
Definition: MUMPS.cpp:34
paso::MUMPS_Handler_t::verbose
bool verbose
Definition: MUMPS.h:56
paso::MUMPS_Handler_t::ssExceptMsg
std::stringstream ssExceptMsg
Definition: MUMPS.h:57
MATRIX_FORMAT_OFFSET1
#define MATRIX_FORMAT_OFFSET1
Definition: Paso.h:64
SparseMatrix.h
Options.h
Paso.h
MPI_COMM_WORLD
#define MPI_COMM_WORLD
Definition: EsysMPI.h:50
escript::DataTypes::dim_t
index_t dim_t
Definition: DataTypes.h:66
paso::MUMPS_Handler< double >::rhs
double * rhs
Definition: MUMPS.h:84
paso::PasoException
PasoException exception class.
Definition: PasoException.h:34
paso::MUMPS_Handler
Definition: MUMPS.h:67
escript::gettime
double gettime()
returns the current ticks for timing
Definition: EsysMPI.h:192
paso::MUMPS_Handler< cplx_t >::rhs
cplx_t * rhs
Definition: MUMPS.h:101
PASO_MUMPS
#define PASO_MUMPS
Definition: Options.h:57
paso::SparseMatrix::solver_p
void * solver_p
pointer to data needed by a solver
Definition: SparseMatrix.h:177
PasoException.h
paso::MUMPS_Handler_t
Definition: MUMPS.h:55
paso
Definition: BiCGStab.cpp:25
escript::DataTypes::cplx_t
std::complex< real_t > cplx_t
complex data type
Definition: DataTypes.h:55
paso::SparseMatrix
Definition: SparseMatrix.h:45
MUMPS.h