diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 8b90c9a0d8..ff8fbb32c3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -42,7 +42,7 @@ def __init__(self, params, defaults): super(BatchedOptimizer, self).__init__(params, defaults) @contextlib.contextmanager - def batched_params(self, param_group): + def batched_params(self, param_group, group_params_names): """ This function returns (technically, yields) a list of of tuples (p, state), where @@ -64,31 +64,44 @@ def batched_params(self, param_group): you can do: with self.batched_params(group["params"]) as batches: - for p, state in batches: + for p, state, p_names in batches: ... Args: group: a parameter group, which is a list of parameters; should be - one of self.groups. + one of self.param_groups. + group_params_names: name for each parameter in group, + which is List[str]. """ batches = defaultdict( list ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches_names = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str - for p in param_group: + assert len(param_group) == len(group_params_names) + for p, named_p in zip(param_group, group_params_names): key = (str(p.dtype), *p.shape) batches[key].append(p) + batches_names[key].append(named_p) + + batches_names_keys = list(batches_names.keys()) + sorted_idx = sorted( + range(len(batches_names)), key=lambda i: batches_names_keys[i] + ) + batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] + batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] stacked_params_dict = dict() # turn batches into a list, in deterministic order. - batches = [batches[key] for key in sorted(batches.keys())] - # pairs will contain pairs of (stacked_param, state), one for each batch - # in `batches`. - pairs = [] + # tuples will contain tuples of (stacked_param, state, stacked_params_names), + # one for each batch in `batches`. + tuples = [] - for batch in batches: + for batch, batch_names in zip(batches, batches_names): p = batch[0] # we arbitrarily store the state in the # state corresponding to the 1st parameter in the @@ -100,11 +113,11 @@ def batched_params(self, param_group): ) p_stacked.grad = grad stacked_params_dict[key] = p_stacked - pairs.append((p_stacked, state)) + tuples.append((p_stacked, state, batch_names)) - yield pairs # <-- calling code will do the actual optimization here! + yield tuples # <-- calling code will do the actual optimization here! - for ((stacked_params, _state), batch) in zip(pairs, batches): + for ((stacked_params, _state, _names), batch) in zip(tuples, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) @@ -165,8 +178,15 @@ def __init__( scalar_max=10.0, size_update_period=4, clipping_update_period=100, + parameters_names=None, + show_dominant_parameters=True, ): + assert parameters_names is not None, ( + "Please prepare parameters_names," + "which is a List[List[str]]. Each List[str] is for a group" + "and each str is for a parameter" + ) defaults = dict( lr=lr, clipping_scale=clipping_scale, @@ -181,6 +201,9 @@ def __init__( ) super(ScaledAdam, self).__init__(params, defaults) + assert len(self.param_groups) == len(parameters_names) + self.parameters_names = parameters_names + self.show_dominant_parameters = show_dominant_parameters def __setstate__(self, state): super(ScaledAdam, self).__setstate__(state) @@ -199,9 +222,10 @@ def step(self, closure=None): loss = closure() batch = True - for group in self.param_groups: - with self.batched_params(group["params"]) as batches: + for group, group_params_names in zip(self.param_groups, self.parameters_names): + + with self.batched_params(group["params"], group_params_names) as batches: # batches is list of pairs (stacked_param, state). stacked_param is like # a regular parameter, and will have a .grad, but the 1st dim corresponds to @@ -214,7 +238,7 @@ def step(self, closure=None): else: clipping_scale = self._get_clipping_scale(group, batches) - for p, state in batches: + for p, state, _ in batches: # Perform optimization step. # grad is not going to be None, we handled that when creating the batches. grad = p.grad @@ -276,7 +300,7 @@ def _init_state(self, group: dict, p: Tensor, state: dict): state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) def _get_clipping_scale( - self, group: dict, pairs: List[Tuple[Tensor, dict]] + self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] ) -> float: """ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients @@ -284,12 +308,16 @@ def _get_clipping_scale( Args: group: the parameter group, an item in self.param_groups - pairs: a list of pairs of (param, state) where param is a batched set of parameters, with a .grad - (1st dim is batch dim) and state is the state-dict where optimization parameters are kept. + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". """ - assert len(pairs) >= 1 + assert len(tuples) >= 1 clipping_scale = group["clipping_scale"] - (first_p, first_state) = pairs[0] + (first_p, first_state, _) = tuples[0] step = first_state["step"] if clipping_scale is None or step == 0: # no clipping. return early on step == 0 because the other @@ -298,7 +326,7 @@ def _get_clipping_scale( clipping_update_period = group["clipping_update_period"] tot_sumsq = torch.tensor(0.0, device=first_p.device) - for (p, state) in pairs: + for (p, state, param_names) in tuples: grad = p.grad if grad.is_sparse: raise RuntimeError( @@ -361,8 +389,74 @@ def _get_clipping_scale( logging.warn( f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" ) + if self.show_dominant_parameters: + assert p.shape[0] == len(param_names) + self._show_gradient_dominating_parameter(tuples, tot_sumsq) return ans + def _show_gradient_dominating_parameter( + self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor + ): + """ + Show information of parameter wihch dominanting tot_sumsq. + + Args: + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + tot_sumsq: sumsq of all parameters. Though it's could be calculated + from tuples, we still pass it to save some time. + """ + all_sumsq_orig = {} + for (p, state, batch_param_names) in tuples: + # p is a stacked batch parameters. + batch_grad = p.grad + if p.numel() == p.shape[0]: # a batch of scalars + batch_sumsq_orig = batch_grad**2 + # Dummpy values used by following `zip` statement. + batch_rms_orig = torch.ones(p.shape[0]) + else: + batch_rms_orig = state["param_rms"] + batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum( + dim=list(range(1, batch_grad.ndim)) + ) + + for name, sumsq_orig, rms, grad in zip( + batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad + ): + + proportion_orig = sumsq_orig / tot_sumsq + all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) + + assert torch.isclose( + sum([value[0] for value in all_sumsq_orig.values()]).cpu(), + torch.tensor(1.0), + ) + sorted_by_proportion = { + k: v + for k, v in sorted( + all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True + ) + } + dominant_param_name = next(iter(sorted_by_proportion)) + ( + dominant_proportion, + dominant_sumsq, + dominant_rms, + dominant_grad, + ) = sorted_by_proportion[dominant_param_name] + logging.info( + f"Parameter Dominanting tot_sumsq {dominant_param_name}" + f" with proportion {dominant_proportion:.2f}," + f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" + f"={dominant_sumsq:.3e}," + f" grad_sumsq = {(dominant_grad**2).sum():.3e}," + f" orig_rms_sq={(dominant_rms**2).item():.3e}" + ) + def _step_one_batch( self, group: dict, p: Tensor, state: dict, clipping_scale: float ): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index b27c573ab3..31a3a0505a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -988,7 +988,16 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)