Actual source code: fnsqrt.c

slepc-3.11.2 2019-07-30
Report Typos and Errors
  1: /*
  2:    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  3:    SLEPc - Scalable Library for Eigenvalue Problem Computations
  4:    Copyright (c) 2002-2019, Universitat Politecnica de Valencia, Spain

  6:    This file is part of SLEPc.
  7:    SLEPc is distributed under a 2-clause BSD license (see LICENSE).
  8:    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  9: */
 10: /*
 11:    Square root function  sqrt(x)
 12: */

 14: #include <slepc/private/fnimpl.h>      /*I "slepcfn.h" I*/
 15: #include <slepcblaslapack.h>

 17: PetscErrorCode FNEvaluateFunction_Sqrt(FN fn,PetscScalar x,PetscScalar *y)
 18: {
 20: #if !defined(PETSC_USE_COMPLEX)
 21:   if (x<0.0) SETERRQ(PETSC_COMM_SELF,1,"Function not defined in the requested value");
 22: #endif
 23:   *y = PetscSqrtScalar(x);
 24:   return(0);
 25: }

 27: PetscErrorCode FNEvaluateDerivative_Sqrt(FN fn,PetscScalar x,PetscScalar *y)
 28: {
 30:   if (x==0.0) SETERRQ(PETSC_COMM_SELF,1,"Derivative not defined in the requested value");
 31: #if !defined(PETSC_USE_COMPLEX)
 32:   if (x<0.0) SETERRQ(PETSC_COMM_SELF,1,"Derivative not defined in the requested value");
 33: #endif
 34:   *y = 1.0/(2.0*PetscSqrtScalar(x));
 35:   return(0);
 36: }

 38: PetscErrorCode FNEvaluateFunctionMat_Sqrt_Schur(FN fn,Mat A,Mat B)
 39: {
 41:   PetscBLASInt   n;
 42:   PetscScalar    *T;
 43:   PetscInt       m;

 46:   if (A!=B) { MatCopy(A,B,SAME_NONZERO_PATTERN); }
 47:   MatDenseGetArray(B,&T);
 48:   MatGetSize(A,&m,NULL);
 49:   PetscBLASIntCast(m,&n);
 50:   SlepcSqrtmSchur(n,T,n,PETSC_FALSE);
 51:   MatDenseRestoreArray(B,&T);
 52:   return(0);
 53: }

 55: PetscErrorCode FNEvaluateFunctionMatVec_Sqrt_Schur(FN fn,Mat A,Vec v)
 56: {
 58:   PetscBLASInt   n;
 59:   PetscScalar    *T;
 60:   PetscInt       m;
 61:   Mat            B;

 64:   FN_AllocateWorkMat(fn,A,&B);
 65:   MatDenseGetArray(B,&T);
 66:   MatGetSize(A,&m,NULL);
 67:   PetscBLASIntCast(m,&n);
 68:   SlepcSqrtmSchur(n,T,n,PETSC_TRUE);
 69:   MatDenseRestoreArray(B,&T);
 70:   MatGetColumnVector(B,v,0);
 71:   FN_FreeWorkMat(fn,&B);
 72:   return(0);
 73: }

 75: PetscErrorCode FNEvaluateFunctionMat_Sqrt_DBP(FN fn,Mat A,Mat B)
 76: {
 78:   PetscBLASInt   n;
 79:   PetscScalar    *T;
 80:   PetscInt       m;

 83:   if (A!=B) { MatCopy(A,B,SAME_NONZERO_PATTERN); }
 84:   MatDenseGetArray(B,&T);
 85:   MatGetSize(A,&m,NULL);
 86:   PetscBLASIntCast(m,&n);
 87:   SlepcSqrtmDenmanBeavers(n,T,n,PETSC_FALSE);
 88:   MatDenseRestoreArray(B,&T);
 89:   return(0);
 90: }

 92: PetscErrorCode FNEvaluateFunctionMat_Sqrt_NS(FN fn,Mat A,Mat B)
 93: {
 95:   PetscBLASInt   n;
 96:   PetscScalar    *Ba;
 97:   PetscInt       m;

100:   if (A!=B) { MatCopy(A,B,SAME_NONZERO_PATTERN); }
101:   MatDenseGetArray(B,&Ba);
102:   MatGetSize(A,&m,NULL);
103:   PetscBLASIntCast(m,&n);
104:   SlepcSqrtmNewtonSchulz(n,Ba,n,PETSC_FALSE);
105:   MatDenseRestoreArray(B,&Ba);
106:   return(0);
107: }

109: #define MAXIT 50

111: /*
112:    Computes the principal square root of the matrix A using the
113:    Sadeghi iteration. A is overwritten with sqrtm(A).
114:  */
115: static PetscErrorCode SlepcSqrtmSadeghi(PetscBLASInt n,PetscScalar *A,PetscBLASInt ld)
116: {
117: #if defined(PETSC_MISSING_LAPACK_GETRF) || defined(PETSC_MISSING_LAPACK_GETRI)
119:   SETERRQ(PETSC_COMM_SELF,PETSC_ERR_SUP,"GETRF/GETRI - Lapack routine is unavailable");
120: #else
121:   PetscScalar        *M,*M2,*G,*X=A,*work,work1,alpha,sqrtnrm;
122:   PetscScalar        szero=0.0,sone=1.0,smfive=-5.0,s1d16=1.0/16.0;
123:   PetscReal          tol,Mres,nrm,rwork[1];
124:   PetscBLASInt       N,i,it,*piv=NULL,info,lwork,query=-1;
125:   const PetscBLASInt one=1;
126:   PetscBool          converged=PETSC_FALSE;
127:   PetscErrorCode     ierr;
128:   unsigned int       ftz;

131:   N = n*n;
132:   tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;
133:   SlepcSetFlushToZero(&ftz);

135:   /* query work size */
136:   PetscStackCallBLAS("LAPACKgetri",LAPACKgetri_(&n,A,&ld,piv,&work1,&query,&info));
137:   PetscBLASIntCast((PetscInt)PetscRealPart(work1),&lwork);

139:   PetscMalloc5(N,&M,N,&M2,N,&G,lwork,&work,n,&piv);
140:   PetscMemcpy(M,A,N*sizeof(PetscScalar));

142:   /* scale M */
143:   nrm = LAPACKlange_("fro",&n,&n,M,&n,rwork);
144:   if (nrm>1.0) {
145:     sqrtnrm = PetscSqrtReal(nrm);
146:     alpha = 1.0/nrm;
147:     PetscStackCallBLAS("BLASscal",BLASscal_(&N,&alpha,M,&one));
148:     tol *= nrm;
149:   }
150:   PetscInfo2(NULL,"||A||_F = %g, new tol: %g\n",(double)nrm,(double)tol);

152:   /* X = I */
153:   PetscMemzero(X,N*sizeof(PetscScalar));
154:   for (i=0;i<n;i++) X[i+i*ld] = 1.0;

156:   for (it=0;it<MAXIT && !converged;it++) {

158:     /* G = (5/16)*I + (1/16)*M*(15*I-5*M+M*M) */
159:     PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,M,&ld,M,&ld,&szero,M2,&ld));
160:     PetscStackCallBLAS("BLASaxpy",BLASaxpy_(&N,&smfive,M,&one,M2,&one));
161:     for (i=0;i<n;i++) M2[i+i*ld] += 15.0;
162:     PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&s1d16,M,&ld,M2,&ld,&szero,G,&ld));
163:     for (i=0;i<n;i++) G[i+i*ld] += 5.0/16.0;

165:     /* X = X*G */
166:     PetscMemcpy(M2,X,N*sizeof(PetscScalar));
167:     PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,M2,&ld,G,&ld,&szero,X,&ld));

169:     /* M = M*inv(G*G) */
170:     PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,G,&ld,G,&ld,&szero,M2,&ld));
171:     PetscStackCallBLAS("LAPACKgetrf",LAPACKgetrf_(&n,&n,M2,&ld,piv,&info));
172:     SlepcCheckLapackInfo("getrf",info);
173:     PetscStackCallBLAS("LAPACKgetri",LAPACKgetri_(&n,M2,&ld,piv,work,&lwork,&info));
174:     SlepcCheckLapackInfo("getri",info);

176:     PetscMemcpy(G,M,N*sizeof(PetscScalar));
177:     PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,G,&ld,M2,&ld,&szero,M,&ld));

179:     /* check ||I-M|| */
180:     PetscMemcpy(M2,M,N*sizeof(PetscScalar));
181:     for (i=0;i<n;i++) M2[i+i*ld] -= 1.0;
182:     Mres = LAPACKlange_("fro",&n,&n,M2,&n,rwork);
183:     PetscIsNanReal(Mres);
184:     if (Mres<=tol) converged = PETSC_TRUE;
185:     PetscInfo2(NULL,"it: %D res: %g\n",it,(double)Mres);
186:     PetscLogFlops(8.0*n*n*n+2.0*n*n+2.0*n*n*n/3.0+4.0*n*n*n/3.0+2.0*n*n*n+2.0*n*n);
187:   }

189:   if (Mres>tol) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_LIB,"SQRTM not converged after %d iterations",MAXIT);

191:   /* undo scaling */
192:   if (nrm>1.0) PetscStackCallBLAS("BLASscal",BLASscal_(&N,&sqrtnrm,A,&one));

194:   PetscFree5(M,M2,G,work,piv);
195:   SlepcResetFlushToZero(&ftz);
196:   return(0);
197: #endif
198: }

200: PetscErrorCode FNEvaluateFunctionMat_Sqrt_Sadeghi(FN fn,Mat A,Mat B)
201: {
203:   PetscBLASInt   n;
204:   PetscScalar    *Ba;
205:   PetscInt       m;

208:   if (A!=B) { MatCopy(A,B,SAME_NONZERO_PATTERN); }
209:   MatDenseGetArray(B,&Ba);
210:   MatGetSize(A,&m,NULL);
211:   PetscBLASIntCast(m,&n);
212:   SlepcSqrtmSadeghi(n,Ba,n);
213:   MatDenseRestoreArray(B,&Ba);
214:   return(0);
215: }

217: PetscErrorCode FNView_Sqrt(FN fn,PetscViewer viewer)
218: {
220:   PetscBool      isascii;
221:   char           str[50];
222:   const char     *methodname[] = {
223:                   "Schur method for the square root",
224:                   "Denman-Beavers (product form)",
225:                   "Newton-Schulz iteration",
226:                   "Sadeghi iteration"
227:   };
228:   const int      nmeth=sizeof(methodname)/sizeof(methodname[0]);

231:   PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&isascii);
232:   if (isascii) {
233:     if (fn->beta==(PetscScalar)1.0) {
234:       if (fn->alpha==(PetscScalar)1.0) {
235:         PetscViewerASCIIPrintf(viewer,"  Square root: sqrt(x)\n");
236:       } else {
237:         SlepcSNPrintfScalar(str,50,fn->alpha,PETSC_TRUE);
238:         PetscViewerASCIIPrintf(viewer,"  Square root: sqrt(%s*x)\n",str);
239:       }
240:     } else {
241:       SlepcSNPrintfScalar(str,50,fn->beta,PETSC_TRUE);
242:       if (fn->alpha==(PetscScalar)1.0) {
243:         PetscViewerASCIIPrintf(viewer,"  Square root: %s*sqrt(x)\n",str);
244:       } else {
245:         PetscViewerASCIIPrintf(viewer,"  Square root: %s",str);
246:         PetscViewerASCIIUseTabs(viewer,PETSC_FALSE);
247:         SlepcSNPrintfScalar(str,50,fn->alpha,PETSC_TRUE);
248:         PetscViewerASCIIPrintf(viewer,"*sqrt(%s*x)\n",str);
249:         PetscViewerASCIIUseTabs(viewer,PETSC_TRUE);
250:       }
251:     }
252:     if (fn->method<nmeth) {
253:       PetscViewerASCIIPrintf(viewer,"  computing matrix functions with: %s\n",methodname[fn->method]);
254:     }
255:   }
256:   return(0);
257: }

259: SLEPC_EXTERN PetscErrorCode FNCreate_Sqrt(FN fn)
260: {
262:   fn->ops->evaluatefunction          = FNEvaluateFunction_Sqrt;
263:   fn->ops->evaluatederivative        = FNEvaluateDerivative_Sqrt;
264:   fn->ops->evaluatefunctionmat[0]    = FNEvaluateFunctionMat_Sqrt_Schur;
265:   fn->ops->evaluatefunctionmat[1]    = FNEvaluateFunctionMat_Sqrt_DBP;
266:   fn->ops->evaluatefunctionmat[2]    = FNEvaluateFunctionMat_Sqrt_NS;
267:   fn->ops->evaluatefunctionmat[3]    = FNEvaluateFunctionMat_Sqrt_Sadeghi;
268:   fn->ops->evaluatefunctionmatvec[0] = FNEvaluateFunctionMatVec_Sqrt_Schur;
269:   fn->ops->view                      = FNView_Sqrt;
270:   return(0);
271: }