jax.scipy.linalg.sqrtm#
- jax.scipy.linalg.sqrtm(A, blocksize=1)[source]#
Matrix square root.
LAX-backend implementation of
scipy.linalg._matfuncs_sqrtm.sqrtm().This differs from
scipy.linalg.sqrtmin that the return type ofjax.scipy.linalg.sqrtmis alwayscomplex64for 32-bit input, andcomplex128for 64-bit input.This function implements the complex Schur method described in [A]. It does not use recursive blocking to speed up computations as a Sylvester Equation solver is not available yet in JAX.
- [A] Björck, Å., & Hammarling, S. (1983).
“A Schur method for the square root of a matrix”. Linear algebra and its applications, 52, 127-140.
Original docstring below.
- Parameters:
A ((N, N) array_like) – Matrix whose square root to evaluate
blocksize (integer, optional) – If the blocksize is not degenerate with respect to the size of the input array, then use a blocked algorithm. (Default: 64)
- Return type:
- Returns:
sqrtm ((N, N) ndarray) – Value of the sqrt function at A. The dtype is float or complex. The precision (data size) is determined based on the precision of input A. When the dtype is float, the precision is the same as A. When the dtype is complex, the precision is double that of A. The precision might be clipped by each dtype precision range.
errest (float) – (if disp == False)
Frobenius norm of the estimated error, ||err||_F / ||A||_F
References