jax.numpy.inner#
- jax.numpy.inner(a, b, *, precision=None, preferred_element_type=None)[source]#
Inner product of two arrays.
LAX-backend implementation of
numpy.inner().In addition to the original NumPy arguments listed below, also supports
precisionfor extra control over matrix-multiplication precision on supported devices.precisionmay be set toNone, which means default precision for the backend, aPrecisionenum value (Precision.DEFAULT,Precision.HIGHorPrecision.HIGHEST) or a tuple of twoPrecisionenums indicating separate precision for each argument.Original docstring below.
Ordinary inner product of vectors for 1-D arrays (without complex conjugation), in higher dimensions a sum product over the last axes.
- Parameters:
a (array_like) – If a and b are nonscalar, their last dimensions must match.
b (array_like) – If a and b are nonscalar, their last dimensions must match.
precision (PrecisionLike) –
preferred_element_type (DType | None) –
- Returns:
out – If a and b are both scalars or both 1-D arrays then a scalar is returned; otherwise an array is returned.
out.shape = (*a.shape[:-1], *b.shape[:-1])- Return type:
ndarray