jax.lax.dot_general#
- jax.lax.dot_general(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None)[source]#
General dot product/contraction operator.
Wraps XLA’s DotGeneral operator.
The semantics of
dot_generalare complicated, but most users should not have to use it directly. Instead, you can use higher-level functions likejax.numpy.dot(),jax.numpy.matmul(),jax.numpy.tensordot(),jax.numpy.einsum(), and others which will construct appropriate calls todot_generalunder the hood. If you really want to understanddot_generalitself, we recommend reading XLA’s DotGeneral operator documentation.- Parameters:
lhs (
Union[Array,ndarray,bool_,number,bool,int,float,complex]) – an arrayrhs (
Union[Array,ndarray,bool_,number,bool,int,float,complex]) – an arraydimension_numbers (
tuple[tuple[Sequence[int],Sequence[int]],tuple[Sequence[int],Sequence[int]]]) – a tuple of tuples of sequences of ints of the form((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))precision (
Union[None,str,Precision,tuple[str,str],tuple[Precision,Precision]]) – Optional. EitherNone, which means the default precision for the backend, aPrecisionenum value (Precision.DEFAULT,Precision.HIGHorPrecision.HIGHEST) or a tuple of twoPrecisionenums indicating precision oflhs`andrhs.preferred_element_type (
Union[str,type[Any],dtype,SupportsDType,None]) – Optional. EitherNone, which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.
- Return type:
- Returns:
An array whose first dimensions are the (shared) batch dimensions, followed by the
lhsnon-contracting/non-batch dimensions, and finally therhsnon-contracting/non-batch dimensions.