Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use scan and hostoffloading for llama model #123

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 32 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,26 @@ and attributes where this model code came from, if any. This also helps to
show case what changes we have done to make it performant on TPU. The original
version is not expected to be run.

## Run huggingface transformer models
## Contributing

Contributions are welcome! Please feel free to submit a pull request.

When developing, use `pip install -e '.[dev]'` to install dev dependencies such
as linter and formatter.

### How to run tests:

```sh
pytest
```

### How to run some of the tests, and re-run them whenever you change a file:

```sh
tp -i test ... # replace with path to tests/directories
```

### How to run HuggingFace transformer models
Torchprime supports run with huggingface models by taking advantage of `tp run`.
To use huggingface models, you can clone
[huggingface/transformers](https://github.com/huggingface/transformers) under
Expand All @@ -137,32 +156,26 @@ add flag `--use-hf` to `tp run` command:
tp run --use-hf torchprime/hf_models/train.py
```

## Contributing

Contributions are welcome! Please feel free to submit a pull request.

When developing, use `pip install -e '.[dev]'` to install dev dependencies such
as linter and formatter.

How to run tests:

```sh
pytest
### How to run inside the docker container locally
You can also run locally without XPK with docker. When running inside the docker
container, it will use the same dependencies and build process as used in the
XPK approach, improving the hermeticity and reliability.
```

How to run some of the tests, and re-run them whenever you change a file:

```sh
tp -i test ... # replace with path to tests/directories
tp docker-run torchprime/torch_xla_models/train.py
```
This will run the TorchPrime docker image locally. You can also add `--use-hf`
to run HuggingFace model locally.
```
tp docker-run --use-hf torchprime/hf_models/train.py
```

How to format:
### How to format:

```sh
ruff format
```

How to lint:
### How to lint:

```sh
ruff check [--fix]
Expand Down
7 changes: 4 additions & 3 deletions torchprime/hf_models/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ train_script:
args:
dataset_name: "wikitext"
dataset_config_name: "wikitext-103-raw-v1"
per_device_train_batch_size: 256 # this is global batch size if use minibatch
per_device_train_batch_size: 1024 # this is global batch size if use minibatch
do_train: true
output_dir: "test-clm"
overwrite_output_dir: true
Expand All @@ -31,5 +31,6 @@ train_script:
torch_dtype: "bfloat16"
dataloader_drop_last: true
flash_attention: true
max_steps: 50
seed: 42
max_steps: 500
seed: 42
ignore_data_skip: true
3 changes: 2 additions & 1 deletion torchprime/launcher/buildpush.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def buildpush(
_run(
f"{sudo_cmd} docker tag {docker_tag} {docker_url}",
)
_run(f"{sudo_cmd} docker push {docker_url}")
if torchprime_docker_tag != "local_run":
_run(f"{sudo_cmd} docker push {docker_url}")
except subprocess.CalledProcessError as e:
print(f"Error running command: {e}")
exit(e.returncode)
Expand Down
62 changes: 57 additions & 5 deletions torchprime/launcher/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@
import torchprime.launcher.doctor
from torchprime.launcher.buildpush import buildpush

_DOCKER_ENV_FORWARD_LIST = [
"HF_TOKEN",
"XLA_IR_DEBUG",
"XLA_HLO_DEBUG",
"LIBTPU_INIT_ARGS",
]


@dataclass_json
@dataclass
Expand Down Expand Up @@ -194,6 +201,55 @@ def create_and_activate_gcloud(gcloud_config_name, config: Config):
)


@cli.command(
name="docker-run",
context_settings=dict(
ignore_unknown_options=True,
),
)
@click.argument("args", nargs=-1, type=click.UNPROCESSED)
@click.option("--use-hf", is_flag=True, help="Use HuggingFace transformer")
def docker_run(args, use_hf: bool):
"""
Runs the provided training command locally for quick testing.
"""
config = read_config()

click.echo(get_project_dir().absolute())

# Build docker image.
build_arg = "USE_TRANSFORMERS=true" if use_hf else None
docker_project = config.docker_project
if docker_project is None:
docker_project = config.project
docker_url = buildpush(
docker_project, torchprime_docker_tag="local_run", build_arg=build_arg
)
# Forward a bunch of important env vars.
env_forwarding = [
arg for env_var in _DOCKER_ENV_FORWARD_LIST for arg in forward_env(env_var)
]
command = [
"python",
] + list(args)
docker_command = [
"docker",
"run",
"-i",
*env_forwarding,
"--privileged",
"--net",
"host",
"--rm",
"-v",
f"{os.getcwd()}:/workspace",
"-w",
"/workspace",
docker_url,
] + command
subprocess.run(docker_command, check=True)


@cli.command(
context_settings=dict(
ignore_unknown_options=True,
Expand Down Expand Up @@ -235,12 +291,8 @@ def run(args, name: str | None, use_hf: bool):

# Forward a bunch of important env vars.
env_forwarding = [
*forward_env("HF_TOKEN"), # HuggingFace token
*forward_env("XLA_IR_DEBUG"), # torch_xla debugging flag
*forward_env("XLA_HLO_DEBUG"), # torch_xla debugging flag
*forward_env("LIBTPU_INIT_ARGS"), # XLA flags
arg for env_var in _DOCKER_ENV_FORWARD_LIST for arg in forward_env(env_var)
]

# Pass artifact dir and jobset name as env vars.
artifact_arg = [
"--env",
Expand Down
31 changes: 30 additions & 1 deletion torchprime/sharding/shard_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,33 @@

import torch.nn


@torch.library.custom_op("xla::aot_mark_sharding", mutates_args=())
def aot_mark_sharding(t: torch.Tensor, partition_spec: str) -> torch.Tensor:
import torch_xla
if t is None:
return None
import ast
mesh = torch_xla.distributed.spmd.get_global_mesh()
partition_spec_eval = ast.literal_eval(partition_spec)
torch_xla.distributed.spmd.mark_sharding(
t, mesh, partition_spec_eval)
return t.clone()

@aot_mark_sharding.register_fake
def aot_mark_sharding_fake(t: torch.Tensor, partition_spec: str) -> torch.Tensor:
if t is None:
return None
return torch.empty_like(t)


def aot_mark_sharding_backward(ctx, grad):
return grad, None


aot_mark_sharding.register_autograd(aot_mark_sharding_backward)


ShardWeightFn = Callable[[torch.Tensor, str], torch.Tensor]
"""
ShardWeightFn optionally transforms a weight tensor based on its name.
Expand Down Expand Up @@ -216,7 +243,9 @@ def shard_fn(tensor, spec: tuple[str, ...]):
the_mesh = mesh if mesh is not None else xs.get_global_mesh()
assert the_mesh is not None, "No mesh found"
# TODO(https://github.com/pytorch/xla/issues/8678): Shard the gradient too.
return xs.mark_sharding(tensor, the_mesh, spec).global_tensor
# Previously we use xs.mark_sharding(tensor, the_mesh, spec).global_tensor.
# However, this is not supported for AOT compilation.
return aot_mark_sharding(tensor, str(spec))

return shard_model_from_config(model, config, shard_fn)

Expand Down
10 changes: 5 additions & 5 deletions torchprime/torch_xla_models/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ defaults:
- model: llama-3-8b # refers to model/llama-3-8b.yaml

dataset_name: wikitext
dataset_config_name: wikitext-2-raw-v1
global_batch_size: 8
logging_steps: 10
max_steps: 15
dataset_config_name: wikitext-103-raw-v1
global_batch_size: 1024
logging_steps: 1
max_steps: 500
block_size: 8192
cache_dir: /tmp/
seed: 42
Expand All @@ -27,6 +27,6 @@ lr_scheduler:
warmup_steps: 0
mesh:
dcn: 1
fsdp: 4
fsdp: 256
tensor: 1
expert: 1
1 change: 1 addition & 0 deletions torchprime/torch_xla_models/configs/model/llama-3-8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ attention_dropout: false
attention_bias: false
flash_attention: true
rope_theta: 500000.0
scan_decoder_layers: true
Copy link
Collaborator Author

@zpcore zpcore Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to default yaml file

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
activation_checkpoint_layers:
- LlamaDecoderLayer
# activation_checkpoint_layers:
# - LlamaDecoderLayer

# Refer to https://github.com/pytorch/xla/issues/6379 for backward optimization barrier info.
optimization_barrier_layers:
Expand All @@ -23,3 +23,4 @@ sharding:
# Activations
model.layers.*: [fsdp, null, null]
lm_head: [fsdp, null, null]

Loading