diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 56f5c32b3af..cfc54ad207c 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -885,7 +885,7 @@ def _standardize( Args: input (Tensor): the input tensor to be standardized. - exclude_dims (Sequence[int]): dimensions to exclude from the statistics, can be negative. Default: (). + exclude_dims (Tuple[int]): dimensions to exclude from the statistics, can be negative. Default: (). mean (Tensor): a mean to be used for standardization. Must be of shape broadcastable to input. Default: None. std (Tensor): a standard deviation to be used for standardization. Must be of shape broadcastable to input. Default: None. eps (float): epsilon to be used for numerical stability. Default: float32 resolution.