jax.nn.initializers.variance_scaling#
- jax.nn.initializers.variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
Initializer that adapts its scale to the shape of the weights tensor.
With
distribution="truncated_normal"ordistribution="normal", samples are drawn from a (truncated) normal distribution with a mean of zero and a standard deviation (after truncation, if applicable) of \(\sqrt{\frac{scale}{n}}\), where n is:the number of input units in the weights tensor, if
mode="fan_in",the number of output units, if
mode="fan_out", orthe average of the numbers of input and output units, if
mode="fan_avg".
This initializer can be configured with
in_axis,out_axis, andbatch_axisto work with general convolutional or dense layers; axes that are not in any of those arguments are assumed to be the “receptive field” (convolution kernel spatial axes).With
distribution="truncated_normal", the absolute values of the samples are truncated at 2 standard deviations before scaling.With
distribution="uniform", samples are drawn from:a uniform interval, if dtype is real, or
a uniform disk, if dtype is complex,
with a mean of zero and a standard deviation of \(\sqrt{\frac{scale}{n}}\) where n is defined above.
- Parameters:
scale (
Any) – scaling factor (positive float).mode (
Union[Literal['fan_in'],Literal['fan_out'],Literal['fan_avg']]) – one of"fan_in","fan_out", and"fan_avg".distribution (
Union[Literal['truncated_normal'],Literal['normal'],Literal['uniform']]) – random distribution to use. One of"truncated_normal","normal"and"uniform".in_axis (
Union[int,Sequence[int]]) – axis or sequence of axes of the input dimension in the weights array.out_axis (
Union[int,Sequence[int]]) – axis or sequence of axes of the output dimension in the weights array.batch_axis (
Sequence[int]) – axis or sequence of axes in the weight array that should be ignored.dtype (
Any) – the dtype of the weights.
- Return type:
Initializer