Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for tensor parallel using Pytorch 2.0 #34194

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions src/transformers/models/granite/modeling_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,6 +1127,61 @@ def forward(
attentions=outputs.attentions,
)

def apply_tensor_parallel(self, sub_mesh: torch.distributed._tensor.DeviceMesh):
"""
Applies tensor parallelism to the model using Pytorch 2.0 APIs.

Args:
sub_mesh (`torch.distributed._tensor.DeviceMesh`):
Pytorch 2.0 device mesh for abstracted view of process groups.
To be consumed by TP.
"""
# moving the model to a GPU is needed to apply Pytorch 2.0 TP
# with CUDA device mesh
self.model = self.model.to("cuda")

# conditional import of Pytorch 2.0 APIs for TP
from torch.distributed._tensor import Replicate
from torch.distributed.tensor.parallel import (
ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,
parallelize_module,
)

# device mesh to be prepared outside this function and passed
tensor_parallel_mesh = sub_mesh

# TP plan for layers not part of transformer blocks
init_tp_plan = {
"embed_tokens": RowwiseParallel(input_layouts=Replicate()),
}

# apply the initial tp plan to the model
self.model = parallelize_module(self.model, tensor_parallel_mesh, init_tp_plan)

# plan for each transformer block
layer_tp_plan = {
"self_attn": PrepareModuleInput(),
"self_attn.q_proj": ColwiseParallel(),
"self_attn.k_proj": ColwiseParallel(),
"self_attn.v_proj": ColwiseParallel(),
"self_attn.o_proj": RowwiseParallel(),
"mlp": PrepareModuleInput(),
"mlp.gate_proj": ColwiseParallel(),
"mlp.up_proj": ColwiseParallel(),
"mlp.down_proj": RowwiseParallel(),
}

# apply the plan to each transformer block
for _, llama_layer in enumerate(self.model.layers):
attn_layer = llama_layer.self_attn
# adjust the number heads based on the local TP shard
attn_layer.num_heads = attn_layer.num_heads // tensor_parallel_mesh.mesh.shape[0]
attn_layer.num_key_value_heads = attn_layer.num_key_value_heads // tensor_parallel_mesh.mesh.shape[0]

llama_layer = parallelize_module(llama_layer, tensor_parallel_mesh, layer_tp_plan)

@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
Expand Down
58 changes: 57 additions & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ def forward(

# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
Expand Down Expand Up @@ -1233,6 +1233,62 @@ def forward(
attentions=outputs.attentions,
)

# Copied from transformers.models.granite.modeling_granite.GraniteForCausalLM.apply_tensor_parallel with Granite->Llama
def apply_tensor_parallel(self, sub_mesh: torch.distributed._tensor.DeviceMesh):
"""
Applies tensor parallelism to the model using Pytorch 2.0 APIs.

Args:
sub_mesh (`torch.distributed._tensor.DeviceMesh`):
Pytorch 2.0 device mesh for abstracted view of process groups.
To be consumed by TP.
"""
# moving the model to a GPU is needed to apply Pytorch 2.0 TP
# with CUDA device mesh
self.model = self.model.to("cuda")

# conditional import of Pytorch 2.0 APIs for TP
from torch.distributed._tensor import Replicate
from torch.distributed.tensor.parallel import (
ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,
parallelize_module,
)

# device mesh to be prepared outside this function and passed
tensor_parallel_mesh = sub_mesh

# TP plan for layers not part of transformer blocks
init_tp_plan = {
"embed_tokens": RowwiseParallel(input_layouts=Replicate()),
}

# apply the initial tp plan to the model
self.model = parallelize_module(self.model, tensor_parallel_mesh, init_tp_plan)

# plan for each transformer block
layer_tp_plan = {
"self_attn": PrepareModuleInput(),
"self_attn.q_proj": ColwiseParallel(),
"self_attn.k_proj": ColwiseParallel(),
"self_attn.v_proj": ColwiseParallel(),
"self_attn.o_proj": RowwiseParallel(),
"mlp": PrepareModuleInput(),
"mlp.gate_proj": ColwiseParallel(),
"mlp.up_proj": ColwiseParallel(),
"mlp.down_proj": RowwiseParallel(),
}

# apply the plan to each transformer block
for _, llama_layer in enumerate(self.model.layers):
attn_layer = llama_layer.self_attn
# adjust the number heads based on the local TP shard
attn_layer.num_heads = attn_layer.num_heads // tensor_parallel_mesh.mesh.shape[0]
attn_layer.num_key_value_heads = attn_layer.num_key_value_heads // tensor_parallel_mesh.mesh.shape[0]

llama_layer = parallelize_module(llama_layer, tensor_parallel_mesh, layer_tp_plan)


@add_start_docstrings(
"""
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@
DistributedDataParallelKwargs,
DistributedType,
GradientAccumulationPlugin,
TorchTensorParallelPlugin,
load_fsdp_model,
load_fsdp_optimizer,
save_fsdp_model,
Expand Down Expand Up @@ -4891,6 +4892,11 @@ def create_accelerator_and_postprocess(self):
args["dataloader_config"] = dataloader_config
else:
args.update(accelerator_config)
# tp is initialized at Accelerator init phase so
# args should be prepared here
if self.args.tp_size > 1:
self.is_tp_enabled = True
args["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=self.args.tp_size)

# create accelerator object
self.accelerator = Accelerator(**args)
Expand All @@ -4905,7 +4911,7 @@ def create_accelerator_and_postprocess(self):
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None

self.is_tp_enabled = getattr(self.accelerator.state, "tp_plugin", None) is not None
# post accelerator creation setup
if self.is_fsdp_enabled:
fsdp_plugin = self.accelerator.state.fsdp_plugin
Expand Down
16 changes: 15 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,9 @@ class TrainingArguments:
Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be
used when the xla flag is set to true, and an auto wrapping policy is specified through
fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap.

tp_size (`int`, *optional*):
Use tp_size to enable pytorch 2.0 tensor parallelism. Set a value greater than 1 to activate TP. The same is
used to prepare device mesh internally.
deepspeed (`str` or `dict`, *optional*):
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
Expand Down Expand Up @@ -1237,6 +1239,16 @@ class TrainingArguments:
)
},
)
tp_size: Optional[int] = field(
default=0,
metadata={
"help": (
"Use tp_size to enable pytorch 2.0 tensor parallelism."
"Set a value greater than 1 to activate TP."
"The same is used to prepare device mesh internally."
)
},
)
fsdp_transformer_layer_cls_to_wrap: Optional[str] = field(
default=None,
metadata={
Expand Down Expand Up @@ -1935,6 +1947,8 @@ def __post_init__(self):
if self.fsdp_config["xla_fsdp_grad_ckpt"]:
warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")

if self.tp_size > 1:
os.environ["ACCELERATE_USE_TP"] = "true"
# accelerate integration for FSDP
if len(self.fsdp) > 0 and is_accelerate_available("0.28.0"):
os.environ["ACCELERATE_USE_FSDP"] = "true"
Expand Down