Actual source code: fnsqrt.c
slepc-3.11.2 2019-07-30
1: /*
2: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
3: SLEPc - Scalable Library for Eigenvalue Problem Computations
4: Copyright (c) 2002-2019, Universitat Politecnica de Valencia, Spain
6: This file is part of SLEPc.
7: SLEPc is distributed under a 2-clause BSD license (see LICENSE).
8: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
9: */
10: /*
11: Square root function sqrt(x)
12: */
14: #include <slepc/private/fnimpl.h> /*I "slepcfn.h" I*/
15: #include <slepcblaslapack.h>
17: PetscErrorCode FNEvaluateFunction_Sqrt(FN fn,PetscScalar x,PetscScalar *y)
18: {
20: #if !defined(PETSC_USE_COMPLEX)
21: if (x<0.0) SETERRQ(PETSC_COMM_SELF,1,"Function not defined in the requested value");
22: #endif
23: *y = PetscSqrtScalar(x);
24: return(0);
25: }
27: PetscErrorCode FNEvaluateDerivative_Sqrt(FN fn,PetscScalar x,PetscScalar *y)
28: {
30: if (x==0.0) SETERRQ(PETSC_COMM_SELF,1,"Derivative not defined in the requested value");
31: #if !defined(PETSC_USE_COMPLEX)
32: if (x<0.0) SETERRQ(PETSC_COMM_SELF,1,"Derivative not defined in the requested value");
33: #endif
34: *y = 1.0/(2.0*PetscSqrtScalar(x));
35: return(0);
36: }
38: PetscErrorCode FNEvaluateFunctionMat_Sqrt_Schur(FN fn,Mat A,Mat B)
39: {
41: PetscBLASInt n;
42: PetscScalar *T;
43: PetscInt m;
46: if (A!=B) { MatCopy(A,B,SAME_NONZERO_PATTERN); }
47: MatDenseGetArray(B,&T);
48: MatGetSize(A,&m,NULL);
49: PetscBLASIntCast(m,&n);
50: SlepcSqrtmSchur(n,T,n,PETSC_FALSE);
51: MatDenseRestoreArray(B,&T);
52: return(0);
53: }
55: PetscErrorCode FNEvaluateFunctionMatVec_Sqrt_Schur(FN fn,Mat A,Vec v)
56: {
58: PetscBLASInt n;
59: PetscScalar *T;
60: PetscInt m;
61: Mat B;
64: FN_AllocateWorkMat(fn,A,&B);
65: MatDenseGetArray(B,&T);
66: MatGetSize(A,&m,NULL);
67: PetscBLASIntCast(m,&n);
68: SlepcSqrtmSchur(n,T,n,PETSC_TRUE);
69: MatDenseRestoreArray(B,&T);
70: MatGetColumnVector(B,v,0);
71: FN_FreeWorkMat(fn,&B);
72: return(0);
73: }
75: PetscErrorCode FNEvaluateFunctionMat_Sqrt_DBP(FN fn,Mat A,Mat B)
76: {
78: PetscBLASInt n;
79: PetscScalar *T;
80: PetscInt m;
83: if (A!=B) { MatCopy(A,B,SAME_NONZERO_PATTERN); }
84: MatDenseGetArray(B,&T);
85: MatGetSize(A,&m,NULL);
86: PetscBLASIntCast(m,&n);
87: SlepcSqrtmDenmanBeavers(n,T,n,PETSC_FALSE);
88: MatDenseRestoreArray(B,&T);
89: return(0);
90: }
92: PetscErrorCode FNEvaluateFunctionMat_Sqrt_NS(FN fn,Mat A,Mat B)
93: {
95: PetscBLASInt n;
96: PetscScalar *Ba;
97: PetscInt m;
100: if (A!=B) { MatCopy(A,B,SAME_NONZERO_PATTERN); }
101: MatDenseGetArray(B,&Ba);
102: MatGetSize(A,&m,NULL);
103: PetscBLASIntCast(m,&n);
104: SlepcSqrtmNewtonSchulz(n,Ba,n,PETSC_FALSE);
105: MatDenseRestoreArray(B,&Ba);
106: return(0);
107: }
109: #define MAXIT 50
111: /*
112: Computes the principal square root of the matrix A using the
113: Sadeghi iteration. A is overwritten with sqrtm(A).
114: */
115: static PetscErrorCode SlepcSqrtmSadeghi(PetscBLASInt n,PetscScalar *A,PetscBLASInt ld)
116: {
117: #if defined(PETSC_MISSING_LAPACK_GETRF) || defined(PETSC_MISSING_LAPACK_GETRI)
119: SETERRQ(PETSC_COMM_SELF,PETSC_ERR_SUP,"GETRF/GETRI - Lapack routine is unavailable");
120: #else
121: PetscScalar *M,*M2,*G,*X=A,*work,work1,alpha,sqrtnrm;
122: PetscScalar szero=0.0,sone=1.0,smfive=-5.0,s1d16=1.0/16.0;
123: PetscReal tol,Mres,nrm,rwork[1];
124: PetscBLASInt N,i,it,*piv=NULL,info,lwork,query=-1;
125: const PetscBLASInt one=1;
126: PetscBool converged=PETSC_FALSE;
127: PetscErrorCode ierr;
128: unsigned int ftz;
131: N = n*n;
132: tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;
133: SlepcSetFlushToZero(&ftz);
135: /* query work size */
136: PetscStackCallBLAS("LAPACKgetri",LAPACKgetri_(&n,A,&ld,piv,&work1,&query,&info));
137: PetscBLASIntCast((PetscInt)PetscRealPart(work1),&lwork);
139: PetscMalloc5(N,&M,N,&M2,N,&G,lwork,&work,n,&piv);
140: PetscMemcpy(M,A,N*sizeof(PetscScalar));
142: /* scale M */
143: nrm = LAPACKlange_("fro",&n,&n,M,&n,rwork);
144: if (nrm>1.0) {
145: sqrtnrm = PetscSqrtReal(nrm);
146: alpha = 1.0/nrm;
147: PetscStackCallBLAS("BLASscal",BLASscal_(&N,&alpha,M,&one));
148: tol *= nrm;
149: }
150: PetscInfo2(NULL,"||A||_F = %g, new tol: %g\n",(double)nrm,(double)tol);
152: /* X = I */
153: PetscMemzero(X,N*sizeof(PetscScalar));
154: for (i=0;i<n;i++) X[i+i*ld] = 1.0;
156: for (it=0;it<MAXIT && !converged;it++) {
158: /* G = (5/16)*I + (1/16)*M*(15*I-5*M+M*M) */
159: PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,M,&ld,M,&ld,&szero,M2,&ld));
160: PetscStackCallBLAS("BLASaxpy",BLASaxpy_(&N,&smfive,M,&one,M2,&one));
161: for (i=0;i<n;i++) M2[i+i*ld] += 15.0;
162: PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&s1d16,M,&ld,M2,&ld,&szero,G,&ld));
163: for (i=0;i<n;i++) G[i+i*ld] += 5.0/16.0;
165: /* X = X*G */
166: PetscMemcpy(M2,X,N*sizeof(PetscScalar));
167: PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,M2,&ld,G,&ld,&szero,X,&ld));
169: /* M = M*inv(G*G) */
170: PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,G,&ld,G,&ld,&szero,M2,&ld));
171: PetscStackCallBLAS("LAPACKgetrf",LAPACKgetrf_(&n,&n,M2,&ld,piv,&info));
172: SlepcCheckLapackInfo("getrf",info);
173: PetscStackCallBLAS("LAPACKgetri",LAPACKgetri_(&n,M2,&ld,piv,work,&lwork,&info));
174: SlepcCheckLapackInfo("getri",info);
176: PetscMemcpy(G,M,N*sizeof(PetscScalar));
177: PetscStackCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,G,&ld,M2,&ld,&szero,M,&ld));
179: /* check ||I-M|| */
180: PetscMemcpy(M2,M,N*sizeof(PetscScalar));
181: for (i=0;i<n;i++) M2[i+i*ld] -= 1.0;
182: Mres = LAPACKlange_("fro",&n,&n,M2,&n,rwork);
183: PetscIsNanReal(Mres);
184: if (Mres<=tol) converged = PETSC_TRUE;
185: PetscInfo2(NULL,"it: %D res: %g\n",it,(double)Mres);
186: PetscLogFlops(8.0*n*n*n+2.0*n*n+2.0*n*n*n/3.0+4.0*n*n*n/3.0+2.0*n*n*n+2.0*n*n);
187: }
189: if (Mres>tol) SETERRQ1(PETSC_COMM_SELF,PETSC_ERR_LIB,"SQRTM not converged after %d iterations",MAXIT);
191: /* undo scaling */
192: if (nrm>1.0) PetscStackCallBLAS("BLASscal",BLASscal_(&N,&sqrtnrm,A,&one));
194: PetscFree5(M,M2,G,work,piv);
195: SlepcResetFlushToZero(&ftz);
196: return(0);
197: #endif
198: }
200: PetscErrorCode FNEvaluateFunctionMat_Sqrt_Sadeghi(FN fn,Mat A,Mat B)
201: {
203: PetscBLASInt n;
204: PetscScalar *Ba;
205: PetscInt m;
208: if (A!=B) { MatCopy(A,B,SAME_NONZERO_PATTERN); }
209: MatDenseGetArray(B,&Ba);
210: MatGetSize(A,&m,NULL);
211: PetscBLASIntCast(m,&n);
212: SlepcSqrtmSadeghi(n,Ba,n);
213: MatDenseRestoreArray(B,&Ba);
214: return(0);
215: }
217: PetscErrorCode FNView_Sqrt(FN fn,PetscViewer viewer)
218: {
220: PetscBool isascii;
221: char str[50];
222: const char *methodname[] = {
223: "Schur method for the square root",
224: "Denman-Beavers (product form)",
225: "Newton-Schulz iteration",
226: "Sadeghi iteration"
227: };
228: const int nmeth=sizeof(methodname)/sizeof(methodname[0]);
231: PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&isascii);
232: if (isascii) {
233: if (fn->beta==(PetscScalar)1.0) {
234: if (fn->alpha==(PetscScalar)1.0) {
235: PetscViewerASCIIPrintf(viewer," Square root: sqrt(x)\n");
236: } else {
237: SlepcSNPrintfScalar(str,50,fn->alpha,PETSC_TRUE);
238: PetscViewerASCIIPrintf(viewer," Square root: sqrt(%s*x)\n",str);
239: }
240: } else {
241: SlepcSNPrintfScalar(str,50,fn->beta,PETSC_TRUE);
242: if (fn->alpha==(PetscScalar)1.0) {
243: PetscViewerASCIIPrintf(viewer," Square root: %s*sqrt(x)\n",str);
244: } else {
245: PetscViewerASCIIPrintf(viewer," Square root: %s",str);
246: PetscViewerASCIIUseTabs(viewer,PETSC_FALSE);
247: SlepcSNPrintfScalar(str,50,fn->alpha,PETSC_TRUE);
248: PetscViewerASCIIPrintf(viewer,"*sqrt(%s*x)\n",str);
249: PetscViewerASCIIUseTabs(viewer,PETSC_TRUE);
250: }
251: }
252: if (fn->method<nmeth) {
253: PetscViewerASCIIPrintf(viewer," computing matrix functions with: %s\n",methodname[fn->method]);
254: }
255: }
256: return(0);
257: }
259: SLEPC_EXTERN PetscErrorCode FNCreate_Sqrt(FN fn)
260: {
262: fn->ops->evaluatefunction = FNEvaluateFunction_Sqrt;
263: fn->ops->evaluatederivative = FNEvaluateDerivative_Sqrt;
264: fn->ops->evaluatefunctionmat[0] = FNEvaluateFunctionMat_Sqrt_Schur;
265: fn->ops->evaluatefunctionmat[1] = FNEvaluateFunctionMat_Sqrt_DBP;
266: fn->ops->evaluatefunctionmat[2] = FNEvaluateFunctionMat_Sqrt_NS;
267: fn->ops->evaluatefunctionmat[3] = FNEvaluateFunctionMat_Sqrt_Sadeghi;
268: fn->ops->evaluatefunctionmatvec[0] = FNEvaluateFunctionMatVec_Sqrt_Schur;
269: fn->ops->view = FNView_Sqrt;
270: return(0);
271: }