Actual source code: fnsqrt.c

slepc-3.17.2 2022-08-09
Report Typos and Errors
  1: /*
  2:    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  3:    SLEPc - Scalable Library for Eigenvalue Problem Computations
  4:    Copyright (c) 2002-, Universitat Politecnica de Valencia, Spain

  6:    This file is part of SLEPc.
  7:    SLEPc is distributed under a 2-clause BSD license (see LICENSE).
  8:    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  9: */
 10: /*
 11:    Square root function  sqrt(x)
 12: */

 14: #include <slepc/private/fnimpl.h>
 15: #include <slepcblaslapack.h>

 17: PetscErrorCode FNEvaluateFunction_Sqrt(FN fn,PetscScalar x,PetscScalar *y)
 18: {
 19: #if !defined(PETSC_USE_COMPLEX)
 21: #endif
 22:   *y = PetscSqrtScalar(x);
 23:   PetscFunctionReturn(0);
 24: }

 26: PetscErrorCode FNEvaluateDerivative_Sqrt(FN fn,PetscScalar x,PetscScalar *y)
 27: {
 29: #if !defined(PETSC_USE_COMPLEX)
 31: #endif
 32:   *y = 1.0/(2.0*PetscSqrtScalar(x));
 33:   PetscFunctionReturn(0);
 34: }

 36: PetscErrorCode FNEvaluateFunctionMat_Sqrt_Schur(FN fn,Mat A,Mat B)
 37: {
 38:   PetscBLASInt   n=0;
 39:   PetscScalar    *T;
 40:   PetscInt       m;

 42:   if (A!=B) MatCopy(A,B,SAME_NONZERO_PATTERN);
 43:   MatDenseGetArray(B,&T);
 44:   MatGetSize(A,&m,NULL);
 45:   PetscBLASIntCast(m,&n);
 46:   FNSqrtmSchur(fn,n,T,n,PETSC_FALSE);
 47:   MatDenseRestoreArray(B,&T);
 48:   PetscFunctionReturn(0);
 49: }

 51: PetscErrorCode FNEvaluateFunctionMatVec_Sqrt_Schur(FN fn,Mat A,Vec v)
 52: {
 53:   PetscBLASInt   n=0;
 54:   PetscScalar    *T;
 55:   PetscInt       m;
 56:   Mat            B;

 58:   FN_AllocateWorkMat(fn,A,&B);
 59:   MatDenseGetArray(B,&T);
 60:   MatGetSize(A,&m,NULL);
 61:   PetscBLASIntCast(m,&n);
 62:   FNSqrtmSchur(fn,n,T,n,PETSC_TRUE);
 63:   MatDenseRestoreArray(B,&T);
 64:   MatGetColumnVector(B,v,0);
 65:   FN_FreeWorkMat(fn,&B);
 66:   PetscFunctionReturn(0);
 67: }

 69: PetscErrorCode FNEvaluateFunctionMat_Sqrt_DBP(FN fn,Mat A,Mat B)
 70: {
 71:   PetscBLASInt   n=0;
 72:   PetscScalar    *T;
 73:   PetscInt       m;

 75:   if (A!=B) MatCopy(A,B,SAME_NONZERO_PATTERN);
 76:   MatDenseGetArray(B,&T);
 77:   MatGetSize(A,&m,NULL);
 78:   PetscBLASIntCast(m,&n);
 79:   FNSqrtmDenmanBeavers(fn,n,T,n,PETSC_FALSE);
 80:   MatDenseRestoreArray(B,&T);
 81:   PetscFunctionReturn(0);
 82: }

 84: PetscErrorCode FNEvaluateFunctionMat_Sqrt_NS(FN fn,Mat A,Mat B)
 85: {
 86:   PetscBLASInt   n=0;
 87:   PetscScalar    *Ba;
 88:   PetscInt       m;

 90:   if (A!=B) MatCopy(A,B,SAME_NONZERO_PATTERN);
 91:   MatDenseGetArray(B,&Ba);
 92:   MatGetSize(A,&m,NULL);
 93:   PetscBLASIntCast(m,&n);
 94:   FNSqrtmNewtonSchulz(fn,n,Ba,n,PETSC_FALSE);
 95:   MatDenseRestoreArray(B,&Ba);
 96:   PetscFunctionReturn(0);
 97: }

 99: #define MAXIT 50

101: /*
102:    Computes the principal square root of the matrix A using the
103:    Sadeghi iteration. A is overwritten with sqrtm(A).
104:  */
105: PetscErrorCode FNSqrtmSadeghi(FN fn,PetscBLASInt n,PetscScalar *A,PetscBLASInt ld)
106: {
107:   PetscScalar    *M,*M2,*G,*X=A,*work,work1,sqrtnrm;
108:   PetscScalar    szero=0.0,sone=1.0,smfive=-5.0,s1d16=1.0/16.0;
109:   PetscReal      tol,Mres=0.0,nrm,rwork[1],done=1.0;
110:   PetscInt       i,it;
111:   PetscBLASInt   N,*piv=NULL,info,lwork=0,query=-1,one=1,zero=0;
112:   PetscBool      converged=PETSC_FALSE;
113:   unsigned int   ftz;

115:   N = n*n;
116:   tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;
117:   SlepcSetFlushToZero(&ftz);

119:   /* query work size */
120:   PetscStackCallBLAS("LAPACKgetri",LAPACKgetri_(&n,A,&ld,piv,&work1,&query,&info));
121:   PetscBLASIntCast((PetscInt)PetscRealPart(work1),&lwork);

123:   PetscMalloc5(N,&M,N,&M2,N,&G,lwork,&work,n,&piv);
124:   PetscArraycpy(M,A,N);

126:   /* scale M */
127:   nrm = LAPACKlange_("fro",&n,&n,M,&n,rwork);
128:   if (nrm>1.0) {
129:     sqrtnrm = PetscSqrtReal(nrm);
130:     PetscStackCallBLAS("LAPACKlascl",LAPACKlascl_("G",&zero,&zero,&nrm,&done,&N,&one,M,&N,&info));
131:     SlepcCheckLapackInfo("lascl",info);
132:     tol *= nrm;
133:   }
134:   PetscInfo(fn,"||A||_F = %g, new tol: %g\n",(double)nrm,(double)tol);

136:   /* X = I */
137:   PetscArrayzero(X,N);
138:   for (i=0;i<n;i++) X[i+i*ld] = 1.0;

140:   for (it=0;it<MAXIT && !converged;it++) {

142:     /* G = (5/16)*I + (1/16)*M*(15*I-5*M+M*M) */
143:     PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,M,&ld,M,&ld,&szero,M2,&ld));
144:     PetscStackCallBLAS("BLASaxpy",BLASaxpy_(&N,&smfive,M,&one,M2,&one));
145:     for (i=0;i<n;i++) M2[i+i*ld] += 15.0;
146:     PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&s1d16,M,&ld,M2,&ld,&szero,G,&ld));
147:     for (i=0;i<n;i++) G[i+i*ld] += 5.0/16.0;

149:     /* X = X*G */
150:     PetscArraycpy(M2,X,N);
151:     PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,M2,&ld,G,&ld,&szero,X,&ld));

153:     /* M = M*inv(G*G) */
154:     PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,G,&ld,G,&ld,&szero,M2,&ld));
155:     PetscStackCallBLAS("LAPACKgetrf",LAPACKgetrf_(&n,&n,M2,&ld,piv,&info));
156:     SlepcCheckLapackInfo("getrf",info);
157:     PetscStackCallBLAS("LAPACKgetri",LAPACKgetri_(&n,M2,&ld,piv,work,&lwork,&info));
158:     SlepcCheckLapackInfo("getri",info);

160:     PetscArraycpy(G,M,N);
161:     PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,G,&ld,M2,&ld,&szero,M,&ld));

163:     /* check ||I-M|| */
164:     PetscArraycpy(M2,M,N);
165:     for (i=0;i<n;i++) M2[i+i*ld] -= 1.0;
166:     Mres = LAPACKlange_("fro",&n,&n,M2,&n,rwork);
168:     if (Mres<=tol) converged = PETSC_TRUE;
169:     PetscInfo(fn,"it: %" PetscInt_FMT " res: %g\n",it,(double)Mres);
170:     PetscLogFlops(8.0*n*n*n+2.0*n*n+2.0*n*n*n/3.0+4.0*n*n*n/3.0+2.0*n*n*n+2.0*n*n);
171:   }


175:   /* undo scaling */
176:   if (nrm>1.0) PetscStackCallBLAS("BLASscal",BLASscal_(&N,&sqrtnrm,A,&one));

178:   PetscFree5(M,M2,G,work,piv);
179:   SlepcResetFlushToZero(&ftz);
180:   PetscFunctionReturn(0);
181: }

183: #if defined(PETSC_HAVE_CUDA)
184: #include "../src/sys/classes/fn/impls/cuda/fnutilcuda.h"
185: #include <slepccublas.h>

187: #if defined(PETSC_HAVE_MAGMA)
188: #include <slepcmagma.h>

190: /*
191:  * Matrix square root by Sadeghi iteration. CUDA version.
192:  * Computes the principal square root of the matrix T using the
193:  * Sadeghi iteration. T is overwritten with sqrtm(T).
194:  */
195: PetscErrorCode FNSqrtmSadeghi_CUDAm(FN fn,PetscBLASInt n,PetscScalar *A,PetscBLASInt ld)
196: {
197:   PetscScalar        *d_X,*d_M,*d_M2,*d_G,*d_work,alpha;
198:   const PetscScalar  szero=0.0,sone=1.0,smfive=-5.0,s15=15.0,s1d16=1.0/16.0;
199:   PetscReal          tol,Mres=0.0,nrm,sqrtnrm=1.0;
200:   PetscInt           it,nb,lwork;
201:   PetscBLASInt       *piv,N;
202:   const PetscBLASInt one=1;
203:   PetscBool          converged=PETSC_FALSE;
204:   cublasHandle_t     cublasv2handle;

206:   PetscDeviceInitialize(PETSC_DEVICE_CUDA); /* For CUDA event timers */
207:   PetscCUBLASGetHandle(&cublasv2handle);
208:   magma_init();
209:   N = n*n;
210:   tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;

212:   PetscMalloc1(n,&piv);
213:   cudaMalloc((void **)&d_X,sizeof(PetscScalar)*N);
214:   cudaMalloc((void **)&d_M,sizeof(PetscScalar)*N);
215:   cudaMalloc((void **)&d_M2,sizeof(PetscScalar)*N);
216:   cudaMalloc((void **)&d_G,sizeof(PetscScalar)*N);

218:   nb = magma_get_xgetri_nb(n);
219:   lwork = nb*n;
220:   cudaMalloc((void **)&d_work,sizeof(PetscScalar)*lwork);
221:   PetscLogGpuTimeBegin();

223:   /* M = A */
224:   cudaMemcpy(d_M,A,sizeof(PetscScalar)*N,cudaMemcpyHostToDevice);

226:   /* scale M */
227:   cublasXnrm2(cublasv2handle,N,d_M,one,&nrm);
228:   if (nrm>1.0) {
229:     sqrtnrm = PetscSqrtReal(nrm);
230:     alpha = 1.0/nrm;
231:     cublasXscal(cublasv2handle,N,&alpha,d_M,one);
232:     tol *= nrm;
233:   }
234:   PetscInfo(fn,"||A||_F = %g, new tol: %g\n",(double)nrm,(double)tol);

236:   /* X = I */
237:   cudaMemset(d_X,0,sizeof(PetscScalar)*N);
238:   set_diagonal(n,d_X,ld,sone);

240:   for (it=0;it<MAXIT && !converged;it++) {

242:     /* G = (5/16)*I + (1/16)*M*(15*I-5*M+M*M) */
243:     cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&sone,d_M,ld,d_M,ld,&szero,d_M2,ld);
244:     cublasXaxpy(cublasv2handle,N,&smfive,d_M,one,d_M2,one);
245:     shift_diagonal(n,d_M2,ld,s15);
246:     cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&s1d16,d_M,ld,d_M2,ld,&szero,d_G,ld);
247:     shift_diagonal(n,d_G,ld,5.0/16.0);

249:     /* X = X*G */
250:     cudaMemcpy(d_M2,d_X,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice);
251:     cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&sone,d_M2,ld,d_G,ld,&szero,d_X,ld);

253:     /* M = M*inv(G*G) */
254:     cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&sone,d_G,ld,d_G,ld,&szero,d_M2,ld);
255:     /* magma */
256:     magma_xgetrf_gpu,n,n,d_M2,ld,piv;
257:     magma_xgetri_gpu,n,d_M2,ld,piv,d_work,lwork;
258:     /* magma */
259:     cudaMemcpy(d_G,d_M,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice);
260:     cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&sone,d_G,ld,d_M2,ld,&szero,d_M,ld);

262:     /* check ||I-M|| */
263:     cudaMemcpy(d_M2,d_M,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice);
264:     shift_diagonal(n,d_M2,ld,-1.0);
265:     cublasXnrm2(cublasv2handle,N,d_M2,one,&Mres);
267:     if (Mres<=tol) converged = PETSC_TRUE;
268:     PetscInfo(fn,"it: %" PetscInt_FMT " res: %g\n",it,(double)Mres);
269:     PetscLogGpuFlops(8.0*n*n*n+2.0*n*n+2.0*n*n*n/3.0+4.0*n*n*n/3.0+2.0*n*n*n+2.0*n*n);
270:   }


274:   if (nrm>1.0) {
275:     alpha = sqrtnrm;
276:     cublasXscal(cublasv2handle,N,&alpha,d_X,one);
277:   }
278:   cudaMemcpy(A,d_X,sizeof(PetscScalar)*N,cudaMemcpyDeviceToHost);
279:   PetscLogGpuTimeEnd();

281:   cudaFree(d_X);
282:   cudaFree(d_M);
283:   cudaFree(d_M2);
284:   cudaFree(d_G);
285:   cudaFree(d_work);
286:   PetscFree(piv);

288:   magma_finalize();
289:   PetscFunctionReturn(0);
290: }
291: #endif /* PETSC_HAVE_MAGMA */
292: #endif /* PETSC_HAVE_CUDA */

294: PetscErrorCode FNEvaluateFunctionMat_Sqrt_Sadeghi(FN fn,Mat A,Mat B)
295: {
296:   PetscBLASInt   n=0;
297:   PetscScalar    *Ba;
298:   PetscInt       m;

300:   if (A!=B) MatCopy(A,B,SAME_NONZERO_PATTERN);
301:   MatDenseGetArray(B,&Ba);
302:   MatGetSize(A,&m,NULL);
303:   PetscBLASIntCast(m,&n);
304:   FNSqrtmSadeghi(fn,n,Ba,n);
305:   MatDenseRestoreArray(B,&Ba);
306:   PetscFunctionReturn(0);
307: }

309: #if defined(PETSC_HAVE_CUDA)
310: PetscErrorCode FNEvaluateFunctionMat_Sqrt_NS_CUDA(FN fn,Mat A,Mat B)
311: {
312:   PetscBLASInt   n=0;
313:   PetscScalar    *Ba;
314:   PetscInt       m;

316:   if (A!=B) MatCopy(A,B,SAME_NONZERO_PATTERN);
317:   MatDenseGetArray(B,&Ba);
318:   MatGetSize(A,&m,NULL);
319:   PetscBLASIntCast(m,&n);
320:   FNSqrtmNewtonSchulz_CUDA(fn,n,Ba,n,PETSC_FALSE);
321:   MatDenseRestoreArray(B,&Ba);
322:   PetscFunctionReturn(0);
323: }

325: #if defined(PETSC_HAVE_MAGMA)
326: PetscErrorCode FNEvaluateFunctionMat_Sqrt_DBP_CUDAm(FN fn,Mat A,Mat B)
327: {
328:   PetscBLASInt   n=0;
329:   PetscScalar    *T;
330:   PetscInt       m;

332:   if (A!=B) MatCopy(A,B,SAME_NONZERO_PATTERN);
333:   MatDenseGetArray(B,&T);
334:   MatGetSize(A,&m,NULL);
335:   PetscBLASIntCast(m,&n);
336:   FNSqrtmDenmanBeavers_CUDAm(fn,n,T,n,PETSC_FALSE);
337:   MatDenseRestoreArray(B,&T);
338:   PetscFunctionReturn(0);
339: }

341: PetscErrorCode FNEvaluateFunctionMat_Sqrt_Sadeghi_CUDAm(FN fn,Mat A,Mat B)
342: {
343:   PetscBLASInt   n=0;
344:   PetscScalar    *Ba;
345:   PetscInt       m;

347:   if (A!=B) MatCopy(A,B,SAME_NONZERO_PATTERN);
348:   MatDenseGetArray(B,&Ba);
349:   MatGetSize(A,&m,NULL);
350:   PetscBLASIntCast(m,&n);
351:   FNSqrtmSadeghi_CUDAm(fn,n,Ba,n);
352:   MatDenseRestoreArray(B,&Ba);
353:   PetscFunctionReturn(0);
354: }
355: #endif /* PETSC_HAVE_MAGMA */
356: #endif /* PETSC_HAVE_CUDA */

358: PetscErrorCode FNView_Sqrt(FN fn,PetscViewer viewer)
359: {
360:   PetscBool      isascii;
361:   char           str[50];
362:   const char     *methodname[] = {
363:                   "Schur method for the square root",
364:                   "Denman-Beavers (product form)",
365:                   "Newton-Schulz iteration",
366:                   "Sadeghi iteration"
367: #if defined(PETSC_HAVE_CUDA)
368:                  ,"Newton-Schulz iteration CUDA"
369: #if defined(PETSC_HAVE_MAGMA)
370:                  ,"Denman-Beavers (product form) CUDA/MAGMA",
371:                   "Sadeghi iteration CUDA/MAGMA"
372: #endif
373: #endif
374:   };
375:   const int      nmeth=sizeof(methodname)/sizeof(methodname[0]);

377:   PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&isascii);
378:   if (isascii) {
379:     if (fn->beta==(PetscScalar)1.0) {
380:       if (fn->alpha==(PetscScalar)1.0) PetscViewerASCIIPrintf(viewer,"  Square root: sqrt(x)\n");
381:       else {
382:         SlepcSNPrintfScalar(str,sizeof(str),fn->alpha,PETSC_TRUE);
383:         PetscViewerASCIIPrintf(viewer,"  Square root: sqrt(%s*x)\n",str);
384:       }
385:     } else {
386:       SlepcSNPrintfScalar(str,sizeof(str),fn->beta,PETSC_TRUE);
387:       if (fn->alpha==(PetscScalar)1.0) PetscViewerASCIIPrintf(viewer,"  Square root: %s*sqrt(x)\n",str);
388:       else {
389:         PetscViewerASCIIPrintf(viewer,"  Square root: %s",str);
390:         PetscViewerASCIIUseTabs(viewer,PETSC_FALSE);
391:         SlepcSNPrintfScalar(str,sizeof(str),fn->alpha,PETSC_TRUE);
392:         PetscViewerASCIIPrintf(viewer,"*sqrt(%s*x)\n",str);
393:         PetscViewerASCIIUseTabs(viewer,PETSC_TRUE);
394:       }
395:     }
396:     if (fn->method<nmeth) PetscViewerASCIIPrintf(viewer,"  computing matrix functions with: %s\n",methodname[fn->method]);
397:   }
398:   PetscFunctionReturn(0);
399: }

401: SLEPC_EXTERN PetscErrorCode FNCreate_Sqrt(FN fn)
402: {
403:   fn->ops->evaluatefunction          = FNEvaluateFunction_Sqrt;
404:   fn->ops->evaluatederivative        = FNEvaluateDerivative_Sqrt;
405:   fn->ops->evaluatefunctionmat[0]    = FNEvaluateFunctionMat_Sqrt_Schur;
406:   fn->ops->evaluatefunctionmat[1]    = FNEvaluateFunctionMat_Sqrt_DBP;
407:   fn->ops->evaluatefunctionmat[2]    = FNEvaluateFunctionMat_Sqrt_NS;
408:   fn->ops->evaluatefunctionmat[3]    = FNEvaluateFunctionMat_Sqrt_Sadeghi;
409: #if defined(PETSC_HAVE_CUDA)
410:   fn->ops->evaluatefunctionmat[4]    = FNEvaluateFunctionMat_Sqrt_NS_CUDA;
411: #if defined(PETSC_HAVE_MAGMA)
412:   fn->ops->evaluatefunctionmat[5]    = FNEvaluateFunctionMat_Sqrt_DBP_CUDAm;
413:   fn->ops->evaluatefunctionmat[6]    = FNEvaluateFunctionMat_Sqrt_Sadeghi_CUDAm;
414: #endif /* PETSC_HAVE_MAGMA */
415: #endif /* PETSC_HAVE_CUDA */
416:   fn->ops->evaluatefunctionmatvec[0] = FNEvaluateFunctionMatVec_Sqrt_Schur;
417:   fn->ops->view                      = FNView_Sqrt;
418:   PetscFunctionReturn(0);
419: }