jax.numpy.roots#
- jax.numpy.roots(p, *, strip_zeros=True)[source]#
Return the roots of a polynomial with coefficients given in p.
LAX-backend implementation of
numpy.roots().Unlike the numpy version of this function, the JAX version returns the roots in a complex array regardless of the values of the roots. Additionally, the jax version of this function adds the
strip_zerosfunction which must be set to False for the function to be compatible with JIT and other JAX transformations. Withstrip_zeros=False, if your coefficients have leading zeros, the roots will be padded with NaN values:>>> coeffs = jnp.array([0, 1, 2])
# The default behavior matches numpy and strips leading zeros: >>> jnp.roots(coeffs) Array([-2.+0.j], dtype=complex64)
# With strip_zeros=False, extra roots are set to NaN: >>> jnp.roots(coeffs, strip_zeros=False) Array([-2. +0.j, nan+nanj], dtype=complex64)
Original docstring below.
Note
This forms part of the old polynomial API. Since version 1.4, the new polynomial API defined in numpy.polynomial is preferred. A summary of the differences can be found in the transition guide.
The values in the rank-1 array p are coefficients of a polynomial. If the length of p is n+1 then the polynomial is described by:
p[0] * x**n + p[1] * x**(n-1) + ... + p[n-1]*x + p[n]
- Parameters:
p (array_like) – Rank-1 array of polynomial coefficients.
strip_zeros (bool, default=True) – If set to True, then leading zeros in the coefficients will be stripped, similar to
numpy.roots(). If set to False, leading zeros will not be stripped, and undefined roots will be represented by NaN values in the function output.strip_zerosmust be set toFalsefor the function to be compatible withjax.jit()and other JAX transformations.
- Returns:
out – An array containing the roots of the polynomial.
- Return type:
ndarray
References