Skip to content

Commit

Permalink
FX: buffer size not aligned
Browse files Browse the repository at this point in the history
  • Loading branch information
a710128 committed May 1, 2022
1 parent 751aa39 commit 81f1a03
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions bmtrain/block_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@ def enter(self):

storage_type = local_param.storage_type()

self._param_buffer[kw] = storage_type(val["total"])
self._param_buffer[kw] = storage_type(val["partition_size"] * config["world_size"])
self._param_tensor[kw] = torch.tensor([], dtype=self._param_buffer[kw].dtype, device=self._param_buffer[kw].device).set_(self._param_buffer[kw])

if requires_grad and local_param.requires_grad:
self._grad_buffer[kw] = storage_type(val["total"])
self._grad_buffer[kw] = storage_type(val["partition_size"] * config["world_size"])
self._grad_tensor[kw] = torch.tensor([], dtype=self._grad_buffer[kw].dtype, device=self._grad_buffer[kw].device).set_(self._grad_buffer[kw]).zero_()

nccl.groupStart()
Expand Down
4 changes: 2 additions & 2 deletions bmtrain/inspect/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def inspect_checkpoint_block(model : CheckpointBlock, param_name : str, prefix :
for kw, val in model._storage_info.items():
storage_type = model._storage_params[kw].storage_type()

_param_buffer[kw] = storage_type(val["total"])
_param_buffer[kw] = storage_type(val["partition_size"] * config['world_size'])
if model._storage_params[kw].grad is not None:
_grad_buffer[kw] = storage_type(val["total"])
_grad_buffer[kw] = storage_type(val["partition_size"] * config['world_size'])

nccl.groupStart()
for kw, val in model._storage_info.items():
Expand Down

0 comments on commit 81f1a03

Please sign in to comment.