Actual source code: fnsqrt.c
slepc-3.18.3 2023-03-24
1: /*
2: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
3: SLEPc - Scalable Library for Eigenvalue Problem Computations
4: Copyright (c) 2002-, Universitat Politecnica de Valencia, Spain
6: This file is part of SLEPc.
7: SLEPc is distributed under a 2-clause BSD license (see LICENSE).
8: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
9: */
10: /*
11: Square root function sqrt(x)
12: */
14: #include <slepc/private/fnimpl.h>
15: #include <slepcblaslapack.h>
17: PetscErrorCode FNEvaluateFunction_Sqrt(FN fn,PetscScalar x,PetscScalar *y)
18: {
19: #if !defined(PETSC_USE_COMPLEX)
21: #endif
22: *y = PetscSqrtScalar(x);
23: return 0;
24: }
26: PetscErrorCode FNEvaluateDerivative_Sqrt(FN fn,PetscScalar x,PetscScalar *y)
27: {
29: #if !defined(PETSC_USE_COMPLEX)
31: #endif
32: *y = 1.0/(2.0*PetscSqrtScalar(x));
33: return 0;
34: }
36: PetscErrorCode FNEvaluateFunctionMat_Sqrt_Schur(FN fn,Mat A,Mat B)
37: {
38: PetscBLASInt n=0;
39: PetscScalar *T;
40: PetscInt m;
42: if (A!=B) MatCopy(A,B,SAME_NONZERO_PATTERN);
43: MatDenseGetArray(B,&T);
44: MatGetSize(A,&m,NULL);
45: PetscBLASIntCast(m,&n);
46: FNSqrtmSchur(fn,n,T,n,PETSC_FALSE);
47: MatDenseRestoreArray(B,&T);
48: return 0;
49: }
51: PetscErrorCode FNEvaluateFunctionMatVec_Sqrt_Schur(FN fn,Mat A,Vec v)
52: {
53: PetscBLASInt n=0;
54: PetscScalar *T;
55: PetscInt m;
56: Mat B;
58: FN_AllocateWorkMat(fn,A,&B);
59: MatDenseGetArray(B,&T);
60: MatGetSize(A,&m,NULL);
61: PetscBLASIntCast(m,&n);
62: FNSqrtmSchur(fn,n,T,n,PETSC_TRUE);
63: MatDenseRestoreArray(B,&T);
64: MatGetColumnVector(B,v,0);
65: FN_FreeWorkMat(fn,&B);
66: return 0;
67: }
69: PetscErrorCode FNEvaluateFunctionMat_Sqrt_DBP(FN fn,Mat A,Mat B)
70: {
71: PetscBLASInt n=0;
72: PetscScalar *T;
73: PetscInt m;
75: if (A!=B) MatCopy(A,B,SAME_NONZERO_PATTERN);
76: MatDenseGetArray(B,&T);
77: MatGetSize(A,&m,NULL);
78: PetscBLASIntCast(m,&n);
79: FNSqrtmDenmanBeavers(fn,n,T,n,PETSC_FALSE);
80: MatDenseRestoreArray(B,&T);
81: return 0;
82: }
84: PetscErrorCode FNEvaluateFunctionMat_Sqrt_NS(FN fn,Mat A,Mat B)
85: {
86: PetscBLASInt n=0;
87: PetscScalar *Ba;
88: PetscInt m;
90: if (A!=B) MatCopy(A,B,SAME_NONZERO_PATTERN);
91: MatDenseGetArray(B,&Ba);
92: MatGetSize(A,&m,NULL);
93: PetscBLASIntCast(m,&n);
94: FNSqrtmNewtonSchulz(fn,n,Ba,n,PETSC_FALSE);
95: MatDenseRestoreArray(B,&Ba);
96: return 0;
97: }
99: #define MAXIT 50
101: /*
102: Computes the principal square root of the matrix A using the
103: Sadeghi iteration. A is overwritten with sqrtm(A).
104: */
105: PetscErrorCode FNSqrtmSadeghi(FN fn,PetscBLASInt n,PetscScalar *A,PetscBLASInt ld)
106: {
107: PetscScalar *M,*M2,*G,*X=A,*work,work1,sqrtnrm;
108: PetscScalar szero=0.0,sone=1.0,smfive=-5.0,s1d16=1.0/16.0;
109: PetscReal tol,Mres=0.0,nrm,rwork[1],done=1.0;
110: PetscInt i,it;
111: PetscBLASInt N,*piv=NULL,info,lwork=0,query=-1,one=1,zero=0;
112: PetscBool converged=PETSC_FALSE;
113: unsigned int ftz;
115: N = n*n;
116: tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;
117: SlepcSetFlushToZero(&ftz);
119: /* query work size */
120: PetscCallBLAS("LAPACKgetri",LAPACKgetri_(&n,A,&ld,piv,&work1,&query,&info));
121: PetscBLASIntCast((PetscInt)PetscRealPart(work1),&lwork);
123: PetscMalloc5(N,&M,N,&M2,N,&G,lwork,&work,n,&piv);
124: PetscArraycpy(M,A,N);
126: /* scale M */
127: nrm = LAPACKlange_("fro",&n,&n,M,&n,rwork);
128: if (nrm>1.0) {
129: sqrtnrm = PetscSqrtReal(nrm);
130: PetscCallBLAS("LAPACKlascl",LAPACKlascl_("G",&zero,&zero,&nrm,&done,&N,&one,M,&N,&info));
131: SlepcCheckLapackInfo("lascl",info);
132: tol *= nrm;
133: }
134: PetscInfo(fn,"||A||_F = %g, new tol: %g\n",(double)nrm,(double)tol);
136: /* X = I */
137: PetscArrayzero(X,N);
138: for (i=0;i<n;i++) X[i+i*ld] = 1.0;
140: for (it=0;it<MAXIT && !converged;it++) {
142: /* G = (5/16)*I + (1/16)*M*(15*I-5*M+M*M) */
143: PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,M,&ld,M,&ld,&szero,M2,&ld));
144: PetscCallBLAS("BLASaxpy",BLASaxpy_(&N,&smfive,M,&one,M2,&one));
145: for (i=0;i<n;i++) M2[i+i*ld] += 15.0;
146: PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&s1d16,M,&ld,M2,&ld,&szero,G,&ld));
147: for (i=0;i<n;i++) G[i+i*ld] += 5.0/16.0;
149: /* X = X*G */
150: PetscArraycpy(M2,X,N);
151: PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,M2,&ld,G,&ld,&szero,X,&ld));
153: /* M = M*inv(G*G) */
154: PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,G,&ld,G,&ld,&szero,M2,&ld));
155: PetscCallBLAS("LAPACKgetrf",LAPACKgetrf_(&n,&n,M2,&ld,piv,&info));
156: SlepcCheckLapackInfo("getrf",info);
157: PetscCallBLAS("LAPACKgetri",LAPACKgetri_(&n,M2,&ld,piv,work,&lwork,&info));
158: SlepcCheckLapackInfo("getri",info);
160: PetscArraycpy(G,M,N);
161: PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,G,&ld,M2,&ld,&szero,M,&ld));
163: /* check ||I-M|| */
164: PetscArraycpy(M2,M,N);
165: for (i=0;i<n;i++) M2[i+i*ld] -= 1.0;
166: Mres = LAPACKlange_("fro",&n,&n,M2,&n,rwork);
168: if (Mres<=tol) converged = PETSC_TRUE;
169: PetscInfo(fn,"it: %" PetscInt_FMT " res: %g\n",it,(double)Mres);
170: PetscLogFlops(8.0*n*n*n+2.0*n*n+2.0*n*n*n/3.0+4.0*n*n*n/3.0+2.0*n*n*n+2.0*n*n);
171: }
175: /* undo scaling */
176: if (nrm>1.0) PetscCallBLAS("BLASscal",BLASscal_(&N,&sqrtnrm,A,&one));
178: PetscFree5(M,M2,G,work,piv);
179: SlepcResetFlushToZero(&ftz);
180: return 0;
181: }
183: #if defined(PETSC_HAVE_CUDA)
184: #include "../src/sys/classes/fn/impls/cuda/fnutilcuda.h"
185: #include <slepccublas.h>
187: #if defined(PETSC_HAVE_MAGMA)
188: #include <slepcmagma.h>
190: /*
191: * Matrix square root by Sadeghi iteration. CUDA version.
192: * Computes the principal square root of the matrix A using the
193: * Sadeghi iteration. A is overwritten with sqrtm(A).
194: */
195: PetscErrorCode FNSqrtmSadeghi_CUDAm(FN fn,PetscBLASInt n,PetscScalar *d_A,PetscBLASInt ld)
196: {
197: PetscScalar *d_M,*d_M2,*d_G,*d_work,alpha;
198: const PetscScalar szero=0.0,sone=1.0,smfive=-5.0,s15=15.0,s1d16=1.0/16.0;
199: PetscReal tol,Mres=0.0,nrm,sqrtnrm=1.0;
200: PetscInt it,nb,lwork;
201: PetscBLASInt *piv,N;
202: const PetscBLASInt one=1;
203: PetscBool converged=PETSC_FALSE;
204: cublasHandle_t cublasv2handle;
206: PetscDeviceInitialize(PETSC_DEVICE_CUDA); /* For CUDA event timers */
207: PetscCUBLASGetHandle(&cublasv2handle);
208: SlepcMagmaInit();
209: N = n*n;
210: tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;
212: PetscMalloc1(n,&piv);
213: cudaMalloc((void **)&d_M,sizeof(PetscScalar)*N);
214: cudaMalloc((void **)&d_M2,sizeof(PetscScalar)*N);
215: cudaMalloc((void **)&d_G,sizeof(PetscScalar)*N);
217: nb = magma_get_xgetri_nb(n);
218: lwork = nb*n;
219: cudaMalloc((void **)&d_work,sizeof(PetscScalar)*lwork);
220: PetscLogGpuTimeBegin();
222: /* M = A */
223: cudaMemcpy(d_M,d_A,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice);
225: /* scale M */
226: cublasXnrm2(cublasv2handle,N,d_M,one,&nrm);
227: if (nrm>1.0) {
228: sqrtnrm = PetscSqrtReal(nrm);
229: alpha = 1.0/nrm;
230: cublasXscal(cublasv2handle,N,&alpha,d_M,one);
231: tol *= nrm;
232: }
233: PetscInfo(fn,"||A||_F = %g, new tol: %g\n",(double)nrm,(double)tol);
235: /* X = I */
236: cudaMemset(d_A,0,sizeof(PetscScalar)*N);
237: set_diagonal(n,d_A,ld,sone);
239: for (it=0;it<MAXIT && !converged;it++) {
241: /* G = (5/16)*I + (1/16)*M*(15*I-5*M+M*M) */
242: cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&sone,d_M,ld,d_M,ld,&szero,d_M2,ld);
243: cublasXaxpy(cublasv2handle,N,&smfive,d_M,one,d_M2,one);
244: shift_diagonal(n,d_M2,ld,s15);
245: cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&s1d16,d_M,ld,d_M2,ld,&szero,d_G,ld);
246: shift_diagonal(n,d_G,ld,5.0/16.0);
248: /* X = X*G */
249: cudaMemcpy(d_M2,d_A,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice);
250: cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&sone,d_M2,ld,d_G,ld,&szero,d_A,ld);
252: /* M = M*inv(G*G) */
253: cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&sone,d_G,ld,d_G,ld,&szero,d_M2,ld);
254: /* magma */
255: PetscCallMAGMA(magma_xgetrf_gpu,n,n,d_M2,ld,piv);
256: PetscCallMAGMA(magma_xgetri_gpu,n,d_M2,ld,piv,d_work,lwork);
257: /* magma */
258: cudaMemcpy(d_G,d_M,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice);
259: cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&sone,d_G,ld,d_M2,ld,&szero,d_M,ld);
261: /* check ||I-M|| */
262: cudaMemcpy(d_M2,d_M,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice);
263: shift_diagonal(n,d_M2,ld,-1.0);
264: cublasXnrm2(cublasv2handle,N,d_M2,one,&Mres);
266: if (Mres<=tol) converged = PETSC_TRUE;
267: PetscInfo(fn,"it: %" PetscInt_FMT " res: %g\n",it,(double)Mres);
268: PetscLogGpuFlops(8.0*n*n*n+2.0*n*n+2.0*n*n*n/3.0+4.0*n*n*n/3.0+2.0*n*n*n+2.0*n*n);
269: }
273: if (nrm>1.0) {
274: alpha = sqrtnrm;
275: cublasXscal(cublasv2handle,N,&alpha,d_A,one);
276: }
277: PetscLogGpuTimeEnd();
279: cudaFree(d_M);
280: cudaFree(d_M2);
281: cudaFree(d_G);
282: cudaFree(d_work);
283: PetscFree(piv);
284: return 0;
285: }
286: #endif /* PETSC_HAVE_MAGMA */
287: #endif /* PETSC_HAVE_CUDA */
289: PetscErrorCode FNEvaluateFunctionMat_Sqrt_Sadeghi(FN fn,Mat A,Mat B)
290: {
291: PetscBLASInt n=0;
292: PetscScalar *Ba;
293: PetscInt m;
295: if (A!=B) MatCopy(A,B,SAME_NONZERO_PATTERN);
296: MatDenseGetArray(B,&Ba);
297: MatGetSize(A,&m,NULL);
298: PetscBLASIntCast(m,&n);
299: FNSqrtmSadeghi(fn,n,Ba,n);
300: MatDenseRestoreArray(B,&Ba);
301: return 0;
302: }
304: #if defined(PETSC_HAVE_CUDA)
305: PetscErrorCode FNEvaluateFunctionMat_Sqrt_NS_CUDA(FN fn,Mat A,Mat B)
306: {
307: PetscBLASInt n=0;
308: PetscScalar *Ba;
309: PetscInt m;
311: if (A!=B) MatCopy(A,B,SAME_NONZERO_PATTERN);
312: MatDenseCUDAGetArray(B,&Ba);
313: MatGetSize(A,&m,NULL);
314: PetscBLASIntCast(m,&n);
315: FNSqrtmNewtonSchulz_CUDA(fn,n,Ba,n,PETSC_FALSE);
316: MatDenseCUDARestoreArray(B,&Ba);
317: return 0;
318: }
320: #if defined(PETSC_HAVE_MAGMA)
321: PetscErrorCode FNEvaluateFunctionMat_Sqrt_DBP_CUDAm(FN fn,Mat A,Mat B)
322: {
323: PetscBLASInt n=0;
324: PetscScalar *T;
325: PetscInt m;
327: if (A!=B) MatCopy(A,B,SAME_NONZERO_PATTERN);
328: MatDenseCUDAGetArray(B,&T);
329: MatGetSize(A,&m,NULL);
330: PetscBLASIntCast(m,&n);
331: FNSqrtmDenmanBeavers_CUDAm(fn,n,T,n,PETSC_FALSE);
332: MatDenseCUDARestoreArray(B,&T);
333: return 0;
334: }
336: PetscErrorCode FNEvaluateFunctionMat_Sqrt_Sadeghi_CUDAm(FN fn,Mat A,Mat B)
337: {
338: PetscBLASInt n=0;
339: PetscScalar *Ba;
340: PetscInt m;
342: if (A!=B) MatCopy(A,B,SAME_NONZERO_PATTERN);
343: MatDenseCUDAGetArray(B,&Ba);
344: MatGetSize(A,&m,NULL);
345: PetscBLASIntCast(m,&n);
346: FNSqrtmSadeghi_CUDAm(fn,n,Ba,n);
347: MatDenseCUDARestoreArray(B,&Ba);
348: return 0;
349: }
350: #endif /* PETSC_HAVE_MAGMA */
351: #endif /* PETSC_HAVE_CUDA */
353: PetscErrorCode FNView_Sqrt(FN fn,PetscViewer viewer)
354: {
355: PetscBool isascii;
356: char str[50];
357: const char *methodname[] = {
358: "Schur method for the square root",
359: "Denman-Beavers (product form)",
360: "Newton-Schulz iteration",
361: "Sadeghi iteration"
362: };
363: const int nmeth=PETSC_STATIC_ARRAY_LENGTH(methodname);
365: PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&isascii);
366: if (isascii) {
367: if (fn->beta==(PetscScalar)1.0) {
368: if (fn->alpha==(PetscScalar)1.0) PetscViewerASCIIPrintf(viewer," square root: sqrt(x)\n");
369: else {
370: SlepcSNPrintfScalar(str,sizeof(str),fn->alpha,PETSC_TRUE);
371: PetscViewerASCIIPrintf(viewer," square root: sqrt(%s*x)\n",str);
372: }
373: } else {
374: SlepcSNPrintfScalar(str,sizeof(str),fn->beta,PETSC_TRUE);
375: if (fn->alpha==(PetscScalar)1.0) PetscViewerASCIIPrintf(viewer," square root: %s*sqrt(x)\n",str);
376: else {
377: PetscViewerASCIIPrintf(viewer," square root: %s",str);
378: PetscViewerASCIIUseTabs(viewer,PETSC_FALSE);
379: SlepcSNPrintfScalar(str,sizeof(str),fn->alpha,PETSC_TRUE);
380: PetscViewerASCIIPrintf(viewer,"*sqrt(%s*x)\n",str);
381: PetscViewerASCIIUseTabs(viewer,PETSC_TRUE);
382: }
383: }
384: if (fn->method<nmeth) PetscViewerASCIIPrintf(viewer," computing matrix functions with: %s\n",methodname[fn->method]);
385: }
386: return 0;
387: }
389: SLEPC_EXTERN PetscErrorCode FNCreate_Sqrt(FN fn)
390: {
391: fn->ops->evaluatefunction = FNEvaluateFunction_Sqrt;
392: fn->ops->evaluatederivative = FNEvaluateDerivative_Sqrt;
393: fn->ops->evaluatefunctionmat[0] = FNEvaluateFunctionMat_Sqrt_Schur;
394: fn->ops->evaluatefunctionmat[1] = FNEvaluateFunctionMat_Sqrt_DBP;
395: fn->ops->evaluatefunctionmat[2] = FNEvaluateFunctionMat_Sqrt_NS;
396: fn->ops->evaluatefunctionmat[3] = FNEvaluateFunctionMat_Sqrt_Sadeghi;
397: #if defined(PETSC_HAVE_CUDA)
398: fn->ops->evaluatefunctionmatcuda[2] = FNEvaluateFunctionMat_Sqrt_NS_CUDA;
399: #if defined(PETSC_HAVE_MAGMA)
400: fn->ops->evaluatefunctionmatcuda[1] = FNEvaluateFunctionMat_Sqrt_DBP_CUDAm;
401: fn->ops->evaluatefunctionmatcuda[3] = FNEvaluateFunctionMat_Sqrt_Sadeghi_CUDAm;
402: #endif /* PETSC_HAVE_MAGMA */
403: #endif /* PETSC_HAVE_CUDA */
404: fn->ops->evaluatefunctionmatvec[0] = FNEvaluateFunctionMatVec_Sqrt_Schur;
405: fn->ops->view = FNView_Sqrt;
406: return 0;
407: }