diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a97489e3..1d54a477 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,4 +27,11 @@ jobs: - name: Run tests run: | pytest + - name: Upload coverage data to coveralls.io + run: | + python -m pip install coveralls[toml] + coveralls --service=github + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + diff --git a/.gitignore b/.gitignore index 0bd78b69..ff12e477 100644 --- a/.gitignore +++ b/.gitignore @@ -55,6 +55,7 @@ htmlcov/ .cache nosetests.xml coverage.xml +coverage_html_report *.cover .hypothesis/ .pytest_cache/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 74106534..ed55664e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -31,6 +31,7 @@ pre-commit install --install-hooks - Make sure your code passes all the tests and pre-commit hooks. Use `pytest` from within the root of your local repository. +- For vscode users, disable pytest coverage in `settings.json` to enable pytest debugging: `"python.testing.pytestArgs": ["--no-cov"]` ## Commit Guidelines diff --git a/README.md b/README.md index ced552c2..0573a4ba 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,10 @@ # Modalities +[![Coverage Status](https://coveralls.io/repos/github/Modalities/modalities/badge.svg)](https://coveralls.io/github/Modalities/modalities) # Getting started For training and evaluation a model, feel free to checkout [this](https://github.com/Modalities/modalities/blob/main/examples/getting_started/getting_started_example.md) getting started tutorial, in which we train a small, 60M-parameter GPT model on a tiny subset of the Redpajama V2 dataset. -Also, see our WIki and API reference documentation: https://modalities.github.io/modalities/ +Also, see our Wiki and API reference documentation: https://modalities.github.io/modalities/ # Installation @@ -19,7 +20,7 @@ then, install the repository via pip install -e . ``` -If you want to contribute, have look at `CONTRIBUTING.md`. +If you want to contribute, have a look at `CONTRIBUTING.md`. @@ -56,12 +57,12 @@ Or, if you are a VsCode user, add this to your `launch.json`: # Pydantic and ClassResolver -The mechanismn introduced to instantiate classes via `type_hint` in the `config.yaml`, utilizes +The mechanism introduced to instantiate classes via `type_hint` in the `config.yaml`, utilizes 1) Omegaconf to load the config yaml file 2) Pydantic for the validation of the config 3) ClassResolver to instantiate the correct, concrete class of a class hierarchy. -Firstly, Omegaconf loads the config yaml file and resolves internal refrences such as `${subconfig.attribue}`. +Firstly, Omegaconf loads the config yaml file and resolves internal references such as `${subconfig.attribute}`. Then, Pydantic validates the whole config as is and checks that each of the sub-configs are `pydantic.BaseModel` classes. For configs, which allow different concrete classes to be instantiated by `ClassResolver`, the special member names `type_hint` and `config` are introduced. @@ -79,7 +80,7 @@ activation_kwargs={...} activation_resolver.make(type_hint, activation_kwargs), ``` -In our implmentation we go a step further, as both, +In our implementation we go a step further, as both, * a `type_hint` in a `BaseModel` config must be of type `modalities.config.lookup_types.LookupEnum` and * `config` is a union of allowed concrete configs of base type `BaseModel`. `config` hereby replaces `activation_kwargs` in the example above, and replaces it with pydantic-validated `BaseModel` configs. @@ -88,7 +89,8 @@ With this, a mapping between type hint strings needed for `class-resolver`, and ```python from enum import Enum -from pydantic import BaseModel, PositiveInt, PositiveFloat, conint, confloat +from typing import Annotated +from pydantic import BaseModel, PositiveInt, PositiveFloat, Field class LookupEnum(Enum): @classmethod @@ -101,8 +103,8 @@ class SchedulerTypes(LookupEnum): ConstantLR = torch.optim.lr_scheduler.ConstantLR class StepLRConfig(BaseModel): - step_size: conint(ge=1) - gamma: confloat(ge=0.0) + step_size: Annotated[int, Field(strict=True, ge=1)] + gamma: Annotated[float, Field(strict=True, ge=0.0)] class ConstantLRConfig(BaseModel): @@ -115,7 +117,7 @@ class SchedulerConfig(BaseModel): config: StepLRConfig | ConstantLRConfig ``` -To allow a user-friendly instantiation, all class resolvers are defined in the `ResolverRegistry` and `build_component_by_config` as convenience function is introduced. Dependecies can be passed-through with the `extra_kwargs` argument: +To allow a user-friendly instantiation, all class resolvers are defined in the `ResolverRegistry` and `build_component_by_config` as convenience function is introduced. Dependencies can be passed-through with the `extra_kwargs` argument: ```python resolvers = ResolverRegister(config=config) optimizer = ... # our example dependency @@ -187,20 +189,20 @@ Alternatively, directly use `src/modalities/__main__.py do_stuff --config_file_p The `MemMapDataset` requires an index file providing the necessary pointers into the raw data file. The `MemMapDataset` can create the index file lazily, however, it is advised to create it beforehand. This can be done by running ```sh -modalities create_memmap_index +modalities data create_raw_index ``` -The index will be created in the same directory as the raw data file. For further options you may look into the usage documentation via `modalities create_memmap_index --help`. +The index will be created in the same directory as the raw data file. For further options you may look into the usage documentation via `modalities data create_raw_index --help`. ## Packed Dataset Generator The `PackedMemMapDatasetContinuous` and `PackedMemMapDatasetMegatron` require a packed data file. To create the data file, you first have to generate a `MemMapDataset` index file as described [above](#memmapdataset-index-generator). Assuming the index and raw data are located in the same directory, you can simply execute the following command: ```sh -modalities create_packed_data +modalities data pack_encoded_data ``` -The packed data file will be created in the same directory as the raw data file. For further options you may look into the usage documentation via `modalities create_packed_data --help`. +The packed data file will be created in the same directory as the raw data file. For further options you may look into the usage documentation via `modalities data pack_encoded_data --help`. ### Packed Data Format diff --git a/benchmarks/dataloader/README.md b/benchmarks/dataloader/README.md new file mode 100644 index 00000000..850e8f06 --- /dev/null +++ b/benchmarks/dataloader/README.md @@ -0,0 +1,77 @@ +# Benchmarking of Dataset Implementations + +## Motivation +We want to include a storage efficient, fast and generic dataset implementation in this repository. +Previous work and ideas were based on MegatronLM and its dataset implementation. + +Unfortunately its usage is quite intransparent and causes regularly unexpected side effects. +Those problems are hard to trace, as we are not the original authors of the code. + +Therefore we want to provide an own implementation, which comes with all the above mentioned benefits. +Most importantly, it should be at least as fast as MegatronLM's implementation. + + +## Benchmark Overview + +We want to evaluate multiple aspects of the dataset implementations: +* preparation speed - All datasets need to do some initial steps like tokenization and indexing. +* initialization speed - When firing up a respective `Dataset` object inside the code. +* iteration speed - When accessing elements (in a random order) in the respective datasets + + +## Used Example Dataset + +The experiments were conducted on a small sample of openwebtext. The data is provided in `.jsonl`-format. +The relevant data included can be found under `"text"` and is obviously text-only. +Each dataset with X samples refers to the first X lines in the full openwebtext data, + as it can be obtained from huggingface. + + +## Experimental Setup + +We relied on the functions provided in `launch_benchmark.sh`. One can reproduce those by calling e.g. + +```shell +. launch_benchmark.sh + +INPUT_DIR= + +echo "MegatronLM:" +measure_megatronLM_iteration +echo "Modalities:" +measure_modalities_iteration +``` + +> For launching the preparation of MegatronLM's dataset, refer to: +> https://github.com/OpenGPTX/opengptx_data/tree/docs/modalities-vs-megatronlm-dl and look at the `launch_benchmark.sh` +> script. + +#### Glossary + +* **preparation:** refers here to the task of turning raw data (e.g. jsonl encoded text) into a binary file, + which is loadable later for training. + For MegatronLM this means tokenizing and packing everything according to their defined format. + For Modalities it means, indexing the raw data and packing it afterwards as token-ids. +* **initialization:** refers to the process of initializing a python object, + which represents the respective dataset (mostly represented via the `torch.Dataset`-interface) +* **iteration:** refers to process of iterating over the respective datasets - once sequentially and once shuffled. + +## Results + + +| Evaluation Aspect | Implementation | Required Time | # Samples in Data | +|----------------------|----------------|:------------------:|-------------------| +| preparation speed | MegatronLM | `0 min 16.965 sec` | `20000(OWT)` | +| preparation speed | Modalities | `0 min 13.904 sec` | `20000(OWT)` | +| preparation speed | MegatronLM | `2 min 11.856 sec` | `200000(OWT)` | +| preparation speed | Modalities | `0 min 38.738 sec` | `200000(OWT)` | +| initialization speed | MegatronLM | `19.3 msec` | `20000(OWT)` | +| initialization speed | Modalities | `5.85 msec` | `20000(OWT)` | +| initialization speed | MegatronLM | `180 msec ` | `200000(OWT)` | +| initialization speed | Modalities | `58 msec` | `200000(OWT)` | +| iteration speed | MegatronLM | `52.4 msec` | `20000(OWT)` | +| iteration speed | Modalities | `66.8 msec` | `20000(OWT)` | +| iteration speed | MegatronLM | `426 msec ` | `200000(OWT)` | +| iteration speed | Modalities | `545 msec` | `200000(OWT)` | + + diff --git a/benchmarks/dataloader/launch_benchmark.sh b/benchmarks/dataloader/launch_benchmark.sh new file mode 100755 index 00000000..c4e9f69d --- /dev/null +++ b/benchmarks/dataloader/launch_benchmark.sh @@ -0,0 +1,87 @@ +#!/bin/bash + + + +INPUT_DIR="/tmp/i-do-not-exist.jsonl" + + +measure_modalities_preparation() { + time ( + set -e + test -f $INPUT_DIR + rm -f ${INPUT_DIR/.jsonl/.idx} + modalities data create_raw_index $INPUT_DIR &> /dev/null + echo "finished memmap index creation" + rm -f ${INPUT_DIR/.jsonl/.pbin} + modalities data pack_encoded_data $INPUT_DIR &> /dev/null + echo "finished memmap packing" + ) +} + + +measure_modalities_initialization() { + input_file=${INPUT_DIR/.jsonl/.pbin} + python -m timeit -n 50 -r 5 -s " +import sys, io +null_device = io.StringIO() +from modalities.dataloader.dataset import PackedMemMapDatasetMegatron +from pathlib import Path +p = Path(\"${input_file}\") + " -- " +sys.stdout = null_device # deactivate stdout to avoid getting spammed +PackedMemMapDatasetMegatron(raw_data_path=p, block_size=1024, sample_key=\"sample\") +sys.stdout = sys.__stdout__ # reactivate stdout for timeit +" +} + +measure_megatronLM_initialization() { + input_file="${INPUT_DIR/.jsonl/.megLM.bin_text_document}" + python -m timeit -n 50 -r 5 -s " +import sys, io +null_device = io.StringIO() +from modalities.dataloader.open_gptx_dataset.mmap_dataset import MMapIndexedDataset +p = \"${input_file}\" + " -- " +sys.stdout = null_device # deactivate stdout to avoid getting spammed +MMapIndexedDataset(p) +sys.stdout = sys.__stdout__ # reactivate stdout for timeit +" +} + +measure_modalities_iteration() { + input_file=${INPUT_DIR/.jsonl/.pbin} + python -m timeit -n 5 -r 3 -s " +import random, sys, io +null_device = io.StringIO() +from modalities.dataloader.dataset import PackedMemMapDatasetMegatron +from pathlib import Path +p = Path(\"${input_file}\") +sys.stdout = null_device # deactivate stdout to avoid getting spammed +dataset = PackedMemMapDatasetMegatron(raw_data_path=p, block_size=1024, sample_key=\"sample\") +random_indices = random.sample(range(len(dataset)), len(dataset)) +sys.stdout = sys.__stdout__ # reactivate stdout for timeit + " -- " +list(dataset) # sequential access +for i in random_indices: + dataset[i] +" +} + + +measure_megatronLM_iteration() { + input_file="${INPUT_DIR/.jsonl/.megLM.bin_text_document}" + python -m timeit -n 5 -r 3 -s " +import random, sys, io +null_device = io.StringIO() +from modalities.dataloader.open_gptx_dataset.mmap_dataset import MMapIndexedDataset +p = \"${input_file}\" +sys.stdout = null_device # deactivate stdout to avoid getting spammed +dataset = MMapIndexedDataset(p) +random_indices = random.sample(range(len(dataset)), len(dataset)) +sys.stdout = sys.__stdout__ # reactivate stdout for timeit + " -- " +list(dataset) # sequential access +for i in random_indices: + dataset[i] +" +} \ No newline at end of file diff --git a/config_files/config.yaml b/config_files/config.yaml index 6925155b..ad57ddef 100644 --- a/config_files/config.yaml +++ b/config_files/config.yaml @@ -142,15 +142,13 @@ model: prediction_key: "logits" block_size: ${data.sequence_len} vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency - n_layer: 12 - n_head: 12 + n_layer_q: 12 + n_head_kv: 12 ffn_hidden: 2048 n_embd: 768 dropout: 0.0 bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster - attention: - attention_type: pytorch_flash_attention - scaling_factor: 3 + attention_type: pytorch_flash_attention activation: gelu epsilon: 1e-5 weight_init: diff --git a/config_files/config_example_hf_meditron_7B_instruction.yaml b/config_files/config_example_hf_meditron_7B_instruction.yaml new file mode 100644 index 00000000..590525dc --- /dev/null +++ b/config_files/config_example_hf_meditron_7B_instruction.yaml @@ -0,0 +1,199 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + training: + callback_interval_in_samples: 2048 + global_num_training_samples: 2048 + global_num_seen_samples: 0 + do_apply_activation_checkpointing: true + gradient_acc_steps: 1 + local_train_micro_batch_size: 1 + sequence_length: 4096 + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpointing_path: data/checkpoints + + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_megatron + config: + raw_data_path: /raid/s3/opengptx/max_lue/modalities/data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_gpt2_tokenized_num_samples_1050391.pbin + block_size: ${settings.training.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "train" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +val_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_megatron + config: + raw_data_path: /raid/s3/opengptx/max_lue/modalities/data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_gpt2_tokenized_num_samples_1024.pbin + block_size: ${settings.training.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +val_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "val" + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: [val_dataloader] + +checkpointing: + component_key: checkpointing + variant_key: default + config: + checkpointing_strategy: + component_key: checkpointing_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpointing_execution: + component_key: checkpointing_execution + variant_key: fsdp_to_disc_checkpointing + config: + checkpoint_path: ${settings.paths.checkpointing_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + mixed_precision_settings: BF_16 + sharding_strategy: FULL_SHARD + block_names: [LlamaDecoderLayer] + + +model: + component_key: model + variant_key: huggingface_pretrained_model + config: + model_type: AutoModelForCausalLM + model_name: epfl-llm/meditron-7b + sample_key: ${settings.referencing_keys.sample_key} + prediction_key: ${settings.referencing_keys.prediction_key} + huggingface_prediction_subscription_key: ${settings.referencing_keys.prediction_key} + kwargs: + cache_dir: /raid/s3/opengptx/max_lue/hf_cache/ + +wrapped_model: + component_key: model + variant_key: fsdp_wrapped + config: + model: + instance_key: model + pass_type: BY_REFERENCE + sync_module_states: true + mixed_precision_settings: BF_16 + sharding_strategy: FULL_SHARD + block_names: [LlamaDecoderLayer] + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +# scheduler: +# type_hint: StepLR +# config: +# step_size: 1 +# gamma: 0.1 + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0001 + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE + + +batch_progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + local_rank: ${settings.cuda_env.local_rank} + world_size: ${settings.cuda_env.world_size} + global_num_seen_samples: ${settings.training.global_num_seen_samples} + train_dataloader: + instance_key: train_dataloader + pass_type: BY_REFERENCE + eval_dataloaders: [] + + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + local_rank: ${settings.cuda_env.local_rank} + project: modalities + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: "." \ No newline at end of file diff --git a/config_files/config_example_mem_map_dataset.yaml b/config_files/config_example_mem_map_dataset.yaml index 005f2fca..2c536d67 100644 --- a/config_files/config_example_mem_map_dataset.yaml +++ b/config_files/config_example_mem_map_dataset.yaml @@ -1,144 +1,211 @@ -modalities_setup: - run_mode: FROM_SCRATCH - settings: +settings: + experiment_id: ${modalities_env:experiment_id} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + training: + callback_interval_in_samples: 32768 + global_num_training_samples: 2048 global_num_seen_samples: 0 + do_apply_activation_checkpointing: false + gradient_acc_steps: 1 + local_train_micro_batch_size: 16 + sequence_length: 4096 + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpointing_path: data/checkpoints -wandb: - project_name: modalities - mode: ONLINE -data: - sample_key: "input_ids" - target_key: "target_ids" - sequence_len: 1024 - train_dataloader: - type_hint: LLMDataLoader - config: - dataloader_tag: "train" - num_workers: 2 - pin_memory: true - shuffle: false - batch_sampler: - type_hint: BatchSampler - config: - batch_size: 8 # per rank - drop_last: false - sampler: - type_hint: DistributedSampler - config: - rank: ${training.global_rank} - num_replicas: ${training.world_size} - shuffle: true - dataset: - type_hint: PackedMemMapDatasetContinuous - config: - raw_data_path: ./data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_samples_1024_train.pbin - block_size: ${data.sequence_len} - sample_key: ${data.sample_key} - collate_fn: - type_hint: GPT2LLMCollator - config: - sample_key: ${data.sample_key} - target_key: ${data.target_key} - eval_dataloaders: - - type_hint: LLMDataLoader +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_megatron + config: + raw_data_path: /raid/s3/opengptx/max_lue/LLMgym/data/redpyjama_v2_default_DE_num_docs_16777216.pbin + block_size: ${settings.training.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "train" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default config: - dataloader_tag: "val" - num_workers: 2 - pin_memory: true - shuffle: false - batch_sampler: - type_hint: BatchSampler - config: - batch_size: 8 # per rank - drop_last: false - sampler: - type_hint: DistributedSampler - config: - rank: ${training.global_rank} - num_replicas: ${training.world_size} - shuffle: false - dataset: - type_hint: PackedMemMapDatasetContinuous - config: - raw_data_path: ./data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_samples_1024_test.pbin - block_size: ${data.sequence_len} - sample_key: ${data.sample_key} - collate_fn: - type_hint: GPT2LLMCollator - config: - sample_key: ${data.sample_key} - target_key: ${data.target_key} + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE -training: - process_group_backend: "nccl" - global_num_training_samples: 2048 - callback_interval_in_samples: 256 - local_rank: ${oc.env:LOCAL_RANK} - global_rank: ${oc.env:RANK} - world_size: ${oc.env:WORLD_SIZE} - main_rank: 0 - local_train_micro_batch_size: ${data.train_dataloader.config.batch_sampler.config.batch_size} - global_num_seen_samples: ${modalities_setup.settings.global_num_seen_samples} - gradient_acc_step: 1 - do_apply_activation_checkpointing: false +val_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_megatron + config: + raw_data_path: /raid/s3/opengptx/max_lue/LLMgym/data/redpyjama_v2_default_DE_num_docs_1024.pbin + block_size: ${settings.training.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} -checkpointing: - checkpointing_strategy: - type_hint: SaveKMostRecentCheckpointsStrategy - config: - k: -1 # -1 to save all checkpoints - checkpointing_execution: - type_hint: FSDPToDiscCheckpointing - config: - checkpoint_path: ./data/checkpoints - global_rank: ${oc.env:RANK} +val_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "val" + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: val_dataloader + pass_type: BY_REFERENCE -running_env: - type_hint: FSDPRunningEnv +checkpointing: + component_key: checkpointing + variant_key: default config: - process_group_backend: ${training.process_group_backend} - local_rank: ${oc.env:LOCAL_RANK} - mixed_precision_settings: BF_16 - sharding_strategy: FULL_SHARD - auto_wrap_policy: TRANSFORMER_AUTO_WRAP_POLICY + checkpointing_strategy: + component_key: checkpointing_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpointing_execution: + component_key: checkpointing_execution + variant_key: fsdp_to_disc_checkpointing + config: + checkpoint_path: ${settings.paths.checkpointing_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + mixed_precision_settings: BF_16 + sharding_strategy: FULL_SHARD + block_names: [GPT2Block] model: - type_hint: GPT2LLM + component_key: model + variant_key: gpt2 config: - sample_key: ${data.sample_key} - prediction_key: "logits" - block_size: ${data.sequence_len} + sample_key: ${settings.referencing_keys.sample_key} + prediction_key: ${settings.referencing_keys.prediction_key} + block_size: ${settings.training.sequence_length} vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency n_layer: 12 - n_head: 12 + n_head_q: 12 + n_head_kv: 12 ffn_hidden: 2048 n_embd: 768 dropout: 0.0 bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster - attention: - attention_type: pytorch_flash_attention - scaling_factor: 3 + attention_type: pytorch_flash_attention activation: gelu epsilon: 1e-5 weight_init: mean: 0.0 std: 0.02 -scheduler: - type_hint: StepLR +wrapped_model: + component_key: model + variant_key: fsdp_wrapped config: - step_size: 1 - gamma: 0.1 + model: + instance_key: model + pass_type: BY_REFERENCE + sync_module_states: true + mixed_precision_settings: BF_16 + sharding_strategy: FULL_SHARD + block_names: [GPT2Block] -optimizer: - type_hint: AdamW +# scheduler: +# type_hint: StepLR +# config: +# step_size: 1 +# gamma: 0.1 + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +optimizer: + component_key: optimizer + variant_key: adam_w config: lr: 0.0001 + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE +batch_progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + local_rank: ${settings.cuda_env.local_rank} + world_size: ${settings.cuda_env.world_size} + global_num_seen_samples: ${settings.training.global_num_seen_samples} + train_dataloader: + instance_key: train_dataloader + pass_type: BY_REFERENCE + eval_dataloaders: + - instance_key: val_dataloader + pass_type: BY_REFERENCE -loss: - type_hint: CLMCrossEntropyLoss + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb config: - target_key: ${data.target_key} - prediction_key: ${model.config.prediction_key} \ No newline at end of file + local_rank: ${settings.cuda_env.local_rank} + project: modalities + mode: ONLINE + experiment_id: ${settings.experiment_id} + directory: "." \ No newline at end of file diff --git a/config_files/config_example_openGPTx_dataset.yaml b/config_files/config_example_openGPTx_dataset.yaml index b5f3eef6..8f3c6e35 100644 --- a/config_files/config_example_openGPTx_dataset.yaml +++ b/config_files/config_example_openGPTx_dataset.yaml @@ -145,14 +145,13 @@ model: block_size: ${data.sequence_len} vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency n_layer: 12 - n_head: 12 + n_head_q: 12 + n_head_kv: 12 ffn_hidden: 2048 n_embd: 768 dropout: 0.0 bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster - attention: - attention_type: pytorch_flash_attention - scaling_factor: 3 + attention_type: pytorch_flash_attention activation: fused_swiglu epsilon: 1e-5 weight_init: diff --git a/config_files/config_lorem_ipsum.yaml b/config_files/config_lorem_ipsum.yaml index 9ac7c93f..7e8ffd51 100644 --- a/config_files/config_lorem_ipsum.yaml +++ b/config_files/config_lorem_ipsum.yaml @@ -1,159 +1,248 @@ -modalities_setup: - run_mode: FROM_SCRATCH - settings: +settings: + experiment_id: ${modalities_env:experiment_id} + referencing_keys: + sample_key: input_ids + target_key: target_ids + training: + callback_interval_in_samples: 6 + global_num_training_samples: 12 global_num_seen_samples: 0 + do_apply_activation_checkpointing: true + gradient_acc_steps: 1 + local_train_micro_batch_size: 3 + sequence_length: 256 + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpointing_path: data/checkpoints +tokenizer: + component_key: tokenizer + variant_key: gpt2_tokenizer_fast + config: + tokenizer_file: data/tokenizer/tokenizer_gpt2.json + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: mem_map_dataset + config: + raw_data_path: data/lorem_ipsum.jsonl + index_path: data/lorem_ipsum.idx + block_size: ${settings.training.sequence_length} + jq_pattern: ".text" + sample_key: ${settings.referencing_keys.sample_key} + tokenizer: + instance_key: tokenizer + pass_type: BY_REFERENCE -wandb: - project_name: modalities - mode: OFFLINE +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "train" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE -data: - sample_key: "input_ids" - target_key: "target_ids" - sequence_len: 128 - train_dataloader: - type_hint: LLMDataLoader - config: - num_workers: 2 - pin_memory: true - shuffle: false - dataloader_tag: "train" - dataset: - type_hint: MemMapDataset - config: - raw_data_path: data/lorem_ipsum.jsonl - index_path: data/lorem_ipsum.idx - block_size: ${data.sequence_len} - jq_pattern: ".text" - sample_key: ${data.sample_key} - tokenizer: - type_hint: GPT2TokenizerFast - config: - tokenizer_file: data/tokenizer/tokenizer.json - batch_sampler: - type_hint: BatchSampler - config: - batch_size: 3 - drop_last: false - sampler: - type_hint: DistributedSampler - config: - rank: ${training.global_rank} - num_replicas: ${training.world_size} - shuffle: true - collate_fn: - type_hint: GPT2LLMCollator - config: - sample_key: ${data.sample_key} - target_key: ${data.target_key} - eval_dataloaders: - - type_hint: LLMDataLoader +val_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "val" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default config: - num_workers: 2 - pin_memory: true - shuffle: false - dataloader_tag: "val" - dataset: ${data.train_dataloader.config.dataset} - batch_sampler: - type_hint: BatchSampler + batch_size: 3 + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler config: - batch_size: 3 - drop_last: false - sampler: - type_hint: DistributedSampler - config: - rank: ${training.global_rank} - num_replicas: ${training.world_size} - shuffle: true - collate_fn: ${data.train_dataloader.config.collate_fn} - - type_hint: LLMDataLoader + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +test_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "test" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default config: - num_workers: 2 - pin_memory: true - shuffle: false - dataloader_tag: "test" - dataset: ${data.train_dataloader.config.dataset} - batch_sampler: - type_hint: BatchSampler + batch_size: 3 + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler config: - batch_size: 3 - drop_last: false - sampler: - type_hint: DistributedSampler - config: - rank: ${training.global_rank} - num_replicas: ${training.world_size} - shuffle: true - collate_fn: ${data.train_dataloader.config.collate_fn} + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE -training: - process_group_backend: "nccl" - global_num_training_samples: 12 - callback_interval_in_samples: 6 - local_rank: ${oc.env:LOCAL_RANK} - global_rank: ${oc.env:RANK} - world_size: ${oc.env:WORLD_SIZE} - main_rank: 0 - local_train_micro_batch_size: ${data.train_dataloader.config.batch_sampler.config.batch_size} - global_num_seen_samples: ${modalities_setup.settings.global_num_seen_samples} - gradient_acc_step: 1 - do_apply_activation_checkpointing: True +eval_dataloaders: + - instance_key: val_dataloader + pass_type: BY_REFERENCE + - instance_key: test_dataloader + pass_type: BY_REFERENCE checkpointing: - checkpointing_strategy: - type_hint: SaveKMostRecentCheckpointsStrategy - config: - k: -1 # -1 to save all checkpoints - checkpointing_execution: - type_hint: FSDPToDiscCheckpointing - config: - checkpoint_path: data/checkpoints - global_rank: ${oc.env:RANK} + component_key: checkpointing + variant_key: default + config: + checkpointing_strategy: + component_key: checkpointing_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpointing_execution: + component_key: checkpointing_execution + variant_key: fsdp_to_disc_checkpointing + config: + checkpoint_path: ${settings.paths.checkpointing_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + mixed_precision_settings: BF_16 + sharding_strategy: FULL_SHARD + block_names: [GPT2Block] -loss: - type_hint: CLMCrossEntropyLoss +# resolving class types via different enums sucks... +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss config: - target_key: ${data.target_key} - prediction_key: ${model.config.prediction_key} + target_key: target_ids + prediction_key: logits -running_env: - type_hint: FSDPRunningEnv +wrapped_model: + component_key: model + variant_key: fsdp_wrapped config: - process_group_backend: ${training.process_group_backend} - local_rank: ${oc.env:LOCAL_RANK} - mixed_precision_settings: FP_16 + model: + instance_key: model + pass_type: BY_REFERENCE + sync_module_states: true + mixed_precision_settings: BF_16 sharding_strategy: FULL_SHARD - auto_wrap_policy: TRANSFORMER_AUTO_WRAP_POLICY + block_names: [GPT2Block] model: - type_hint: GPT2LLM + component_key: model + variant_key: gpt2 config: - sample_key: ${data.sample_key} - prediction_key: "logits" - block_size: ${data.sequence_len} + sample_key: "input_ids" # TODO reference this + prediction_key: "logits" # TODO reference this + block_size: 256 # TODO reference this (same as sequence length) vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency n_layer: 2 - n_head: 4 + n_head_q: 8 + n_head_kv: 2 ffn_hidden: 128 n_embd: 128 dropout: 0.0 bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster - attention: - attention_type: pytorch_flash_attention - scaling_factor: 3 + attention_type: default_attention # pytorch_flash_attention activation: gelu epsilon: 1e-5 weight_init: mean: 0.0 std: 0.02 -scheduler: - type_hint: StepLR - config: - step_size: 1 - gamma: 0.1 +# scheduler: +# type_hint: StepLR +# config: +# step_size: 1 +# gamma: 0.1 -optimizer: - type_hint: AdamW +optimizer: + component_key: optimizer + variant_key: adam_w config: lr: 0.0001 + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE + +# message subscriber + +batch_progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + local_rank: ${settings.cuda_env.local_rank} + world_size: ${settings.cuda_env.world_size} + global_num_seen_samples: ${settings.training.global_num_seen_samples} + train_dataloader: + instance_key: train_dataloader + pass_type: BY_REFERENCE + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + local_rank: ${settings.cuda_env.local_rank} + project: modalities + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: "." + \ No newline at end of file diff --git a/data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_samples_512_test.idx b/data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_gpt2_tokenized_num_samples_512_test.idx similarity index 100% rename from data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_samples_512_test.idx rename to data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_gpt2_tokenized_num_samples_512_test.idx diff --git a/data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_samples_512_test.pbin b/data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_gpt2_tokenized_num_samples_512_test.pbin similarity index 100% rename from data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_samples_512_test.pbin rename to data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_gpt2_tokenized_num_samples_512_test.pbin diff --git a/data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_samples_512_train.idx b/data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_gpt2_tokenized_num_samples_512_train.idx similarity index 100% rename from data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_samples_512_train.idx rename to data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_gpt2_tokenized_num_samples_512_train.idx diff --git a/data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_samples_512_train.pbin b/data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_gpt2_tokenized_num_samples_512_train.pbin similarity index 100% rename from data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_samples_512_train.pbin rename to data/sample_datasets/redpajama_v2/mem_map/redpajama_v2_gpt2_tokenized_num_samples_512_train.pbin diff --git a/data/tokenizer/tokenizer.json b/data/tokenizer/tokenizer_gpt2.json similarity index 100% rename from data/tokenizer/tokenizer.json rename to data/tokenizer/tokenizer_gpt2.json diff --git a/docs/source/configuration.rst b/docs/source/configuration.rst index e51335fa..85545671 100644 --- a/docs/source/configuration.rst +++ b/docs/source/configuration.rst @@ -47,7 +47,8 @@ With this, a mapping between type hint strings needed for `class-resolver`, and .. code-block:: python from enum import Enum - from pydantic import BaseModel, PositiveInt, PositiveFloat, conint, confloat + from typing import Annotated + from pydantic import BaseModel, PositiveInt, PositiveFloat, Field class LookupEnum(Enum): @classmethod @@ -60,8 +61,8 @@ With this, a mapping between type hint strings needed for `class-resolver`, and ConstantLR = torch.optim.lr_scheduler.ConstantLR class StepLRConfig(BaseModel): - step_size: conint(ge=1) - gamma: confloat(ge=0.0) + step_size: Annotated[int, Field(strict=True, ge=1)] + gamma: Annotated[float, Field(strict=True, ge=0.0)] class ConstantLRConfig(BaseModel): diff --git a/docs/source/memmap.rst b/docs/source/memmap.rst index 22793c08..84326fc4 100644 --- a/docs/source/memmap.rst +++ b/docs/source/memmap.rst @@ -14,9 +14,9 @@ The :python:`MemMapDataset` requires an index file providing the necessary point .. code-block:: bash - modalities create_memmap_index + modalities data create_raw_index -The index will be created in the same directory as the raw data file. For further options you may look into the usage documentation via :bash:`modalities create_memmap_index --help`. +The index will be created in the same directory as the raw data file. For further options you may look into the usage documentation via :bash:`modalities data create_raw_index --help`. Packed Dataset Generator -------------------------------------------------------------------------------- @@ -25,9 +25,9 @@ The :python:`PackedMemMapDatasetContinuous` and :python:`PackedMemMapDatasetMega .. code-block:: bash - modalities create_packed_data + modalities data pack_encoded_data -The packed data file will be created in the same directory as the raw data file. For further options you may look into the usage documentation via :bash:`modalities create_packed_data --help`. +The packed data file will be created in the same directory as the raw data file. For further options you may look into the usage documentation via :bash:`modalities data pack_encoded_data --help`. Packed Data Format ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index d39a78a8..a8a81900 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -12,20 +12,20 @@ To start a training you need to create memmap dataset out of a jsonl file first, .. code-block:: bash # Create memmap dataset from jsonl file. - modalities create_memmap_index + modalities data create_raw_index # Create packed dataset. - modalities create_packed_data + modalities data pack_encoded_data For example, using the lorem ipsum example: .. code-block:: bash # Create memmap dataset from jsonl file. - modalities create_memmap_index data/lorem_ipsum.jsonl + modalities data create_raw_index data/lorem_ipsum.jsonl # Create packed dataset. - modalities create_packed_data data/lorem_ipsum.jsonl + modalities data pack_encoded_data data/lorem_ipsum.jsonl Training ---------------------------------------------------- diff --git a/examples/getting_started/getting_started_example.md b/examples/getting_started/README.md similarity index 81% rename from examples/getting_started/getting_started_example.md rename to examples/getting_started/README.md index a5ffdd74..53f49371 100644 --- a/examples/getting_started/getting_started_example.md +++ b/examples/getting_started/README.md @@ -11,7 +11,7 @@ As a reference, this example has the following folder structure. Folders in <> w ├── example_config.yaml ├── data │ ├── mem_map - │ │ └ + │ ├── │ └── raw │ ├── redpajama_v2_samples_512_test.jsonl │ └── redpajama_v2_samples_512_train.jsonl @@ -23,8 +23,7 @@ As a reference, this example has the following folder structure. Folders in <> w ``` ## 1. Preprocessing -A single line of the Redpajama V2 JSONL file has the structure denoted below. Since we are not interested in the meta data and quality signals for this minimal example, we consider the `raw_content` from each line without any filtering. -for model training. +A single line of the Redpajama V2 JSONL file has the structure denoted below. Since we are not interested in the meta data and quality signals for this minimal example, we consider the `raw_content` from each line without any filtering for model training. ```json { "raw_content":"Archivio Tag: 25 aprile\nSupermercati aperti 25 aprile 2019: centri commerciali e negozi a Roma, Milano, Napoli e Torino\nNell\u2019articolo odierno troverete tutte le informazioni utili su quali saranno i supermercati e le attivit\u00e0 commerciali che resteranno aperti in occasione...\nAuguri di Buon 25 Aprile 2017: frasi e pensieri originali sulla Festa della Liberazione", @@ -42,29 +41,29 @@ Firstly, we create the dataset index via cd modalities/examples/getting_started/ # train split -modalities create_memmap_index --index_path data/mem_map/redpajama_v2_samples_512_train.idx \ +modalities data create_raw_index --index_path data/mem_map/redpajama_v2_samples_512_train.idx \ data/raw/redpajama_v2_samples_512_train.jsonl # test split -modalities create_memmap_index --index_path data/mem_map/redpajama_v2_samples_512_test.idx \ +modalities data create_raw_index --index_path data/mem_map/redpajama_v2_samples_512_test.idx \ data/raw/redpajama_v2_samples_512_test.jsonl ``` -In this step, we read the JSON file as a binary file, iterate over all characters und build up the sample index (char-wisestart and end position for each JSON sample) -as determined by the `\n` character positions. The sample index is stored in the specified `index_path`. Internally, the `create_memmap_index` command -instantiates and calls the the [IndexGenerator](https://github.com/Modalities/modalities/blob/main/src/modalities/dataloader/create_index.py#L14). +In this step, we read the JSON file as a binary file, iterate over all characters and build up the sample index (char-wise start and end position for each JSON sample) +as determined by the `\n` character positions. The sample index is stored in the specified `index_path`. Internally, the `create_raw_index` command +instantiates and calls the [IndexGenerator](https://github.com/Modalities/modalities/blob/main/src/modalities/dataloader/create_index.py#L14). After having determined the index, we create the packed dataset as described below by leveraging the tokenizer, jsonl file and the created index. ```sh # train split -modalities create_packed_data --jq_pattern .raw_content \ +modalities data pack_encoded_data --jq_pattern .raw_content \ --index_path data/mem_map/redpajama_v2_samples_512_train.idx \ --dst_path data/mem_map/redpajama_v2_samples_512_train.pbin \ --tokenizer_file tokenizer/tokenizer.json \ data/raw/redpajama_v2_samples_512_train.jsonl # test split -modalities create_packed_data --jq_pattern .raw_content \ +modalities data pack_encoded_data --jq_pattern .raw_content \ --index_path data/mem_map/redpajama_v2_samples_512_test.idx \ --dst_path data/mem_map/redpajama_v2_samples_512_test.pbin \ --tokenizer_file tokenizer/tokenizer.json \ @@ -84,15 +83,21 @@ Technically, packed datasets are defined a self-contained format that stores the **Packed MemMap File Format** ``` -|--8-BYTES-HEADER--|-------------------DATA-SEGMENT-------------------|----INDEX-SEGMENT----| +|--HEADER--|-------------------DATA-SEGMENT-------------------|----INDEX-SEGMENT----| -8 bytes header: +header: =============== -specifies the size of the data segment in bytes. Since the header size is fixed to 8 bytes, -the start and end position of each segment (i.e, header, data, index) is specified. Therefore, the theoretical maximum size of the data segment -is 2^64 bytes = 18,446 peta bytes or 4600e+15 tokens or 4.6 quintillion tokens, given that a token has 4 bytes. - +Contains two elements: +* Specifies the size of the data segment in bytes. Since the header size is fixed to 8 bytes, + the start and end position of each segment (i.e, header, data, index) is specified. + Therefore, the theoretical maximum size of the data segment + is 2^64 bytes = 18,446 peta bytes or 4600e+15 tokens or 4.6 quintillion tokens, given that a token has 4 bytes. +* The size of a each represented single token in the data segment in bytes. + This values is inferred from the source data of this `.pbin` + and depends solely on the tokenizer's vocabulary used for encoding. + A 4-byte integer is used for this. +Therefore the header is always 8+4=12 bytes long. Data segment: ============= @@ -115,7 +120,7 @@ first and then divides it into chunks of size context-length. In modalities, we describe the entire training and evaluation setup (i.e., components such das model, trainer, evaluator, dataloder etc.) within a single config file. Not only does this increase reproducibility but also allows for having the entire training runs under version control. -The example config file for this experiment can be found in `examples/mem_map_redpajama_gpt/config_example_mem_map_dataset.yaml`. +The example config file for this experiment can be found in `examples/getting_started/example_config.yaml`. ## 2. Training @@ -151,8 +156,8 @@ The command can be broken down into the following parts: 7. **`run`**: - Command argument for the `modalities` executable to initiate the training. -8. **`--config_file_path config_example_mem_map_dataset.yaml`**: - - Specifies the path to the configuration file. The file `config_example_mem_map_dataset.yaml` contains mentinoed configuratino of the components, including dataset and model configurations, training parameters, etc. +8. **`--config_file_path example_config.yaml`**: + - Specifies the path to the configuration file. The file `example_config.yaml` contains the configuration of the components, including dataset and model configurations, training parameters, etc. Already during the training, the checkpoints can be found locally in `checkpoints/` and the loss and metric developments can be inspected online in [Weights&Biases](https://wandb.ai/). @@ -171,5 +176,4 @@ which opens an interactive chatting CMD interface. ``` enter prompt> Once upon a time, there was ... - ``` \ No newline at end of file diff --git a/examples/getting_started/example_config.yaml b/examples/getting_started/example_config.yaml index b4c788f6..3505b392 100644 --- a/examples/getting_started/example_config.yaml +++ b/examples/getting_started/example_config.yaml @@ -114,14 +114,13 @@ model: block_size: ${data.sequence_len} vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency n_layer: 12 - n_head: 12 + n_head_q: 12 + n_head_kv: 12 ffn_hidden: 2048 n_embd: 768 dropout: 0.0 bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster - attention: - attention_type: pytorch_flash_attention - scaling_factor: 3 + attention_type: pytorch_flash_attention activation: gelu epsilon: 1e-5 weight_init: diff --git a/examples/library_usage/README.md b/examples/library_usage/README.md new file mode 100644 index 00000000..89898c53 --- /dev/null +++ b/examples/library_usage/README.md @@ -0,0 +1,79 @@ +# Running Modalities like a package + +Modalities can be used in a library fashion by installing the package via `pip`, as described in the [README](https://github.com/Modalities/modalities?tab=readme-ov-file#installation). The framework allows for the addition of custom components to the registry at runtime without necessitating any code changes to modalities. This functionality is achieved in Modalities with the introduction of a component registry, containing all the internal components (e.g., Dataloader, Loss function etc.). To support the addition of custom components (e.g., new model architectures) at runtime, Modalities exposes a function endpoint adding custom components to the internal registry. + +A typical use case for running Modalities in package-like fashion would be to have a custom model implemented in a repository parallel to modalities. To train the model, we would register the model class and its config class within Modalities' registry and additionally provide the typical training config (see [here](https://github.com/Modalities/modalities/blob/main/examples/getting_started/example_config.yaml) for an example) that also references the new model. Since modalities is aware of the model and config class, the model can be built from the config YAML file and used for training. + +## Concrete Example + +Given the explanation above, we now provide a minimal dummy example of the process of implementing, registering and instantiating a custom component via the example of a custom collate function. +The full example code can be found [here](https://github.com/Modalities/modalities/tree/hierarchical_instantiation/examples/library_usage). + +The code for the custom collate function, its config and registering is implemented in +[main.py](https://github.com/Modalities/modalities/blob/hierarchical_instantiation/examples/library_usage/main.py). Firstly, the script implements the custom collate function by first defining the config that parameterizes the collate function. Here, we took the two attributes from the original [GPT2LLMCollateFnConfig]() and added the custom field `custom_attribute`. + +```python + class CustomGPT2LLMCollateFnConfig(BaseModel): + sample_key: str + target_key: str + custom_attribute: str +``` + +The collate function implements the `CollateFnIF` interface. Its constructor expects the attributes from the previously defined `CustomGPT2LLMCollateFnConfig`. Since this is only a minimal example to demonstrate the registering of custom components, we just print the custom attribute without adding any senseful functionality. + +```python +class CustomGPT2LLMCollateFn(CollateFnIF): + def __init__(self, sample_key: str, target_key: str, custom_attribute: str): + self.sample_key = sample_key + self.target_key = target_key + self.custom_attribute = custom_attribute + + def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: + sample_tensor = torch.stack([torch.tensor(d[self.sample_key]) for d in batch]) + samples = {self.sample_key: sample_tensor[:, :-1]} + targets = {self.target_key: sample_tensor[:, 1:]} + + print(f"Custom attribute: {self.custom_attribute}") + + return DatasetBatch(targets=targets, samples=samples) +``` + +Given `CustomGPT2LLMCollateFnConfig` and `CustomGPT2LLMCollateFnConfig`, we register the new component via `add_custom_component(...)` by providing the respective component key and variant key together with the two previously defined classes. Note that even though the `component_key` and `variant_key` are in principle arbitrary, it is good practice to follow the patterns used for the internal components, as defined in [components.py](https://github.com/Modalities/modalities/blob/hierarchical_instantiation/src/modalities/registry/components.py#L64). + +```python +def main(): + # load and parse the config file + config_file_path = Path("config_lorem_ipsum.yaml") + config_dict = load_app_config_dict(config_file_path) + + # instantiate the Main entrypoint of modalities by passing in the config + modalities_main = Main(config_dict=config_dict) + + # add the custom component to modalities + modalities_main.add_custom_component( + component_key="collate_fn", + variant_key="custom_gpt_2_llm_collator", + custom_component=CustomGPT2LLMCollateFn, + custom_config=CustomGPT2LLMCollateFnConfig, + ) + # run the experiment + modalities_main.run() +``` + +Lastly, we add the `collate_fn` to the [example YAML config](https://github.com/Modalities/modalities/blob/hierarchical_instantiation/examples/library_usage/config_lorem_ipsum.yaml) with the the new collator. +```yaml +collate_fn: + component_key: collate_fn + variant_key: custom_gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + custom_attribute: "custom_value" +``` + +Given the changes above, we are now ready to run the training by executing the following bash command in the example directory. +```sh +CUDA_VISIBLE_DEVICES=0,1 torchrun --rdzv-endpoint localhost:29504 --nnodes 1 --nproc_per_node 2 main.py +``` + + diff --git a/examples/library_usage/config_lorem_ipsum.yaml b/examples/library_usage/config_lorem_ipsum.yaml new file mode 100644 index 00000000..02eeca79 --- /dev/null +++ b/examples/library_usage/config_lorem_ipsum.yaml @@ -0,0 +1,244 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + referencing_keys: + sample_key: input_ids + target_key: target_ids + training: + callback_interval_in_samples: 6 + global_num_training_samples: 12 + global_num_seen_samples: 0 + do_apply_activation_checkpointing: true + gradient_acc_steps: 1 + local_train_micro_batch_size: 3 + sequence_length: 256 + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpointing_path: data/checkpoints + +tokenizer: + component_key: tokenizer + variant_key: gpt2_tokenizer_fast + config: + tokenizer_file: tokenizer_gpt2.json + +collate_fn: + component_key: collate_fn + variant_key: custom_gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + custom_attribute: "custom_value" + +train_dataset: + component_key: dataset + variant_key: mem_map_dataset + config: + raw_data_path: ../../data/lorem_ipsum.jsonl + index_path: ../../data/lorem_ipsum.idx + block_size: ${settings.training.sequence_length} + jq_pattern: ".text" + sample_key: ${settings.referencing_keys.sample_key} + tokenizer: + instance_key: tokenizer + pass_type: BY_REFERENCE + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "train" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +val_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "val" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: 3 + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +test_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "test" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: 3 + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: val_dataloader + pass_type: BY_REFERENCE + - instance_key: test_dataloader + pass_type: BY_REFERENCE + +checkpointing: + component_key: checkpointing + variant_key: default + config: + checkpointing_strategy: + component_key: checkpointing_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpointing_execution: + component_key: checkpointing_execution + variant_key: fsdp_to_disc_checkpointing + config: + checkpoint_path: ${settings.paths.checkpointing_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + mixed_precision_settings: BF_16 + sharding_strategy: FULL_SHARD + block_names: [GPT2Block] + +# resolving class types via different enums sucks... +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: target_ids + prediction_key: logits + +wrapped_model: + component_key: model + variant_key: fsdp_wrapped + config: + model: + instance_key: model + pass_type: BY_REFERENCE + sync_module_states: true + mixed_precision_settings: BF_16 + sharding_strategy: FULL_SHARD + block_names: [GPT2Block] + +model: + component_key: model + variant_key: gpt2 + config: + sample_key: "input_ids" # TODO reference this + prediction_key: "logits" # TODO reference this + block_size: 256 # TODO reference this (same as sequence length) + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 2 + n_head_q: 4 + n_head_kv: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_type: default_attention # pytorch_flash_attention + activation: gelu + epsilon: 1e-5 + weight_init: + mean: 0.0 + std: 0.02 + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0001 + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE + +# message subscriber + +batch_progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + local_rank: ${settings.cuda_env.local_rank} + world_size: ${settings.cuda_env.world_size} + global_num_seen_samples: ${settings.training.global_num_seen_samples} + train_dataloader: + instance_key: train_dataloader + pass_type: BY_REFERENCE + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + local_rank: ${settings.cuda_env.local_rank} + project: modalities + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: "." + \ No newline at end of file diff --git a/examples/library_usage/main.py b/examples/library_usage/main.py new file mode 100644 index 00000000..cb03eb63 --- /dev/null +++ b/examples/library_usage/main.py @@ -0,0 +1,55 @@ +from pathlib import Path +from typing import Dict, List + +import torch +from pydantic import BaseModel + +from modalities.__main__ import Main +from modalities.batch import DatasetBatch +from modalities.config.config import load_app_config_dict +from modalities.models.gpt2.collator import CollateFnIF + + +class CustomGPT2LLMCollateFnConfig(BaseModel): + sample_key: str + target_key: str + custom_attribute: str + + +class CustomGPT2LLMCollateFn(CollateFnIF): + def __init__(self, sample_key: str, target_key: str, custom_attribute: str): + self.sample_key = sample_key + self.target_key = target_key + self.custom_attribute = custom_attribute + + def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: + sample_tensor = torch.stack([torch.tensor(d[self.sample_key]) for d in batch]) + samples = {self.sample_key: sample_tensor[:, :-1]} + targets = {self.target_key: sample_tensor[:, 1:]} + + print(f"Custom attribute: {self.custom_attribute}") + + return DatasetBatch(targets=targets, samples=samples) + + +def main(): + # load and parse the config file + config_file_path = Path("config_lorem_ipsum.yaml") + config_dict = load_app_config_dict(config_file_path) + + # instantiate the Main entrypoint of modalities by passing in the config + modalities_main = Main(config_dict=config_dict, config_path=config_file_path) + + # add the custom component to modalities + modalities_main.add_custom_component( + component_key="collate_fn", + variant_key="custom_gpt_2_llm_collator", + custom_component=CustomGPT2LLMCollateFn, + custom_config=CustomGPT2LLMCollateFnConfig, + ) + # run the experiment + modalities_main.run() + + +if __name__ == "__main__": + main() diff --git a/examples/library_usage/run.sh b/examples/library_usage/run.sh new file mode 100644 index 00000000..89effdc7 --- /dev/null +++ b/examples/library_usage/run.sh @@ -0,0 +1,3 @@ +#!/bin/sh + +CUDA_VISIBLE_DEVICES=0,1 torchrun --rdzv-endpoint localhost:29504 --nnodes 1 --nproc_per_node 2 main.py \ No newline at end of file diff --git a/examples/pretraining_llama2/train.sh b/examples/pretraining_llama2/train.sh new file mode 100644 index 00000000..c2393e4e --- /dev/null +++ b/examples/pretraining_llama2/train.sh @@ -0,0 +1,3 @@ +#!/bin/sh + +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 torchrun --rdzv-endpoint localhost:29504 --nnodes 1 --nproc_per_node 6 $(which modalities) run --config_file_path config_example_hf_meditron_7B_instruction.yaml \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index caa789cd..1e3fa5c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ [project.optional-dependencies] linting = ["pre-commit"] -tests = ["pytest"] +tests = ["pytest", "pytest-cov"] [project.scripts] modalities = "modalities.__main__:main" @@ -45,3 +45,35 @@ line_length = 120 [tool.ruff] line-length = 120 + +[tool.pytest.ini_options] +addopts = "--cov=src --cov-report term --cov-report html" + +[tool.coverage.run] +branch = true +omit = ["*/src/modalities/dataloader/open_gptx_dataset/*"] + +[tool.coverage.report] +# Regexes for lines to exclude from consideration +exclude_also = [ + # Don't complain about missing debug-only code: + "def __repr__", + "if self\\.debug", + + # Don't complain if tests don't hit defensive assertion code: + "raise AssertionError", + "raise NotImplementedError", + + # Don't complain if non-runnable code isn't run: + "if 0:", + "if __name__ == .__main__.:", + + # Don't complain about abstract methods, they aren't run: + "@(abc\\.)?abstractmethod", + ] + + +ignore_errors = true + +[tool.coverage.html] +directory = "coverage_html_report" \ No newline at end of file diff --git a/scripts/train.sh b/scripts/train.sh index 1f110c27..142a0e33 100644 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -1,3 +1,3 @@ #!/bin/sh -CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --rdzv-endpoint localhost:29504 --nnodes 1 --nproc_per_node 8 $(which modalities) run --config_file_path /raid/s3/opengptx/max_lue/modalities/config_files/config_example_mem_map_dataset.yaml \ No newline at end of file +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 torchrun --rdzv-endpoint localhost:29504 --nnodes 1 --nproc_per_node 6 $(which modalities) run --config_file_path ../config_files/config_example_mem_map_dataset.yaml \ No newline at end of file diff --git a/src/modalities/__main__.py b/src/modalities/__main__.py index 8f709130..1b712250 100644 --- a/src/modalities/__main__.py +++ b/src/modalities/__main__.py @@ -1,42 +1,32 @@ #!/usr/bin/env python import logging +import os +import shutil from pathlib import Path -from typing import Dict, List, Tuple +from typing import Dict, Tuple import click import click_pathlib -import torch -import torch.nn as nn -from omegaconf import OmegaConf -from torch.optim import Optimizer from modalities.activation_checkpointing import apply_activation_checkpointing_inplace from modalities.batch import EvaluationResultBatch -from modalities.checkpointing.checkpointing import Checkpointing, CheckpointingIF -from modalities.checkpointing.checkpointing_factory import CheckpointingFactory -from modalities.config.config import AppConfig, ModalitiesSetupConfig, RunMode -from modalities.config.lookup_types import TokenizerTypes +from modalities.config.component_factory import ComponentFactory +from modalities.config.config import ComponentsModel, ProcessGroupBackendType, TokenizerTypes, load_app_config_dict from modalities.dataloader.create_index import IndexGenerator -from modalities.dataloader.create_packed_data import PackedDataGenerator -from modalities.dataloader.dataloader import LLMDataLoader -from modalities.dataloader.dataloader_factory import DataloaderFactory +from modalities.dataloader.create_packed_data import EmbeddedStreamData, PackedDataGenerator, join_embedded_stream_data from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader from modalities.evaluator import Evaluator from modalities.gym import Gym from modalities.logging_broker.message_broker import MessageBroker from modalities.logging_broker.messages import BatchProgressUpdate, MessageTypes from modalities.logging_broker.publisher import MessagePublisher -from modalities.logging_broker.subscriber_impl.batch_progress_subscriber import ( - DummyProgressSubscriber, - RichProgressSubscriber, -) -from modalities.logging_broker.subscriber_impl.results_subscriber import WandBEvaluationResultSubscriber -from modalities.loss_functions import Loss -from modalities.resolver_register import ResolverRegister -from modalities.running_env.fsdp.fsdp_running_env import RunningEnv +from modalities.logging_broker.subscriber import MessageSubscriberIF +from modalities.registry.components import COMPONENTS +from modalities.registry.registry import Registry +from modalities.running_env.cuda_env import CudaEnv from modalities.trainer import Trainer -from modalities.util import compute_number_of_trainable_parameters, get_date_of_run +from modalities.util import compute_number_of_trainable_parameters, get_callback_interval_in_batches_per_rank from modalities.utils.generate_text import main as generate_text_main @@ -54,8 +44,7 @@ def main() -> None: ) def entry_point_run_modalities(config_file_path: Path): config_dict = load_app_config_dict(config_file_path) - config = AppConfig.model_validate(config_dict) - main = Main(config) + main = Main(config_dict, config_file_path) main.run() @@ -83,7 +72,15 @@ def entry_point_generate_text(model_path, config_path, tokenizer_type, tokenizer generate_text_main(model_path, config_path, tokenizer, max_new_tokens, chat) -@main.command(name="create_memmap_index") +@main.group(name="data") +def data(): + """ + Collection of utilities to preprocess, analyse and modify training data. + """ + pass + + +@data.command(name="create_raw_index") @click.argument("src_path", type=Path) @click.option( "--index_path", @@ -91,7 +88,13 @@ def entry_point_generate_text(model_path, config_path, tokenizer_type, tokenizer default=None, help="output path for index. will use parent directory of src_path if none.", ) -def entry_point_create_memmap_index(src_path, index_path): +def entry_point_data_create_raw_index(src_path, index_path): + """ + Utility for indexing a large jsonl-file's content. + Background is the ability to further process the respective file without loading it, + while splitting its content line-based. This step is necessary in advance of further processing like tokenization. + It is only necessary once for a jsonl-file and allows therefore different tokenizations without re-indexing. + """ index_path = LargeFileLinesReader.default_index_path(src_path, index_path) if index_path.exists(): raise ValueError("index already exists. delete it or specify different output folder.") @@ -102,7 +105,7 @@ def entry_point_create_memmap_index(src_path, index_path): generator.create_index(index_path) -@main.command(name="create_packed_data") +@data.command(name="pack_encoded_data") @click.argument("src_path", type=Path) @click.option( "--dst_path", @@ -137,7 +140,21 @@ def entry_point_create_memmap_index(src_path, index_path): default=".text", help="jq pattern to extract the data from the json line.", ) -def entry_point_create_packed_data(src_path, dst_path, index_path, tokenizer_type, tokenizer_file, jq_pattern): +@click.option( + "--num-cpus", + type=int, + show_default=True, + default=os.cpu_count(), + help="Specify the number of tokenization workers. Default is the number of available CPUs.", +) +def entry_point_pack_encoded_data(src_path, dst_path, index_path, tokenizer_type, tokenizer_file, jq_pattern, num_cpus): + """ + Utility to encode an indexed, large jsonl-file. + + (see also `create_index` for more information) + Returns .pbin-file, which can be inserted into a training process directly + and does not require its original jsonl-file or the respective index file anymore. + """ # TODO: if we want to use alternative entrypoints together with the ResolverRegistry, # we can currently not rely on the existing class resolver. # This is based on its connection to the overall `AppConfig`. @@ -145,205 +162,139 @@ def entry_point_create_packed_data(src_path, dst_path, index_path, tokenizer_typ # This could get resolved by implementing on own ResolverRegistry for each entrypoint or adapting the existing # ResolverRegistry to work dynamically with any type-hinted config object from config.py. tokenizer = tokenizer_type.value(tokenizer_file=str(tokenizer_file)) - generator = PackedDataGenerator(src_path, index_path=index_path, tokenizer=tokenizer, jq_pattern=jq_pattern) + generator = PackedDataGenerator( + src_path, + index_path=index_path, + tokenizer=tokenizer, + jq_pattern=jq_pattern, + number_of_processes=num_cpus, + ) generator.run(dst_path) -def load_app_config_dict(config_file_path: Path) -> Dict: - cfg = OmegaConf.load(config_file_path) - logging.info(f"Config\n {OmegaConf.to_yaml(cfg, resolve=True)}") - return OmegaConf.to_container(cfg, resolve=True) +@data.command(name="merge_packed_data") +@click.argument("src_paths", type=click.types.Path(exists=True, path_type=Path), nargs=-1, required=True) +@click.argument("target_path", type=click.types.Path(file_okay=False, dir_okay=False, path_type=Path)) +def entry_point_merge_packed_data(src_paths, target_path): + """ + Utility for merging different pbin-files into one. + This is especially useful, if different datasets were at different points in time or if one encoding takes so long, + that the overall process was done in chunks. + It is important that the same tokenizer got used for all chunks. + + Specify an arbitrary amount of pbin-files and/or directory containing such as input. + """ + input_files = [] + for p in src_paths: + p: Path + if p.is_dir(): + input_files.extend(p.glob("**/*.pbin")) + else: + input_files.append(p) + embedded_datasets = list(map(EmbeddedStreamData, input_files)) + join_embedded_stream_data(embedded_datasets, target_path) class Main: - def __init__(self, config: AppConfig) -> None: - self.config = config - self.experiment_id = get_date_of_run() - - self.resolvers = ResolverRegister(config=config) - self.running_env: RunningEnv = self.resolvers.build_component_by_config(config=self.config.running_env) + def __init__(self, config_dict: Dict, config_path: Path) -> None: + self.config_dict = config_dict + self.config_path = config_path + + self.registry = Registry(COMPONENTS) + self.component_factory = ComponentFactory(registry=self.registry) + + def add_custom_component(self, component_key: str, variant_key: str, custom_component, custom_config) -> None: + self.registry.add_entity( + component_key=component_key, + variant_key=variant_key, + component_type=custom_component, + component_config_type=custom_config, + ) def run(self): - with self.running_env as running_env: - ( - gym, - train_dataloader, - eval_data_loaders, - checkpointing, - wrapped_model, - optimizer, - ) = self.construct_components(resolvers=self.resolvers, config=self.config, running_env=running_env) - - logging.info(f"Training model with {compute_number_of_trainable_parameters(wrapped_model)} parameters.") - - gym.run( - callback_interval_in_batches=self.config.training.callback_interval_in_batches_per_rank, - train_data_loader=train_dataloader, - evaluation_data_loaders=eval_data_loaders, - checkpointing=checkpointing, - model=wrapped_model, - optimizer=optimizer, + with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl): + components: ComponentsModel = self.component_factory.build_components( + config_dict=self.config_dict, components_model_type=ComponentsModel ) - def construct_components( - self, resolvers: ResolverRegister, config: AppConfig, running_env: RunningEnv - ) -> Tuple[Gym, LLMDataLoader, List[LLMDataLoader], CheckpointingIF, nn.Module, Optimizer]: - # Checkpointing - - checkpointing = CheckpointingFactory.get_checkpointing( - resolvers=self.resolvers, - config=config.checkpointing, - running_env=running_env, - experiment_id=self.experiment_id, - num_ranks=config.training.world_size, - ) - - # Model and optimizer - wrapped_model, optimizer = self.get_model_and_optimizer( - config=config, running_env=running_env, checkpointing=checkpointing - ) - if config.training.do_apply_activation_checkpointing: - apply_activation_checkpointing_inplace(wrapped_model) - logging.info("Applied activation checkpointing!") - - # Loss function - loss_fun: Loss = resolvers.build_component_by_config(config=config.loss) - - # Dataloaders - # skip_num_samples = 0 - # if run_mode == RunMode.WARM_START: - # skip_num_samples = config.modalities_setup.settings.checkpoint_num_seen_samples - - skip_num_local_train_batches = config.training.skip_num_local_train_batches - train_dataloader = DataloaderFactory.get_dataloader( - resolvers=resolvers, config=config.data.train_dataloader, skip_num_batches=skip_num_local_train_batches - ) - eval_dataloaders = [ - DataloaderFactory.get_dataloader(resolvers=resolvers, config=dataloader_config) - for dataloader_config in config.data.eval_dataloaders - ] - - # Logging - eval_split_lengths = { - dataloader.dataloader_tag: len(dataloader) * config.training.world_size * dataloader.sampler_batch_size - for dataloader in eval_dataloaders - } - - # TODO: check why not *config.training.world_size - # and consider just using config.training.num_training_samples for progress Subscriber - train_split_lengths = { - train_dataloader.dataloader_tag: (len(train_dataloader) + skip_num_local_train_batches) - * config.training.world_size - * train_dataloader.sampler_batch_size - } - - evaluation_result_publisher, batch_processed_publisher = self.get_logging_publishers( - config=config, train_split_lengths=train_split_lengths, eval_split_lengths=eval_split_lengths - ) - - # Trainer - trainer = Trainer( - local_rank=config.training.local_rank, - batch_progress_publisher=batch_processed_publisher, - evaluation_result_publisher=evaluation_result_publisher, - gradient_acc_step=config.training.gradient_acc_step, - ) - - # Evaluator - evaluator = Evaluator( - local_rank=config.training.local_rank, - batch_progress_publisher=batch_processed_publisher, - evaluation_result_publisher=evaluation_result_publisher, - ) - - # Gym - gym = Gym(trainer=trainer, evaluator=evaluator, loss_fun=loss_fun, num_ranks=config.training.world_size) - - return gym, train_dataloader, eval_dataloaders, checkpointing, wrapped_model, optimizer - - def get_model_and_optimizer( - self, config: AppConfig, running_env: RunningEnv, checkpointing: Checkpointing - ) -> Tuple[nn.Module, Optimizer]: - run_mode = config.modalities_setup.run_mode - - model: torch.nn.Module = self.resolvers.build_component_by_config(config=config.model) - - if run_mode == RunMode.WARM_START: - warm_start_settings: ModalitiesSetupConfig.WarmStartSettings = config.modalities_setup.settings - wrapped_model = checkpointing.load_model_checkpoint( - file_path=warm_start_settings.checkpoint_model_path, - model=model, + # save the config file to the checkpointing path + if components.settings.cuda_env.global_rank == 0: + experiment_path = components.settings.paths.checkpointing_path / components.settings.experiment_id + os.makedirs(experiment_path, exist_ok=True) + shutil.copy(self.config_path, experiment_path / self.config_path.name) + + evaluation_result_publisher, batch_processed_publisher = self.get_logging_publishers( + progress_subscriber=components.batch_progress_subscriber, + results_subscriber=components.evaluation_subscriber, + global_rank=components.settings.cuda_env.global_rank, + local_rank=components.settings.cuda_env.local_rank, ) - optimizer: torch.optim.Optimizer = self.resolvers.build_component_by_config( - config=config.optimizer, extra_kwargs=dict(params=wrapped_model.parameters()) + # Trainer + trainer = Trainer( + local_rank=components.settings.cuda_env.local_rank, + batch_progress_publisher=batch_processed_publisher, + evaluation_result_publisher=evaluation_result_publisher, + gradient_acc_steps=components.settings.training.gradient_acc_steps, ) - # TODO improve this - if warm_start_settings.checkpoint_optimizer_path is None: - raise ( - NotImplementedError( - "So far we always have to provide an optimizer checkpoint. " - "For fine-tuning a pre-trained, we might not want to load " - "an optimizer checkpoint." - ) - ) - - optimizer = checkpointing.load_optimizer_checkpoint( - optimizer=optimizer, model=wrapped_model, file_path=warm_start_settings.checkpoint_optimizer_path + # Evaluator + evaluator = Evaluator( + local_rank=components.settings.cuda_env.local_rank, + batch_progress_publisher=batch_processed_publisher, + evaluation_result_publisher=evaluation_result_publisher, ) - else: - wrapped_model = running_env.wrap_model(model=model, sync_module_states=False) - optimizer: torch.optim.Optimizer = self.resolvers.build_component_by_config( - config=config.optimizer, extra_kwargs=dict(params=wrapped_model.parameters()) + # Gym + gym = Gym( + trainer=trainer, + evaluator=evaluator, + loss_fun=components.loss_fn, + num_ranks=components.settings.cuda_env.world_size, ) + wrapped_model = components.wrapped_model + logging.info(f"Training model with {compute_number_of_trainable_parameters(wrapped_model)} parameters.") + + if components.settings.training.do_apply_activation_checkpointing: + apply_activation_checkpointing_inplace(wrapped_model) - # TODO implement scheduler - # scheduler = self.resolvers.build_component_by_config( - # config=config.scheduler, extra_kwargs=dict(optimizer=self.optimizer) - # ) + callback_interval_in_batches_per_rank = get_callback_interval_in_batches_per_rank( + callback_interval_in_samples=components.settings.training.callback_interval_in_samples, + local_train_micro_batch_size=components.settings.training.local_train_micro_batch_size, + gradient_acc_steps=components.settings.training.gradient_acc_steps, + world_size=components.settings.cuda_env.world_size, + ) - return wrapped_model, optimizer + gym.run( + callback_interval_in_batches=callback_interval_in_batches_per_rank, + train_data_loader=components.train_dataloader, + evaluation_data_loaders=components.eval_dataloaders, + checkpointing=components.checkpointing, + model=wrapped_model, + optimizer=components.optimizer, + ) + print("done") def get_logging_publishers( - self, config: AppConfig, train_split_lengths: Dict[str, int], eval_split_lengths: Dict[str, int] + self, + progress_subscriber: MessageSubscriberIF[BatchProgressUpdate], + results_subscriber: MessageSubscriberIF[EvaluationResultBatch], + global_rank: int, + local_rank: int, ) -> Tuple[MessagePublisher[EvaluationResultBatch], MessagePublisher[BatchProgressUpdate],]: - # Message Broker message_broker = MessageBroker() batch_processed_publisher = MessagePublisher[BatchProgressUpdate]( message_broker=message_broker, - global_rank=config.training.global_rank, - local_rank=config.training.local_rank, + global_rank=global_rank, + local_rank=local_rank, ) evaluation_result_publisher = MessagePublisher[EvaluationResultBatch]( message_broker=message_broker, - global_rank=config.training.global_rank, - local_rank=config.training.local_rank, + global_rank=global_rank, + local_rank=local_rank, ) - # TODO make logging rank configurable - # TODO: make this instantiation of subscribers configurable via config.yml and use "build_component_by_config" - if config.training.global_rank == 0: - progress_subscriber = RichProgressSubscriber( - num_ranks=config.training.world_size, - train_split_num_samples=train_split_lengths, - eval_splits_num_samples=eval_split_lengths, - ) - evaluation_result_subscriber = WandBEvaluationResultSubscriber( - num_ranks=config.training.world_size, - project=config.wandb.project_name, - experiment_id=self.experiment_id, - mode=config.wandb.mode, - dir=config.wandb.dir, - experiment_config=config, - ) - message_broker.add_subscriber( - subscription=MessageTypes.EVALUATION_RESULT, subscriber=evaluation_result_subscriber - ) - - else: - progress_subscriber = DummyProgressSubscriber() + message_broker.add_subscriber(subscription=MessageTypes.EVALUATION_RESULT, subscriber=results_subscriber) message_broker.add_subscriber( subscription=MessageTypes.BATCH_PROGRESS_UPDATE, subscriber=progress_subscriber, diff --git a/src/modalities/activation_checkpointing.py b/src/modalities/activation_checkpointing.py index a6a4fd9a..288dc09f 100644 --- a/src/modalities/activation_checkpointing.py +++ b/src/modalities/activation_checkpointing.py @@ -8,11 +8,11 @@ ) from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP -from modalities.models.gpt2.gpt2_model import Block +from modalities.models.gpt2.gpt2_model import GPT2Block def is_module_to_apply_activation_checkpointing(submodule: torch.nn.Module) -> bool: - return isinstance(submodule, Block) + return isinstance(submodule, GPT2Block) def apply_activation_checkpointing_inplace(model: torch.nn.Module): diff --git a/src/modalities/batch.py b/src/modalities/batch.py index bc6c62c0..7cf3f34e 100644 --- a/src/modalities/batch.py +++ b/src/modalities/batch.py @@ -103,12 +103,15 @@ class EvaluationResultBatch(Batch): losses: Dict[str, torch.Tensor] = field(default_factory=lambda: dict()) metrics: Dict[str, torch.Tensor] = field(default_factory=lambda: dict()) throughput_metrics: Dict[str, torch.Tensor] = field(default_factory=lambda: dict()) + def __str__(self) -> str: eval_str = ( f"Evaluation result on dataset tag {self.dataloader_tag} after {self.global_train_sample_id + 1} samples:" ) eval_str += "\n\nlosses: " + "\n\t".join([f"{k}: {v.mean().item()}" for k, v in self.losses.items()]) eval_str += "\n\nmetrics: " + "\n\t".join([f"{k}: {v.mean().item()}" for k, v in self.metrics.items()]) - eval_str += "\n\nthroughput metrics: " + "\n\t".join([f"{k}: {v.mean().item()}" for k, v in self.throughput_metrics.items()]) + eval_str += "\n\nthroughput metrics: " + "\n\t".join( + [f"{k}: {v.mean().item()}" for k, v in self.throughput_metrics.items()] + ) eval_str += "\n===============================================" return eval_str diff --git a/src/modalities/checkpointing/checkpointing.py b/src/modalities/checkpointing/checkpointing.py index 611e8e48..6d5d80b5 100644 --- a/src/modalities/checkpointing/checkpointing.py +++ b/src/modalities/checkpointing/checkpointing.py @@ -45,11 +45,9 @@ def __init__( self, checkpointing_strategy: CheckpointingStrategyIF, checkpointing_execution: CheckpointingExecutionIF, - num_ranks: int, ): self.checkpointing_strategy = checkpointing_strategy self.checkpointing_execution = checkpointing_execution - self.num_ranks = num_ranks def save_checkpoint( self, @@ -76,10 +74,10 @@ def load_model_checkpoint(self, model: nn.Module, file_path: Path) -> nn.Module: model = self.checkpointing_execution.load_model_checkpoint(model=model, file_path=file_path) return model - def load_optimizer_checkpoint(self, optimizer: Optimizer, model: nn.Module, file_path: Path) -> Optimizer: + def load_optimizer_checkpoint(self, optimizer: Optimizer, wrapped_model: nn.Module, file_path: Path) -> Optimizer: optimizer = self.checkpointing_execution.load_optimizer_checkpoint( optimizer=optimizer, - model=model, + wrapped_model=wrapped_model, file_path=file_path, ) return optimizer diff --git a/src/modalities/checkpointing/checkpointing_execution.py b/src/modalities/checkpointing/checkpointing_execution.py index 4b94c414..cfe89306 100644 --- a/src/modalities/checkpointing/checkpointing_execution.py +++ b/src/modalities/checkpointing/checkpointing_execution.py @@ -1,18 +1,19 @@ from abc import ABC, abstractmethod from enum import Enum from pathlib import Path -from typing import Callable, List +from typing import List import torch import torch.distributed as dist import torch.nn as nn from torch.distributed.fsdp import FullOptimStateDictConfig, FullStateDictConfig from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import StateDictType +from torch.distributed.fsdp import ShardingStrategy, StateDictType from torch.optim import Optimizer from modalities.checkpointing.checkpointing_instruction import CheckpointingInstruction from modalities.exceptions import CheckpointingError +from modalities.running_env.env_utils import MixedPrecisionSettings class CheckpointingEntityType(Enum): @@ -29,7 +30,7 @@ def load_model_checkpoint(self, model: nn.Module, file_path: Path) -> nn.Module: def load_optimizer_checkpoint( self, optimizer: Optimizer, - model: nn.Module, + wrapped_model: nn.Module, file_path: Path, ) -> Optimizer: raise NotImplementedError @@ -76,7 +77,9 @@ def __init__( checkpoint_path: Path, experiment_id: str, global_rank: int, - model_wrapping_fn: Callable[[nn.Module, bool], FSDP], + block_names: List[str], + mixed_precision_settings: MixedPrecisionSettings, + sharding_strategy: ShardingStrategy, ): """ Implementation of checkpointing to disc via FSDP @@ -85,13 +88,13 @@ def __init__( checkpoint_path (Path): folder path to the checkpoint experiment_id (str): ID of the experiment global_rank (int): global rank within the current process group - model_wrapping_fn (Callable[[nn.Module, bool], FSDP]): Wrapping function that wraps raw model. - For FSDP, we pass in FSDPRunningEnv.wrap_model """ self.checkpoint_path = checkpoint_path self.global_rank = global_rank - self.model_wrapping_fn = model_wrapping_fn self.experiment_id = experiment_id + self.block_names = block_names + self.mixed_precision_settings = mixed_precision_settings + self.sharding_strategy = sharding_strategy def _get_checkpointing_path( self, @@ -186,10 +189,20 @@ def load_model_checkpoint(self, model: nn.Module, file_path: Path) -> nn.Module: # load model on rank 0 into CPU RAM model_state = torch.load(file_path) model.load_state_dict(model_state) - fsdp_model = self.model_wrapping_fn(model=model, sync_module_states=True) + + # TODO nasty workaround to prevent circular imports + from modalities.models.model_factory import ModelFactory + + fsdp_model = ModelFactory.get_fsdp_wrapped_model( + model=model, + sync_module_states=True, + block_names=self.block_names, + mixed_precision_settings=self.mixed_precision_settings, + sharding_strategy=self.sharding_strategy, + ) return fsdp_model - def load_optimizer_checkpoint(self, optimizer: Optimizer, model: FSDP, file_path: Path) -> Optimizer: + def load_optimizer_checkpoint(self, optimizer: Optimizer, wrapped_model: FSDP, file_path: Path) -> Optimizer: # load optimizer full_optimizer_state_dict = None if self.global_rank == 0: @@ -198,7 +211,7 @@ def load_optimizer_checkpoint(self, optimizer: Optimizer, model: FSDP, file_path # distribute the optimizer state dict from rank 0 to all the other ranks sharded_optimizer_state_dict = FSDP.scatter_full_optim_state_dict( - full_optim_state_dict=full_optimizer_state_dict, model=model, group=None + full_optim_state_dict=full_optimizer_state_dict, model=wrapped_model, group=None ) optimizer.load_state_dict(sharded_optimizer_state_dict) diff --git a/src/modalities/checkpointing/checkpointing_factory.py b/src/modalities/checkpointing/checkpointing_factory.py deleted file mode 100644 index 138569be..00000000 --- a/src/modalities/checkpointing/checkpointing_factory.py +++ /dev/null @@ -1,36 +0,0 @@ -from modalities.checkpointing.checkpointing import ( - Checkpointing, - CheckpointingExecutionIF, - CheckpointingIF, - CheckpointingStrategyIF, -) -from modalities.config.config import CheckpointingConfig -from modalities.resolver_register import ResolverRegister -from modalities.running_env.fsdp.fsdp_running_env import RunningEnv - - -class CheckpointingFactory: - @staticmethod - def get_checkpointing( - resolvers: ResolverRegister, - config: CheckpointingConfig, - running_env: RunningEnv, - experiment_id: str, - num_ranks: int, - ) -> CheckpointingIF: - checkpointing_strategy: CheckpointingStrategyIF = resolvers.build_component_by_config( - config=config.checkpointing_strategy, extra_kwargs={} - ) - - checkpointing_execution: CheckpointingExecutionIF = resolvers.build_component_by_config( - config=config.checkpointing_execution, - extra_kwargs={"experiment_id": experiment_id, "model_wrapping_fn": running_env.wrap_model}, - ) - - checkpointing = Checkpointing( - checkpointing_strategy=checkpointing_strategy, - checkpointing_execution=checkpointing_execution, - num_ranks=num_ranks, - ) - - return checkpointing diff --git a/src/modalities/config/component_factory.py b/src/modalities/config/component_factory.py new file mode 100644 index 00000000..c3a3dfd5 --- /dev/null +++ b/src/modalities/config/component_factory.py @@ -0,0 +1,144 @@ +from typing import Any, Dict, List, Type, TypeVar, Union + +from pydantic import BaseModel + +from modalities.registry.registry import Registry + + +class ComponentFactory: + def __init__(self, registry: Registry) -> None: + self.registry = registry + + BaseModelChild = TypeVar("BaseModelChild", bound=BaseModel) + + def build_components(self, config_dict: Dict, components_model_type: Type[BaseModelChild]) -> BaseModelChild: + component_names = list(components_model_type.model_fields.keys()) + component_dict = self._build_config(config_dict=config_dict, component_names=component_names) + print(component_dict) + components = components_model_type(**component_dict) + return components + + def _build_config(self, config_dict: Dict, component_names: List[str]) -> Dict[str, Any]: + component_dict_filtered = {name: config_dict[name] for name in component_names} + components, _ = self._build_component( + current_component_config=component_dict_filtered, + component_config=config_dict, + top_level_components={}, + traversal_path=[], + ) + return components + + def _build_component( + self, + current_component_config: Union[Dict, List, Any], + component_config: Union[Dict, List, Any], + top_level_components: Dict[str, Any], + traversal_path: List, + ) -> Any: + # build sub components first via recursion + if isinstance(current_component_config, dict): + # if the entities are top level components, we return the component, + # as it must have been built already via a referencing component + if len(traversal_path) > 0 and traversal_path[-1] in top_level_components: + entity_key = traversal_path[-1] + return top_level_components[entity_key], top_level_components + # if it is not a component that has been built already, we need to build it. + # We first traverse the config for possible sub components that need to build beforehand. + materialized_component_config = {} + for sub_entity_key, sub_component_config_dict in current_component_config.items(): + materialized_component_config[sub_entity_key], top_level_components = self._build_component( + current_component_config=sub_component_config_dict, + component_config=component_config, + top_level_components=top_level_components, + traversal_path=traversal_path + [sub_entity_key], + ) + # After building all the sub components, we can now build the actual component + # if the config is component_config then we instantiate the component + if ComponentFactory._is_component_config(config_dict=current_component_config): + # instantiate component config + component_key = current_component_config["component_key"] + variant_key = current_component_config["variant_key"] + current_component_config = self._instantiate_component_config( + component_key=component_key, + variant_key=variant_key, + config_dict=materialized_component_config["config"], + ) + # instantiate component + component = self._instantiate_component( + component_key=component_key, variant_key=variant_key, component_config=current_component_config + ) + print(" -> ".join(traversal_path) + ":", component) + + # if the component is a top level component, then we add it to the top level components dictionary + # to make sure that we don't build it again. Building it again would mean that we work by-value + # instead of by reference. + if len(traversal_path) == 1: + entity_key = traversal_path[-1] + top_level_components[entity_key] = component + return component, top_level_components + + # if the config is a reference_config then check if it exists and if not, we build it + if ComponentFactory._is_reference_config(config_dict=current_component_config): + referenced_entity_key = current_component_config["instance_key"] + if referenced_entity_key not in top_level_components: + materialized_referenced_component, top_level_components = self._build_component( + current_component_config=component_config[referenced_entity_key], + component_config=component_config, + top_level_components=top_level_components, + traversal_path=[referenced_entity_key], + ) + # we add the newly build reference config to the top level components dict + # so that we don't instantiate it again when we reach the respective component config + # in the subsequent config traversal + top_level_components[referenced_entity_key] = materialized_referenced_component + print(" -> ".join(traversal_path) + ": ", f"--ref--> {top_level_components[referenced_entity_key]}") + return top_level_components[referenced_entity_key], top_level_components + + return materialized_component_config, top_level_components + + elif isinstance(current_component_config, list): + materialized_component_configs = [] + for sub_entity_key, sub_component_config in enumerate(current_component_config): + materialized_component_config, top_level_components = self._build_component( + current_component_config=sub_component_config, + component_config=component_config, + top_level_components=top_level_components, + traversal_path=traversal_path + [str(sub_entity_key)], + ) + materialized_component_configs.append(materialized_component_config) + return materialized_component_configs, top_level_components + + else: + # we return the raw sub config if the sub config is not a dictionary or a list + # i.e., just a "scalar" value (e.g., string, int, etc.), since we don't have to build it. + return current_component_config, top_level_components + + @staticmethod + def _is_component_config(config_dict: Dict) -> bool: + # TODO instead of field checks, we should introduce an enum for the config type. + return "component_key" in config_dict.keys() + + @staticmethod + def _is_reference_config(config_dict: Dict) -> bool: + # TODO instead of field checks, we should introduce an enum for the config type. + return {"instance_key", "pass_type"} == config_dict.keys() + + def _instantiate_component_config(self, component_key: str, variant_key: str, config_dict: Dict) -> BaseModel: + component_config_type: Type[BaseModel] = self.registry.get_config(component_key, variant_key) + comp_config = component_config_type(**config_dict, strict=True) + return comp_config + + def _instantiate_component(self, component_key: str, variant_key: str, component_config: BaseModel) -> Any: + component_type: Type = self.registry.get_component(component_key, variant_key) + component_config_dict = self.base_model_to_dict(component_config) + component = component_type(**component_config_dict) + return component + + @staticmethod + def base_model_to_dict(base_model: BaseModel) -> Dict: + # converts top level structure of base_model into dictionary while maintaining substructure + output = {} + for name, _ in base_model.model_fields.items(): + value = getattr(base_model, name) + output[name] = value + return output diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 0e166242..353cf14e 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -1,345 +1,336 @@ -import json -import warnings -from enum import Enum +import os from pathlib import Path -from typing import List, Optional, Union - -from pydantic import BaseModel, Field, FilePath, PositiveFloat, PositiveInt, confloat, conint, model_validator -from transformers import PretrainedConfig - -from modalities.config.lookup_types import ( - BatchSamplerTypes, - CheckpointingExectionTypes, - CheckpointingStrategyTypes, - CollatorTypes, - DataloaderTypes, - DatasetTypes, - LossTypes, - ModelTypes, - OptimizerTypes, - SamplerTypes, - SchedulerTypes, - TokenizerTypes, -) -from modalities.config.types import ProcessGroupBackendType -from modalities.models.gpt2.gpt2_model import GPT2Config -from modalities.running_env.fsdp.fsdp_running_env import RunningEnvConfig - - -class WandbConfig(BaseModel): - class WandbMode(Enum): - ONLINE = "ONLINE" - OFFLINE = "OFFLINE" - DISABLED = "DISABLED" - - project_name: str - mode: WandbMode - dir: Optional[Path] = Field(default_factory=lambda: Path(".")) +from typing import Annotated, Any, Dict, List, Optional + +import torch.nn as nn +from omegaconf import OmegaConf +from pydantic import BaseModel, Field, FilePath, GetCoreSchemaHandler, PositiveInt, field_validator +from pydantic_core import core_schema +from torch.distributed.fsdp import ShardingStrategy +from torch.optim import Optimizer +from torch.utils.data import Sampler +from torch.utils.data.dataset import Dataset +from transformers import GPT2TokenizerFast +from transformers.models.llama.tokenization_llama_fast import LlamaTokenizerFast +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast + +from modalities.checkpointing.checkpointing import CheckpointingIF +from modalities.checkpointing.checkpointing_execution import CheckpointingExecutionIF +from modalities.checkpointing.checkpointing_strategies import CheckpointingStrategyIF +from modalities.config.lookup_enum import LookupEnum +from modalities.dataloader.dataloader import LLMDataLoader +from modalities.logging_broker.subscriber import MessageSubscriberIF +from modalities.loss_functions import Loss +from modalities.models.gpt2.collator import CollateFnIF +from modalities.running_env.env_utils import MixedPrecisionSettings, has_bfloat_support +from modalities.util import get_date_of_run, parse_enum_by_name + + +class PydanticThirdPartyTypeIF: + def __init__(self, third_party_type): + self.third_party_type = third_party_type + + def __get_pydantic_core_schema__( + self, + _source_type: Any, + _handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + # see: https://docs.pydantic.dev/latest/concepts/types/#handling-third-party-types + return core_schema.json_or_python_schema( + json_schema=core_schema.is_instance_schema(self.third_party_type), + python_schema=core_schema.is_instance_schema(self.third_party_type), + # serialization=core_schema.plain_serializer_function_ser_schema( + # lambda instance: instance.x + # ), + ) + + +PydanticCheckpointingIFType = Annotated[CheckpointingIF, PydanticThirdPartyTypeIF(CheckpointingIF)] +PydanticCheckpointingStrategyIFType = Annotated[ + CheckpointingStrategyIF, PydanticThirdPartyTypeIF(CheckpointingStrategyIF) +] +PydanticCheckpointingExecutionIFType = Annotated[ + CheckpointingExecutionIF, PydanticThirdPartyTypeIF(CheckpointingExecutionIF) +] +PydanticModelIFType = Annotated[nn.Module, PydanticThirdPartyTypeIF(nn.Module)] +PydanticTokenizerIFType = Annotated[PreTrainedTokenizerFast, PydanticThirdPartyTypeIF(PreTrainedTokenizerFast)] +PydanticDatasetIFType = Annotated[Dataset, PydanticThirdPartyTypeIF(Dataset)] +PydanticSamplerIFType = Annotated[Sampler, PydanticThirdPartyTypeIF(Sampler)] +PydanticCollateFnIFType = Annotated[CollateFnIF, PydanticThirdPartyTypeIF(CollateFnIF)] +PydanticLLMDataLoaderIFType = Annotated[LLMDataLoader, PydanticThirdPartyTypeIF(LLMDataLoader)] +PydanticOptimizerIFType = Annotated[Optimizer, PydanticThirdPartyTypeIF(Optimizer)] +PydanticLossIFType = Annotated[Loss, PydanticThirdPartyTypeIF(Loss)] +PydanticMessageSubscriberIFType = Annotated[MessageSubscriberIF, PydanticThirdPartyTypeIF(MessageSubscriberIF)] + + +class ProcessGroupBackendType(LookupEnum): + nccl = "nccl" + + +class TokenizerTypes(LookupEnum): + GPT2TokenizerFast = GPT2TokenizerFast + LlamaTokenizerFast = LlamaTokenizerFast + + +class PassType(LookupEnum): + BY_VALUE = "by_value" + BY_REFERENCE = "by_reference" + + +class WandbMode(LookupEnum): + ONLINE = "ONLINE" + OFFLINE = "OFFLINE" + DISABLED = "DISABLED" + + +class ReferenceConfig(BaseModel): + instance_key: str + pass_type: PassType -class CudaKwargsConfig(BaseModel): - num_workers: conint(ge=0) - pin_memory: bool - shuffle: bool +class CLMCrossEntropyLossConfig(BaseModel): + target_key: str + prediction_key: str + +# Checkpointing +class SaveEveryKStepsCheckpointingStrategyConfig(BaseModel): + k: PositiveInt -class TokenizerConfig(BaseModel): - class GPT2TokenizerFastConfig(BaseModel): - tokenizer_file: str # FilePath not possible, since transformers.PretrainedTokenizers can only handle strings - type_hint: TokenizerTypes - config: GPT2TokenizerFastConfig +class SaveKMostRecentCheckpointsStrategyConfig(BaseModel): + k: Annotated[int, Field(strict=True, ge=-1)] -class DatasetConfig(BaseModel): - class MemMapDatasetConfig(BaseModel): - raw_data_path: FilePath - index_path: Optional[FilePath] = None - block_size: conint(gt=0) - tokenizer: TokenizerConfig - jq_pattern: str - sample_key: str +class FSDPToDiscCheckpointingConfig(BaseModel): + checkpoint_path: Path + global_rank: Annotated[int, Field(strict=True, ge=0)] + experiment_id: str + block_names: List[str] + mixed_precision_settings: MixedPrecisionSettings + sharding_strategy: ShardingStrategy - class PackedMemMapDatasetContinuousConfig(BaseModel): - raw_data_path: Path - block_size: conint(gt=0) - sample_key: str + @field_validator("mixed_precision_settings", mode="before") + def parse_mixed_precision_setting_by_name(cls, name): + mixed_precision_settings: MixedPrecisionSettings = parse_enum_by_name( + name=name, enum_type=MixedPrecisionSettings + ) + if not has_bfloat_support() and ( + mixed_precision_settings == MixedPrecisionSettings.BF_16 + or mixed_precision_settings == MixedPrecisionSettings.BF_16_WORKING + ): + raise ValueError("BF16 not supported in the current environment") + return mixed_precision_settings - class PackedMemMapDatasetMegatronConfig(BaseModel): - raw_data_path: Path - block_size: conint(gt=0) - sample_key: str + @field_validator("sharding_strategy", mode="before") + def parse_sharding_strategy_by_name(cls, name): + return parse_enum_by_name(name=name, enum_type=ShardingStrategy) - class MMapIndexedDatasetConfig(BaseModel): - path: Path - skip_warmup: bool - class OpenGPTXMMapDatasetConfig(BaseModel): - num_samples: conint(ge=1) - path: FilePath - sample_key: str - sequence_len: PositiveInt +class CheckpointingConfig(BaseModel): + checkpointing_strategy: PydanticCheckpointingStrategyIFType + checkpointing_execution: PydanticCheckpointingExecutionIFType - type_hint: DatasetTypes - config: Union[ - MemMapDatasetConfig, - OpenGPTXMMapDatasetConfig, - PackedMemMapDatasetContinuousConfig, - PackedMemMapDatasetMegatronConfig, - MMapIndexedDatasetConfig, - ] = Field(union_mode="left_to_right") +class AdamWOptimizerConfig(BaseModel): + lr: float + wrapped_model: PydanticModelIFType -class SamplerConfig(BaseModel): - class DistributedSamplerConfig(BaseModel): - rank: conint(ge=0) - num_replicas: conint(ge=0) - shuffle: bool - type_hint: SamplerTypes - config: DistributedSamplerConfig +class CheckpointedOptimizerConfig(BaseModel): + checkpointing: PydanticCheckpointingIFType + checkpoint_path: Path + wrapped_model: PydanticModelIFType + optimizer: PydanticOptimizerIFType -class BatchSamplerConfig(BaseModel): - class StandardBatchSamplerConfig(BaseModel): - sampler: SamplerConfig - batch_size: conint(gt=0) - drop_last: bool +class CheckpointedModelConfig(BaseModel): + checkpointing: PydanticCheckpointingIFType + checkpoint_path: Path + model: PydanticModelIFType - type_hint: BatchSamplerTypes - config: StandardBatchSamplerConfig +class FSDPWrappedModelConfig(BaseModel): + model: PydanticModelIFType + sync_module_states: bool + mixed_precision_settings: MixedPrecisionSettings + sharding_strategy: ShardingStrategy + block_names: List[str] -class CollatorConfig(BaseModel): - class GPT2LLMCollatorConfig(BaseModel): - sample_key: str - target_key: str + @field_validator("mixed_precision_settings", mode="before") + def parse_mixed_precision_setting_by_name(cls, name): + mixed_precision_settings: MixedPrecisionSettings = parse_enum_by_name( + name=name, enum_type=MixedPrecisionSettings + ) + if not has_bfloat_support() and ( + mixed_precision_settings == MixedPrecisionSettings.BF_16 + or mixed_precision_settings == MixedPrecisionSettings.BF_16_WORKING + ): + raise ValueError("BF16 not supported in the current environment") + return mixed_precision_settings - type_hint: CollatorTypes - config: GPT2LLMCollatorConfig + @field_validator("sharding_strategy", mode="before") + def parse_sharding_strategy_by_name(cls, name): + return parse_enum_by_name(name=name, enum_type=ShardingStrategy) -class DataLoaderConfig(BaseModel): - class LLMDataLoaderConfig(CudaKwargsConfig): - dataloader_tag: str - dataset: DatasetConfig - batch_sampler: BatchSamplerConfig - collate_fn: CollatorConfig +class GPT2TokenizerFastConfig(BaseModel): + # Note: huggingface tokenizers expect file path as string + tokenizer_file: str - type_hint: DataloaderTypes - config: LLMDataLoaderConfig + +class DistributedSamplerConfig(BaseModel): + rank: Annotated[int, Field(strict=True, ge=0)] + num_replicas: Annotated[int, Field(strict=True, ge=0)] + shuffle: bool + dataset: PydanticDatasetIFType -class DataConfig(BaseModel): +class MemMapDatasetConfig(BaseModel): + raw_data_path: FilePath + index_path: Optional[FilePath] = None + block_size: Annotated[int, Field(strict=True, gt=0)] + tokenizer: PydanticTokenizerIFType + jq_pattern: str sample_key: str - target_key: str - sequence_len: int - train_dataloader: DataLoaderConfig - eval_dataloaders: List[DataLoaderConfig] -class ModelConfig(BaseModel): - type_hint: ModelTypes - config: GPT2Config +class PackedMemMapDatasetContinuousConfig(BaseModel): + raw_data_path: Path + block_size: Annotated[int, Field(strict=True, gt=0)] + sample_key: str -class CLMCrossEntropyLossConfig(BaseModel): +class PackedMemMapDatasetMegatronConfig(BaseModel): + raw_data_path: Path + block_size: Annotated[int, Field(strict=True, gt=0)] + sample_key: str + + +class MMapIndexedDatasetConfig(BaseModel): + path: Path + skip_warmup: bool + + +class OpenGPTXMMapDatasetConfig(BaseModel): + num_samples: Annotated[int, Field(strict=True, ge=1)] + path: FilePath + sample_key: str + sequence_len: PositiveInt + + +class BatchSamplerConfig(BaseModel): + sampler: PydanticSamplerIFType + batch_size: Annotated[int, Field(strict=True, gt=0)] + drop_last: bool + + +class ResumableBatchSamplerConfig(BaseModel): + sampler: PydanticSamplerIFType + start_index: Annotated[int, Field(strict=True, gt=0)] + + +class GPT2LLMCollateFnConfig(BaseModel): + sample_key: str target_key: str - prediction_key: str -class LossConfig(BaseModel): - type_hint: LossTypes - config: CLMCrossEntropyLossConfig - - -class TrainingConfig(BaseModel): - # TODO: use this in Progress Logging - global_num_training_samples: conint(gt=0) - callback_interval_in_samples: conint(gt=0) - process_group_backend: ProcessGroupBackendType - local_rank: conint(ge=0) - global_rank: conint(ge=0) - world_size: conint(ge=0) - main_rank: conint(ge=0) - local_train_micro_batch_size: conint(gt=0) - global_num_seen_samples: conint(ge=0) - do_apply_activation_checkpointing: bool - gradient_acc_step: conint(gt=0) - - @property - def local_train_batch_size(self): - return self.local_train_micro_batch_size * self.gradient_acc_step - - @property - def global_train_batch_size(self): - return self.local_train_batch_size * self.world_size - - @property - def local_num_train_samples(self): - exact = self.global_num_training_samples / self.world_size - ret = self.global_num_training_samples // self.world_size - if exact != ret: - print(f"Calculated local_num_training_samples is not an integer. Clipping {exact} to {ret} ") - return ret - - @property - def local_num_seen_train_samples(self): - exact = self.global_num_seen_samples / self.world_size - ret = self.global_num_seen_samples // self.world_size - if exact != ret: - print(f"Calculated global_num_seen_samples is not an integer. Clipping {exact} to {ret} ") - return ret - - @property - def skip_num_local_train_batches(self) -> int: - exact = self.global_num_seen_samples / self.world_size / self.local_train_micro_batch_size - ret = self.global_num_seen_samples // self.world_size // self.local_train_micro_batch_size - if exact != ret: - print(f"Calculated skip_num_local_train_batches is not an integer. Clipping {exact} to {ret} ") - return ret - - @property - def num_training_batches(self) -> int: - exact = self.global_num_training_samples / self.local_train_micro_batch_size - ret = self.global_num_training_samples // self.local_train_micro_batch_size - if exact != ret: - warnings.warn(f"Calculated num_training_batches is not an integer. Clipping {exact} to {ret} ") - return ret - - @property - def callback_interval_in_batches_per_rank(self): - exact = self.callback_interval_in_samples / self.local_train_micro_batch_size / self.world_size - ret = max(self.callback_interval_in_samples // self.local_train_micro_batch_size // self.world_size, 1) - if exact != ret: - warnings.warn( - f"Calculated callback_interval_in_batches_per_rank is not an integer. Clipping {exact} to {ret} " - ) - return ret - - -class AdamWConfig(BaseModel): - lr: confloat(ge=0.0) - - -class OptimizerConfig(BaseModel): - type_hint: OptimizerTypes - config: AdamWConfig - - -class OneCycleLRConfig(BaseModel): - max_lr: PositiveFloat - total_steps: conint(ge=1) - pct_start: confloat(ge=0.0) - anneal_strategy: str - cycle_momentum: bool - base_momentum: float | List - max_momentum: float | List - div_factor: PositiveFloat - final_div_factor: PositiveFloat - three_phase: bool - last_epochs: int - verbose: bool - - -class StepLRConfig(BaseModel): - step_size: conint(ge=1) - gamma: confloat(ge=0.0) - - -class ConstantLRConfig(BaseModel): - factor: PositiveFloat - total_iters: PositiveInt - - -class SchedulerConfig(BaseModel): - type_hint: SchedulerTypes - config: StepLRConfig | ConstantLRConfig | OneCycleLRConfig +class LLMDataLoaderConfig(BaseModel): + dataloader_tag: str + dataset: PydanticDatasetIFType + batch_sampler: PydanticSamplerIFType + collate_fn: PydanticCollateFnIFType + num_workers: Annotated[int, Field(strict=True, ge=0)] + pin_memory: bool + shuffle: bool + skip_num_batches: Optional[int] = 0 -class CheckpointingConfig(BaseModel): - class CheckpointingStrategyConfig(BaseModel): - class SaveEveryKStepsCheckpointingStrategyConfig(BaseModel): - k: PositiveInt - - class SaveKMostRecentCheckpointsStrategyConfig(BaseModel): - k: conint(ge=-1) - - type_hint: CheckpointingStrategyTypes - config: SaveEveryKStepsCheckpointingStrategyConfig | SaveKMostRecentCheckpointsStrategyConfig - - class CheckpointingExecutionConfig(BaseModel): - class FSDPToDiscCheckpointingConfig(BaseModel): - checkpoint_path: Path - global_rank: conint(ge=0) - - type_hint: CheckpointingExectionTypes - config: FSDPToDiscCheckpointingConfig - - checkpointing_strategy: CheckpointingStrategyConfig - checkpointing_execution: CheckpointingExecutionConfig - - -class RunMode(Enum): - FROM_SCRATCH = "FROM_SCRATCH" - WARM_START = "WARM_START" - -class ModalitiesSetupConfig(BaseModel): - class WarmStartSettings(BaseModel): - checkpoint_model_path: Path - global_num_seen_samples: conint(gt=0) - checkpoint_optimizer_path: Optional[Path] = None - checkpoint_lr_scheduler_path: Optional[Path] = None - - class FromScratchSettings(BaseModel): - global_num_seen_samples: int = 0 - - run_mode: RunMode - settings: FromScratchSettings - # settings: WarmStartSettings - - @model_validator(mode="after") - def check_passwords_match(self) -> "ModalitiesSetupConfig": - if self.run_mode == RunMode.FROM_SCRATCH: - if self.settings.global_num_seen_samples != 0: - raise ValueError("When starting from scratch, global_num_seen_samples must be 0.") - return self - - -class AppConfig(BaseModel): - modalities_setup: ModalitiesSetupConfig - data: DataConfig - training: TrainingConfig - running_env: RunningEnvConfig - model: ModelConfig - optimizer: OptimizerConfig - scheduler: SchedulerConfig - checkpointing: CheckpointingConfig - wandb: WandbConfig - loss: LossConfig - - -class PretrainedGPTConfig(PretrainedConfig): - model_type = "modalities_gpt2" - - def __init__(self, config: GPT2Config = None, **kwargs): - if type(config) == dict: - config = GPT2Config(**config) - self.config = config - - super().__init__(**kwargs) - - def to_json_string(self, use_diff: bool = True) -> str: - if self.config: - json_dict = {"config": self.config.__dict__.copy(), "model_type": self.model_type} - json_dict["config"]["attention"] = { - "attention_type": self.config.attention.attention_type.value, - "scaling_factor": self.config.attention.scaling_factor, - } - json_dict["config"]["weight_init"] = { - "mean": self.config.weight_init.mean, - "std": self.config.weight_init.std, - } - else: - json_dict = {} - return json.dumps(json_dict) +class DummyProgressSubscriberConfig(BaseModel): + pass + + +class RichProgressSubscriberConfig(BaseModel): + train_dataloader: PydanticLLMDataLoaderIFType + eval_dataloaders: Optional[List[PydanticLLMDataLoaderIFType]] = Field(default_factory=list) + world_size: int + global_num_seen_samples: int + local_rank: int + + +class DummyResultSubscriberConfig(BaseModel): + pass + + +class WandBEvaluationResultSubscriberConfig(BaseModel): + local_rank: int + project: str + experiment_id: str + mode: WandbMode + directory: Path + experiment_config: Optional[Dict] = None + + +class RichResultSubscriberConfig(BaseModel): + num_ranks: int + local_rank: int + + +class CudaEnv(BaseModel): + local_rank: Annotated[int, Field(strict=True, ge=0)] + world_size: Annotated[int, Field(strict=True, ge=1)] + global_rank: Annotated[int, Field(strict=True, ge=0)] + + +class Settings(BaseModel): + class Training(BaseModel): + callback_interval_in_samples: Annotated[int, Field(strict=True, ge=1)] + global_num_training_samples: Annotated[int, Field(strict=True, ge=1)] + global_num_seen_samples: Annotated[int, Field(strict=True, ge=0)] + do_apply_activation_checkpointing: bool + gradient_acc_steps: Annotated[int, Field(strict=True, ge=1)] + local_train_micro_batch_size: Annotated[int, Field(strict=True, ge=1)] + sequence_length: Annotated[int, Field(strict=True, ge=1)] + + class Paths(BaseModel): + checkpointing_path: Path + + experiment_id: str + referencing_keys: Dict[str, str] + training: Training + cuda_env: CudaEnv + paths: Paths + + +class ComponentsModel(BaseModel): + wrapped_model: PydanticModelIFType + optimizer: PydanticOptimizerIFType + loss_fn: PydanticLossIFType + train_dataloader: PydanticLLMDataLoaderIFType + eval_dataloaders: List[PydanticLLMDataLoaderIFType] + batch_progress_subscriber: PydanticMessageSubscriberIFType + evaluation_subscriber: PydanticMessageSubscriberIFType + checkpointing: PydanticCheckpointingIFType + settings: Settings + + +class ComponentsInferenceModel(BaseModel): + wrapped_model: PydanticModelIFType + cuda_env: CudaEnv + + +def load_app_config_dict(config_file_path: Path) -> Dict: + def cuda_env_resolver_fun(var_name: str) -> int: + int_env_variable_names = ["LOCAL_RANK", "WORLD_SIZE", "RANK"] + return int(os.getenv(var_name)) if var_name in int_env_variable_names else os.getenv(var_name) + + def modalities_env_resolver_fun(var_name: str) -> int: + if var_name == "experiment_id": + return get_date_of_run() + + OmegaConf.register_new_resolver("cuda_env", cuda_env_resolver_fun, replace=True) + OmegaConf.register_new_resolver("modalities_env", modalities_env_resolver_fun, replace=True) + + cfg = OmegaConf.load(config_file_path) + config_dict = OmegaConf.to_container(cfg, resolve=True) + return config_dict diff --git a/src/modalities/config/lookup_enum.py b/src/modalities/config/lookup_enum.py new file mode 100644 index 00000000..1e033735 --- /dev/null +++ b/src/modalities/config/lookup_enum.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class LookupEnum(Enum): + @classmethod + def _missing_(cls, value: str) -> type: + """constructs Enum by member name, if not constructable by value""" + return cls.__dict__[value] diff --git a/src/modalities/config/lookup_types.py b/src/modalities/config/lookup_types.py deleted file mode 100644 index 46147480..00000000 --- a/src/modalities/config/lookup_types.py +++ /dev/null @@ -1,83 +0,0 @@ -from enum import Enum - -import torch -from torch.utils.data import BatchSampler, DistributedSampler -from transformers import GPT2TokenizerFast - -from modalities.checkpointing.checkpointing_execution import FSDPToDiscCheckpointing -from modalities.checkpointing.checkpointing_strategies import ( - SaveEveryKStepsCheckpointingStrategy, - SaveKMostRecentCheckpointsStrategy, -) -from modalities.dataloader.dataloader import LLMDataLoader, RepeatingDataLoader -from modalities.dataloader.dataset import MemMapDataset, PackedMemMapDatasetContinuous, PackedMemMapDatasetMegatron -from modalities.dataloader.open_gptx_dataset.mmap_dataset import MMapIndexedDatasetBuilder -from modalities.dataloader.open_gptx_dataset.open_gptx_dataset import OpenGPTXMMapDataset -from modalities.loss_functions import CLMCrossEntropyLoss -from modalities.models.gpt2.collator import GPT2LLMCollator -from modalities.models.gpt2.gpt2_model import GPT2LLM - - -class LookupEnum(Enum): - @classmethod - def _missing_(cls, value: str) -> type: - """constructs Enum by member name, if not constructable by value""" - return cls.__dict__[value] - - -class ModelTypes(LookupEnum): - GPT2LLM = GPT2LLM - - -class LossTypes(LookupEnum): - CLMCrossEntropyLoss = CLMCrossEntropyLoss - - -class OptimizerTypes(LookupEnum): - AdamW = torch.optim.AdamW - - -class SchedulerTypes(LookupEnum): - StepLR = torch.optim.lr_scheduler.StepLR - ConstantLR = torch.optim.lr_scheduler.ConstantLR - OneCycleLR = torch.optim.lr_scheduler.OneCycleLR - - -class TokenizerTypes(LookupEnum): - GPT2TokenizerFast = GPT2TokenizerFast - - -class DatasetTypes(LookupEnum): - MemMapDataset = MemMapDataset - PackedMemMapDatasetContinuous = PackedMemMapDatasetContinuous - PackedMemMapDatasetMegatron = PackedMemMapDatasetMegatron - MMapIndexedDataset = MMapIndexedDatasetBuilder - # TODO: ClassResolver does not work with functions ... therefore there is also no - # support for factories. - OpenGPTXMMapDataset = OpenGPTXMMapDataset # member(OpenGPTXDatasetFactory.create_dataset) - - -class SamplerTypes(LookupEnum): - DistributedSampler = DistributedSampler - - -class BatchSamplerTypes(LookupEnum): - BatchSampler = BatchSampler - - -class CollatorTypes(LookupEnum): - GPT2LLMCollator = GPT2LLMCollator - - -class DataloaderTypes(LookupEnum): - RepeatingDataLoader = RepeatingDataLoader - LLMDataLoader = LLMDataLoader - - -class CheckpointingStrategyTypes(LookupEnum): - SaveKMostRecentCheckpointsStrategy = SaveKMostRecentCheckpointsStrategy - SaveEveryKStepsCheckpointingStrategy = SaveEveryKStepsCheckpointingStrategy - - -class CheckpointingExectionTypes(LookupEnum): - FSDPToDiscCheckpointing = FSDPToDiscCheckpointing diff --git a/src/modalities/config/types.py b/src/modalities/config/types.py deleted file mode 100644 index 803abc45..00000000 --- a/src/modalities/config/types.py +++ /dev/null @@ -1,5 +0,0 @@ -from enum import Enum - - -class ProcessGroupBackendType(Enum): - nccl = "nccl" diff --git a/src/modalities/dataloader/create_index.py b/src/modalities/dataloader/create_index.py index 8b5e0e3c..1fc0d4d9 100644 --- a/src/modalities/dataloader/create_index.py +++ b/src/modalities/dataloader/create_index.py @@ -6,16 +6,14 @@ import warnings from pathlib import Path -import numpy as np from tqdm import tqdm -# TODO: benchmark against pyspark class IndexGenerator: def __init__(self, src_file: Path, chunksize: int = 4096, drop_faulty_entries: bool = False): """ Reads in a JSON file as a binary file, iterates character by character und builds up - the sample index (char-wisestart and end position for each JSON sample) via "\n" character positions. + the sample index (char-wise start and end position for each JSON sample) via "\n" character positions. :param src_file: Path to a jsonl-file. :param chunksize: defines the size of byte chunks that are processed via a producer-consumer approach. @@ -26,12 +24,11 @@ def __init__(self, src_file: Path, chunksize: int = 4096, drop_faulty_entries: b self.src_file = src_file self.chunksize = chunksize self.drop_faulty_entries = drop_faulty_entries - with self.src_file.open(mode="r", encoding="utf-8") as fin: + with self.src_file.open(mode="r") as fin: fin.seek(0, os.SEEK_END) - num_chars = fin.tell() - self.num_chunks = num_chars // self.chunksize - self.reminder = num_chars % self.chunksize - self._chunk_queue = queue.Queue() + self._total_num_chars = fin.tell() + self.num_chunks = self._total_num_chars // self.chunksize + self._queue_of_raw_lines = queue.Queue() self._index_map = [] self._exception_buffer = [] @@ -51,49 +48,42 @@ def create_index(self, target_path_for_index_file: Path): def _indexer_thread(self): def queue_generator(): while True: - chunk = self._chunk_queue.get() - if chunk is None: + line = self._queue_of_raw_lines.get() + if line is None: break - yield chunk + yield line - def process_line(last_index: int, curr_index: int): - segment_len = curr_index - last_index + def parse_line_as_json(line_start_idx: int, line: str): try: # check if line is a valid json - line = np.memmap(self.src_file, mode="r", offset=last_index, shape=(segment_len,)).view("S1").tolist() - line = [c.decode("utf8") for c in line] - line = "".join(line) json.loads(line) - self._index_map.append((last_index, segment_len)) + self._index_map.append((line_start_idx, len(line))) except Exception as low_level_err: if self.drop_faulty_entries: - warnings.warn(f"faulty line at {last_index}-{curr_index}, skipping...") + warnings.warn(f'faulty line "{line}", skipping...') else: - warnings.warn(f"faulty line: {line=}") - err = ValueError(f"faulty line at {last_index}-{curr_index}") + err = ValueError(f'faulty line "{line}", skipping...') err.__cause__ = low_level_err self._exception_buffer.append(err) self._index_map = [] - last_index = 0 - for chunk_idx, chunk in tqdm(enumerate(queue_generator()), desc="Processed Chunks", total=self.num_chunks): - for char_index, c in enumerate(chunk): - curr_index = chunk_idx * self.chunksize + char_index - if c == ord("\n"): - process_line(last_index, curr_index) - last_index = curr_index + 1 - # prevents automatically added "\n"-chars at the end of files getting interpreted as own sample - if curr_index >= last_index: - process_line(last_index, curr_index + 1) + for line_start_idx, line in tqdm(queue_generator(), desc="Processed Lines"): + if self._check_for_parallel_errors(): + return + parse_line_as_json(line_start_idx, line) def _reader_thread(self): - with open(self.src_file, "rb") as fin: + with open(self.src_file, "r") as fin: while True: - chunk = fin.read(self.chunksize) - if self._exception_buffer: - raise RuntimeError( - "Exception found in exception buffer. Probably the indexer thread ran into an error..." - ) - if not chunk: + cursor = fin.tell() + line = fin.readline() + if self._check_for_parallel_errors(): + return + if fin.tell() == self._total_num_chars: + self._queue_of_raw_lines.put((cursor, line)) break - self._chunk_queue.put(chunk) - self._chunk_queue.put(None) + line_without_newline_char = line[:-1] + self._queue_of_raw_lines.put((cursor, line_without_newline_char)) + self._queue_of_raw_lines.put(None) + + def _check_for_parallel_errors(self) -> bool: + return bool(self._exception_buffer) diff --git a/src/modalities/dataloader/create_packed_data.py b/src/modalities/dataloader/create_packed_data.py index 6e8d4d3c..f2ba6419 100644 --- a/src/modalities/dataloader/create_packed_data.py +++ b/src/modalities/dataloader/create_packed_data.py @@ -1,7 +1,11 @@ +import logging +import math +import multiprocessing +import os import pickle import warnings from pathlib import Path -from typing import IO +from typing import Callable, Iterator, List, Tuple import jq import numpy as np @@ -10,23 +14,21 @@ from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader +logger = logging.getLogger(__name__) + + +class EmptySampleError(RuntimeError): + pass -class PackedDataGenerator: - # amount of bytes to represent tokens as integers. - # If the vocabulary exceeds 2^(8*`size_in_bytes`), this requires adaptation. - TOKEN_SIZE_IN_BYTES = 4 - # amount of bytes to represent number of all tokens in dataset. - # If the amount exceeds 2^(8*`header_size_in_bytes`), this requires adaptation. - # Decided to keep this constant, since a size of 8 bytes requires more data than the internet currently provides - HEAD_SIZE_IN_BYTES = 8 +class PackedDataGenerator: def __init__( self, src_path: Path, tokenizer: PreTrainedTokenizer, index_path: Path = None, jq_pattern: str = ".text", - max_number_of_tokens: int = None, + number_of_processes: int = os.cpu_count(), ): """ Reads in a jsonl file and the corresponding index file and packs dataset file for LLM training. @@ -38,18 +40,25 @@ def __init__( :param tokenizer: PretrainedTokenizer object, which is used to pre-tokenize the provided data in `src_path`. Tokenization is necessary to work on final lengths of token sequences. :param jq_pattern: jq-pattern applied on every jsonl-entry. Results are afterwards tokenized and packed - :param max_number_of_tokens: Limit the total amount of tokens in the packed dataset. - If not specified, the whole data is packed into the dataset. """ self.src_path = src_path self.tokenizer = tokenizer + self._token_size_in_bytes = self._get_required_num_of_bytes_to_repr(self.tokenizer.vocab_size) + encoded_eos_token = self.tokenizer(self.tokenizer.eos_token)["input_ids"][0] + self._encoded_eos_token_as_bytes = self._encoded_token_to_bytes(encoded_eos_token) self.jq_filter = jq.compile(jq_pattern) - self.max_tokens = max_number_of_tokens - + self._number_of_processes = number_of_processes self._reader = LargeFileLinesReader(src_path, index_path=index_path) self._total_num_of_tokens = 0 - self._curr_offset = self.HEAD_SIZE_IN_BYTES - self._index_list = [] + self._tokens_write_queue = multiprocessing.Queue() + self._exception_buffer = [] + + @staticmethod + def _get_required_num_of_bytes_to_repr(int_to_get_repr: int) -> int: + return math.ceil(math.log(math.log2(int_to_get_repr), 8)) + + def _encoded_token_to_bytes(self, encoded_token: int) -> bytes: + return encoded_token.to_bytes(self._token_size_in_bytes, byteorder="big", signed=False) def _default_destination_path(self, destination_path: Path = None) -> Path: if destination_path is None: @@ -68,54 +77,177 @@ def run(self, dst_path: Path = None): if dst_path.exists(): raise ValueError(f"file already exists at destination path '{dst_path}'.") - encoded_eos_token = self.tokenizer(self.tokenizer.eos_token)["input_ids"][0] - encoded_eos_token_as_bytes = encoded_eos_token.to_bytes(self.TOKEN_SIZE_IN_BYTES, byteorder="big") - with dst_path.open("wb") as f: - # allocate first self.header_size_in_bytes bytes for header (encodes length of data section) - # not possible to prepend header after determining size of data section - f.write((0).to_bytes(self.HEAD_SIZE_IN_BYTES, byteorder="big")) - - # write data section (tokens) - for idx, line in tqdm(enumerate(self._reader)): - try: - self._process_line(encoded_eos_token_as_bytes, f, line) - except ValueError: - warnings.warn(f"Encountered empty sample in line {idx} of file {self.src_path}") - except StopIteration: - break - except Exception as exception: - warnings.warn(f"could not process line: {exception=}") - - # write index - f.write(pickle.dumps(self._index_list)) - - self._update_data_length_in_pre_allocated_header(dst_path) - - def _update_data_length_in_pre_allocated_header(self, dst_path: Path): - start_of_index_in_bytes = self._index_list[-1][0] + self._index_list[-1][1] - length_of_byte_encoded_data_section = start_of_index_in_bytes - self.HEAD_SIZE_IN_BYTES - header_content = length_of_byte_encoded_data_section.to_bytes(self.HEAD_SIZE_IN_BYTES, byteorder="big") - header_content = np.frombuffer(header_content, dtype="uint8") - # write the header content to the packed dataset file - m = np.memmap(dst_path, mode="r+", offset=0, shape=(self.HEAD_SIZE_IN_BYTES,)) - m[:] = header_content[:] - - def _process_line(self, eos_token_as_bytes: bytes, f: IO, line: str): + self._exception_buffer = [] + try: + # not setting this can cause deadlocks when using hf's "FastTokenizers". See also: + # https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning/67254879#67254879 + os.environ["TOKENIZERS_PARALLELISM"] = "false" + self._launch_parallelized_workers(dst_path) + finally: + os.unsetenv("TOKENIZERS_PARALLELISM") + + if self._exception_buffer: + raise self._exception_buffer[0] + + def _launch_parallelized_workers(self, dst_path: Path): + writer = multiprocessing.Process(target=self._writer_thread(dst_path)) + writer.start() + processor_threads = [ + multiprocessing.Process(target=self._process_thread, args=(i,)) for i in range(self._number_of_processes) + ] + for p in processor_threads: + p.start() + for p in processor_threads: + p.join() + self._stop_processing() + writer.join() + + def _stop_processing(self): + self._tokens_write_queue.put(None) + + def _generator_for_tokens_to_get_written(self): + while True: + if self._check_for_parallel_errors(): + return + tokens = self._tokens_write_queue.get() + if tokens is None: + break + yield tokens + + def _check_for_parallel_errors(self) -> bool: + return bool(self._exception_buffer) + + def _writer_thread(self, dst_path: Path) -> Callable: + def writer(): + index_list = [] + with dst_path.open("wb") as f: + # allocate first self.header_size_in_bytes bytes for header (encodes length of data section) + # not possible to prepend header after determining size of data section + f.write((0).to_bytes(EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="big")) + f.write( + self._token_size_in_bytes.to_bytes( + EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES, byteorder="big" + ) + ) + curr_offset = EmbeddedStreamData.HEADER_SIZE_IN_BYTES + + # write data section (tokens) + for tokens_as_bytes in tqdm( + self._generator_for_tokens_to_get_written(), desc="Processed Samples", total=len(self._reader) + ): + f.write(tokens_as_bytes) + segment_length = len(tokens_as_bytes) + index_list.append((curr_offset, segment_length)) + curr_offset += segment_length + + # write index + f.write(pickle.dumps(index_list)) + + self._update_data_length_in_pre_allocated_header(dst_path, index_list) + + return writer + + def _process_thread(self, process_id: int): + if self._check_for_parallel_errors(): + return + for idx in range(process_id, len(self._reader), self._number_of_processes): + line = self._reader[idx] + try: + self._tokens_write_queue.put(self._process_line(line)) + except EmptySampleError: + warnings.warn(f"Encountered empty sample in line {idx} of file {self.src_path}") + except Exception as exception: + warnings.warn(f"could not process line of number {idx}. Raised the following error: {exception=}") + + def _update_data_length_in_pre_allocated_header(self, dst_path: Path, index_list: List[Tuple[int, int]]): + start_of_index_in_bytes = index_list[-1][0] + index_list[-1][1] + length_of_byte_encoded_data_section = start_of_index_in_bytes - EmbeddedStreamData.HEADER_SIZE_IN_BYTES + data_section_length_in_bytes = length_of_byte_encoded_data_section.to_bytes( + EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="big" + ) + with dst_path.open("rb+") as fout: + fout.seek(0) + fout.write(data_section_length_in_bytes) + + def _process_line(self, line: str) -> bytes: jq_retrieved_text = self.jq_filter.input_text(line).first() + if jq_retrieved_text is None: + raise ValueError(f"jq was not able to find anything using the expression: {self.jq_filter}") tokens = self.tokenizer(jq_retrieved_text)["input_ids"] if len(tokens) == 0: - raise ValueError("Received empty sample...") - token_idx = 0 - for token in tokens: - token_as_bytes = token.to_bytes(self.TOKEN_SIZE_IN_BYTES, byteorder="big") - f.write(token_as_bytes) - self._total_num_of_tokens += 1 - if self._total_num_of_tokens == self.max_tokens: - segment_length = (token_idx + 1) * self.TOKEN_SIZE_IN_BYTES - self._index_list.append((self._curr_offset, segment_length)) - raise StopIteration - token_idx += 1 - f.write(eos_token_as_bytes) - segment_length = (token_idx + 1) * self.TOKEN_SIZE_IN_BYTES # segment_length in bytes - self._index_list.append((self._curr_offset, segment_length)) - self._curr_offset += segment_length + raise EmptySampleError("Received empty sample...") + return b"".join(map(self._encoded_token_to_bytes, tokens)) + self._encoded_eos_token_as_bytes + + +class EmbeddedStreamData: + # amount of bytes to represent number of all tokens in dataset. + # If the amount exceeds 2^(8*`header_size_in_bytes`), this requires adaptation. + # Decided to keep this constant, since a size of 8 bytes requires more data than the internet currently provides + DATA_SECTION_LENGTH_IN_BYTES = 8 + TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES = 4 + HEADER_SIZE_IN_BYTES = DATA_SECTION_LENGTH_IN_BYTES + TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES + + def __init__(self, data_path: Path): + self._data_path = data_path + if not self._data_path.is_file(): + raise FileNotFoundError( + f"Packed Data was not found at {self._data_path}." + f"Create on in advance by using `modalities data pack_encoded_data`." + ) + + with self._data_path.open("rb") as f: + # get number of bytes in data section + data_section_length_in_bytes = f.read(self.DATA_SECTION_LENGTH_IN_BYTES) + self.data_len = int.from_bytes(data_section_length_in_bytes, byteorder="big") + + # get number of bytes for encoding a single token + f.seek(self.DATA_SECTION_LENGTH_IN_BYTES) + token_size_as_bytes = f.read(self.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES) + self.token_size_in_bytes = int.from_bytes(token_size_as_bytes, byteorder="big", signed=False) + + # get index + f.seek(self.HEADER_SIZE_IN_BYTES + self.data_len) + pkl_encoded_index = f.read() + self.index_base = pickle.loads(pkl_encoded_index) + + # initialize memmapped data section + self.data = np.memmap(self._data_path, mode="r", offset=self.HEADER_SIZE_IN_BYTES, shape=(self.data_len,)) + + +def join_embedded_stream_data(stream_data: List[EmbeddedStreamData], target_file: Path, chunk_size: int = 2048): + if target_file.exists(): + raise FileExistsError(f'Target File at "{target_file}" exists!') + data_len = sum(d.data_len for d in stream_data) + assert len({d.token_size_in_bytes for d in stream_data}) == 1, ( + "Found different token representation sizes. This could indicate the usage of different tokenizers. " + "Not supported!" + ) + token_size_in_bytes = stream_data[0].token_size_in_bytes + + num_data_chunks = sum(math.ceil(d.data_len / chunk_size) for d in stream_data) + data_stream_generator = (d.data[i : i + chunk_size] for d in stream_data for i in range(0, d.data_len, chunk_size)) + + num_entries = sum(len(d.index_base) for d in stream_data) + + def index_stream_generator() -> Iterator[Tuple[int, int]]: + curr_offset = 0 + for embedded_stream_data in stream_data: + for entry_offset, segment_length in embedded_stream_data.index_base: + yield entry_offset + curr_offset, segment_length + curr_offset += embedded_stream_data.data_len + curr_offset -= embedded_stream_data.HEADER_SIZE_IN_BYTES + + with target_file.open("wb") as fout: + fout.write(data_len.to_bytes(EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="big")) + fout.write( + token_size_in_bytes.to_bytes(EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES, byteorder="big") + ) + for data_chunk in tqdm(data_stream_generator, total=num_data_chunks, desc="Writing Data Chunks..."): + fout.write(data_chunk) + + joint_index = [entry for entry in tqdm(index_stream_generator(), total=num_entries, desc="Concatenating Index")] + pickled_index = pickle.dumps(joint_index) + pickled_index_as_chunks = (pickled_index[i : i + chunk_size] for i in range(0, len(pickled_index), chunk_size)) + num_index_chunks = math.ceil(len(pickled_index) / chunk_size) + for index_chunk in tqdm(pickled_index_as_chunks, total=num_index_chunks, desc="Writing Index Chunks..."): + fout.write(index_chunk) diff --git a/src/modalities/dataloader/dataloader.py b/src/modalities/dataloader/dataloader.py index a3c6d3f4..695cc2a7 100644 --- a/src/modalities/dataloader/dataloader.py +++ b/src/modalities/dataloader/dataloader.py @@ -2,15 +2,14 @@ from torch.utils.data import Dataset, Sampler from torch.utils.data.dataloader import DataLoader, T_co, _collate_fn_t, _worker_init_fn_t - -from modalities.dataloader.samplers import ResumableBatchSampler +from torch.utils.data.sampler import BatchSampler class LLMDataLoader(DataLoader[T_co]): def __init__( self, dataloader_tag: str, - batch_sampler: ResumableBatchSampler, + batch_sampler: BatchSampler, dataset: Dataset[T_co], batch_size: Optional[int] = 1, shuffle: Optional[bool] = None, @@ -49,19 +48,24 @@ def __init__( ) self._dataloader_tag = dataloader_tag + self._batch_size = batch_sampler.batch_size @property def dataloader_tag(self) -> str: return self._dataloader_tag @property - def sampler_batch_size(self) -> int: + def batch_size(self) -> int: # The parent Dataloader class has already a batch_size property defined which is originally used # when the batch_sampler is not specified. Since the LLMDataLoader enforces to always use a BatchSampler, - # we defined the property sampler_batch_size to return the actual batch size used in the dataloder. + # we defined/ override the property batch_size to return the actual batch size used in the dataloder. # BatchSampler is required, as we must seek forward in the dataloder during a warm start and # we don't want to load all the data during the fast-forward. - return self.batch_sampler.sampler_batch_size + return self._batch_size + + @batch_size.setter + def batch_size(self, value: int): + self._batch_size = value @property def fast_forward_sample_id(self) -> int: @@ -70,7 +74,7 @@ def fast_forward_sample_id(self) -> int: Returns: int: fast forward sample id """ - return self.sampler_batch_size * self.batch_sampler.start_index + return self.batch_size * self.batch_sampler.start_index @property def fast_forward_batch_id(self) -> int: @@ -83,15 +87,15 @@ def fast_forward_batch_id(self) -> int: class RepeatingDataLoader(LLMDataLoader[T_co]): - def __init__(self, data_loader: LLMDataLoader[T_co], reshuffle_after_epoch: bool = False): + def __init__(self, dataloader: LLMDataLoader[T_co], reshuffle_after_epoch: bool = False): """Wraps an iterator to allow for infinite iteration. This is especially useful for DataLoader types that we wish to automatically restart upon completion. Args: loader (iterator): The data loader to repeat. """ - self.data_loader = data_loader - self.data_iter = iter(self.data_loader) + self.dataloader = dataloader + self.data_iter = iter(self.dataloader) self.current_epoch = 0 self.reshuffle_after_epoch = reshuffle_after_epoch @@ -102,24 +106,24 @@ def __next__(self): try: batch = next(self.data_iter) except StopIteration: - if self.data_loader.sampler is not None: + if self.dataloader.sampler is not None: # In distributed mode, calling the set_epoch() method at the beginning of each epoch before creating # the DataLoader iterator is necessary to make shuffling work properly across multiple epochs. # Otherwise, the same ordering will be always used. See discussion: # https://discuss.pytorch.org/t/why-is-sampler-set-epoch-epoch-needed-for-distributedsampler/149672 self.current_epoch += 1 - self.data_loader.sampler.set_epoch(self.current_epoch) - self.data_iter = iter(self.data_loader) + self.dataloader.sampler.set_epoch(self.current_epoch) + self.data_iter = iter(self.dataloader) batch = next(self.data_iter) return batch @property def dataloader_tag(self) -> str: - return self.data_loader._dataloader_tag + return self.dataloader._dataloader_tag @property - def sampler_batch_size(self) -> int: - return self.data_loader.batch_sampler.batch_size + def batch_size(self) -> int: + return self.dataloader.batch_sampler.batch_size @property def fast_forward_sample_id(self) -> int: @@ -128,7 +132,7 @@ def fast_forward_sample_id(self) -> int: Returns: int: fast forward sample id """ - return self.data_loader.sampler_batch_size * self.batch_sampler.start_index + return self.dataloader.batch_size * self.batch_sampler.start_index @property def fast_forward_batch_id(self) -> int: @@ -137,4 +141,4 @@ def fast_forward_batch_id(self) -> int: Returns: int: fast forward batch id """ - return self.data_loader.batch_sampler.start_index + return self.dataloader.batch_sampler.start_index diff --git a/src/modalities/dataloader/dataloader_factory.py b/src/modalities/dataloader/dataloader_factory.py index 225a4583..09606415 100644 --- a/src/modalities/dataloader/dataloader_factory.py +++ b/src/modalities/dataloader/dataloader_factory.py @@ -1,82 +1,34 @@ +from typing import Callable, Optional + +from torch.utils.data import BatchSampler from torch.utils.data.dataset import Dataset -from modalities.config.config import DataLoaderConfig, DatasetConfig from modalities.dataloader.dataloader import LLMDataLoader -from modalities.dataloader.open_gptx_dataset.open_gptx_dataset import OpenGPTXMMapDataset from modalities.dataloader.samplers import ResumableBatchSampler -from modalities.resolver_register import ResolverRegister - - -class OpenGPTXDatasetWrapper(Dataset): - def __init__(self, open_gptx_dataset: OpenGPTXMMapDataset, num_samples: int) -> None: - super().__init__() - self.open_gptx_dataset = open_gptx_dataset - self.num_samples = num_samples - - def __len__(self): - return self.num_samples - - def __getitem__(self, idx: int): - if self.num_samples > idx: - return self.open_gptx_dataset.__getitem__(idx) - else: - raise ValueError("num_samples <= idx") class DataloaderFactory: @staticmethod def get_dataloader( - resolvers: ResolverRegister, config: DataLoaderConfig, skip_num_batches: int = 0 + dataloader_tag: str, + dataset: Dataset, + batch_sampler: BatchSampler, + collate_fn: Callable, + num_workers: int, + pin_memory: bool, + shuffle: bool, + skip_num_batches: Optional[int] = 0, ) -> LLMDataLoader: - # TODO: replace this with dynamic nested object instantiation. (More details: Different Dataloaders require - # different objects in their constructors. the resolvers should be able to provide the necessary complex - # objects automatically, without us manually creating this complex factory.) - additional_init_payload = {} - if hasattr(config.config.dataset.config, "tokenizer"): - tokenizer = resolvers.build_component_by_config(config=config.config.dataset.config.tokenizer) - tokenizer.pad_token = tokenizer.eos_token - additional_init_payload.update(tokenizer=tokenizer) - - dataset = resolvers.build_component_by_config( - config=config.config.dataset, extra_kwargs=additional_init_payload - ) - - # BUG: Sometimes the dataset genereated by the OpenGPTXMMap implementation has too many samples. - # This is a workaround to fix the dataset to the size, as specified in the config! - # TODO: Fix the OpenGPTX implementation and get rid of this hack. - if isinstance(config.config.dataset.config, DatasetConfig.OpenGPTXMMapDatasetConfig): - dataset = OpenGPTXDatasetWrapper( - open_gptx_dataset=dataset, num_samples=config.config.dataset.config.num_samples - ) - - collator = resolvers.build_component_by_config(config=config.config.collate_fn) - sampler = resolvers.build_component_by_config( - config=config.config.batch_sampler.config.sampler, extra_kwargs=dict(dataset=dataset) - ) - - batch_sampler = resolvers.build_component_by_config( - config=config.config.batch_sampler, - extra_kwargs=dict( - sampler=sampler, - ), - ) - - resumable_batch_sampler = ResumableBatchSampler( - start_index=skip_num_batches, underlying_batch_sampler=batch_sampler - ) - - dataloader = resolvers.build_component_by_config( - config=config, - extra_kwargs=dict( - dataset=dataset, - batch_sampler=resumable_batch_sampler, - collate_fn=collator, - ), + batch_sampler = ResumableBatchSampler(start_index=skip_num_batches, underlying_batch_sampler=batch_sampler) + + dataloader = LLMDataLoader( + dataloader_tag=dataloader_tag, + batch_sampler=batch_sampler, + dataset=dataset, + collate_fn=collate_fn, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, ) - # TODO we should have this check rather in the gym. Here, it is clear that - # we are using the LLMDataLoader - assert isinstance( - dataloader, LLMDataLoader - ), f"Dataloader Class must use the {LLMDataLoader.__name__}-Interface" return dataloader diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 8e7a4c3b..ef0ae2ad 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -1,7 +1,5 @@ from __future__ import annotations -import os -import pickle from pathlib import Path from typing import List, Optional, Tuple @@ -12,6 +10,7 @@ from transformers import BatchEncoding, PreTrainedTokenizer from ..dataloader.large_file_lines_reader import LargeFileLinesReader +from .create_packed_data import EmbeddedStreamData class Dataset(TorchdataSet): @@ -70,98 +69,72 @@ def __getitem__(self, idx: int) -> BatchEncoding: class PackedMemMapDatasetBase(Dataset): - INT_SIZE_IN_BYTES = 4 - HEADER_SIZE_IN_BYTES = 8 + DATA_SECTION_LENGTH_IN_BYTES = EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES + TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES = EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES + HEADER_SIZE_IN_BYTES = EmbeddedStreamData.HEADER_SIZE_IN_BYTES + np_dtype_from_num_bytes = { + 1: np.dtype(np.uint8).newbyteorder(">"), + 2: np.dtype(np.uint16).newbyteorder(">"), + 4: np.dtype(np.uint32).newbyteorder(">"), + 8: np.dtype(np.uint64).newbyteorder(">"), + } def __init__(self, raw_data_path: Path, block_size: int, sample_key: str): """ Base class for packed memmapped datasets. The underlying dataset file has the structure: | header | data | index | - The header contains information about the length of the subsequent data sequence. The index contains - the tuple information (start, end) in terms of byte positions. + The header contains information about the length of the subsequent data sequence and the amount of bytes + required to represent tokens in the data section. The index contains the tuple information (start, end) in terms + of byte positions. :param raw_data_path: Path to a packed binary file (*.pbin). - Use `modalities create_packed_data` to create one based on a jsonl-file. + Use `modalities data pack_encoded_data` to create one based on a jsonl-file. :param block_size: alias for max sequence length. The amount of tokens the model can handle. :param sample_key: model-specific parameter to indicate where in the BatchEncoding the input_token_ids are. TODO: If this setting should support multi-modal features using separately encoded inputs, this needs to get replaced with a list of sample keys! """ super().__init__(raw_data_path=raw_data_path, block_size=block_size, sample_key=sample_key) - if not self.raw_data_path.is_file(): - raise FileNotFoundError( - f"Packed Data was not found at {self.raw_data_path}." - f"Create on in advance by using `modalities create_packed_data`." + self._embedded_stream_data = EmbeddedStreamData(raw_data_path) + self._token_size_in_bytes = self._embedded_stream_data.token_size_in_bytes + try: + self._token_dtype = self.np_dtype_from_num_bytes[self._token_size_in_bytes] + except KeyError: + raise RuntimeError( + f"Encountered a required token representation with {self._token_size_in_bytes}," + " which is not supported. Consider using a smaller vocabulary." ) + self._index = self._generate_packing_index() - # get number of total bytes in file - with self.raw_data_path.open("rb") as f: - f.seek(0, os.SEEK_END) - self.total_bytes = f.tell() - f.seek(0) - - # get number of bytes in data section - self.data_len = np.memmap( - self.raw_data_path, - mode="r", - offset=0, - shape=(self.HEADER_SIZE_IN_BYTES,), - ).view(f"S{self.HEADER_SIZE_IN_BYTES}") - self.data_len = int.from_bytes(self.data_len, byteorder="big") - - # get index - self.index_base = np.memmap( - self.raw_data_path, - mode="r", - offset=self.HEADER_SIZE_IN_BYTES + self.data_len, - shape=(self.total_bytes - self.data_len - self.HEADER_SIZE_IN_BYTES,), - ).view(f"S{self.total_bytes-self.data_len-self.HEADER_SIZE_IN_BYTES}") - self.index_base = pickle.loads(self.index_base) - - -class PackedMemMapDatasetContinuous(PackedMemMapDatasetBase): - def __init__(self, raw_data_path: Path, block_size: int, sample_key: str): - """ - PackedMemMapDatasetContinuous iterates through the data in block_size sized chunks, - irrespective of the samples' start and end position, as defined in the index. - Therefore, for this datset, the index is irrelevant. - - :param raw_data_path: Path to a packed binary file (*.pbin). - Use `modalities create_packed_data` to create one based on a jsonl-file. - :param block_size: alias for max sequence length. The amount of tokens the model can handle. - :param sample_key: model-specific parameter to indicate where in the BatchEncoding the input_token_ids are. - TODO: If this setting should support multi-modal features using separately encoded inputs, - this needs to get replaced with a list of sample keys! - """ - super().__init__(raw_data_path=raw_data_path, block_size=block_size, sample_key=sample_key) - - # get number of total tokens in file - total_tokens = self.data_len // self.INT_SIZE_IN_BYTES - self._num_samples = total_tokens // self.block_size + def _generate_packing_index(self) -> List[Tuple[int, int]]: + raise NotImplementedError def __len__(self) -> int: - return self._num_samples + return len(self._index) def __getitem__(self, idx: int) -> BatchEncoding: self._check_if_inbounds(idx) - tokens_as_byte_strings = np.memmap( - self.raw_data_path, - mode="r", - offset=self.HEADER_SIZE_IN_BYTES + idx * self.INT_SIZE_IN_BYTES * self.block_size, - shape=(self.INT_SIZE_IN_BYTES * self.block_size,), - ).view(f"S{self.INT_SIZE_IN_BYTES}") - tokens = [int.from_bytes(token, byteorder="big") for token in tokens_as_byte_strings] + offset, length = self._index[idx] + tokens = np.frombuffer(self._embedded_stream_data.data, dtype=self._token_dtype, count=length, offset=offset) return BatchEncoding(data={self.sample_key: tokens}) +class PackedMemMapDatasetContinuous(PackedMemMapDatasetBase): + def _generate_packing_index(self) -> List[Tuple[int, int]]: + # get number of total tokens in file + total_tokens = self._embedded_stream_data.data_len // self._token_size_in_bytes + num_samples = total_tokens // self.block_size + return [(i * self.block_size * self._token_size_in_bytes, self.block_size) for i in range(num_samples)] + + class PackedMemMapDatasetMegatron(PackedMemMapDatasetBase): - def generate_megatron_index(self) -> List[Tuple[int, int]]: + def _generate_packing_index(self) -> List[Tuple[int, int]]: index = [] curr_offset = self.HEADER_SIZE_IN_BYTES curr_len = 0 - block_size_in_bytes = self.block_size * self.INT_SIZE_IN_BYTES - for segment_offset, segment_len in tqdm(self.index_base): - # When the sum of of the length of the current previously seen samples doesn't + block_size_in_bytes = self.block_size * self._token_size_in_bytes + for segment_offset, segment_len in tqdm(self._embedded_stream_data.index_base): + # When the sum of the length of the current previously seen samples doesn't # exceed block_size_in_bytes, we add the current segment length to the previous # ones and continue. if curr_len + segment_len < block_size_in_bytes: @@ -169,14 +142,14 @@ def generate_megatron_index(self) -> List[Tuple[int, int]]: # If the previous and current length equals block_size_in_bytes, we add the starting index # and the total sequences length to the index list as a new sample. elif curr_len + segment_len == block_size_in_bytes: - index.append((curr_offset, block_size_in_bytes)) + index.append((curr_offset, self.block_size)) curr_len = 0 curr_offset += block_size_in_bytes # Else case is executed when the current and previous segment length exceed the block_size. # In this case we set the starting point of the next sample to the end of the current sample. # This way, the start of a sample is never in the middle of a sentence. else: - index.append((curr_offset, block_size_in_bytes)) + index.append((curr_offset, self.block_size)) if segment_len > block_size_in_bytes: curr_offset += block_size_in_bytes curr_len = 0 @@ -184,30 +157,3 @@ def generate_megatron_index(self) -> List[Tuple[int, int]]: curr_offset = segment_offset curr_len = segment_len return index - - def __init__(self, raw_data_path: Path, block_size: int, sample_key: str): - """ - :param raw_data_path: Path to a packed binary file (*.pbin). - Use `modalities create_packed_data` to create one based on a jsonl-file. - :param block_size: alias for max sequence length. The amount of tokens the model can handle. - :param sample_key: model-specific parameter to indicate where in the BatchEncoding the input_token_ids are. - TODO: If this setting should support multi-modal features using separately encoded inputs, - this needs to get replaced with a list of sample keys! - """ - super().__init__(raw_data_path=raw_data_path, block_size=block_size, sample_key=sample_key) - self._index = self.generate_megatron_index() - - def __len__(self) -> int: - return len(self._index) - - def __getitem__(self, idx: int) -> BatchEncoding: - self._check_if_inbounds(idx) - offset, length = self._index[idx] - tokens_as_byte_strings = np.memmap( - self.raw_data_path, - mode="r", - offset=offset, - shape=(length,), - ).view(f"S{self.INT_SIZE_IN_BYTES}") - tokens = [int.from_bytes(token, byteorder="big") for token in tokens_as_byte_strings] - return BatchEncoding(data={self.sample_key: tokens}) diff --git a/src/modalities/dataloader/dataset_factory.py b/src/modalities/dataloader/dataset_factory.py new file mode 100644 index 00000000..157e98d0 --- /dev/null +++ b/src/modalities/dataloader/dataset_factory.py @@ -0,0 +1,85 @@ +from pathlib import Path +from typing import Optional + +from pydantic import FilePath +from torch.utils.data.dataset import Dataset +from transformers import PreTrainedTokenizer + +from modalities.dataloader.dataset import MemMapDataset, PackedMemMapDatasetContinuous, PackedMemMapDatasetMegatron +from modalities.dataloader.open_gptx_dataset.open_gptx_dataset import OpenGPTXMMapDataset + + +class OpenGPTXDatasetWrapper(Dataset): + def __init__(self, open_gptx_dataset: OpenGPTXMMapDataset, num_samples: int) -> None: + super().__init__() + self.open_gptx_dataset = open_gptx_dataset + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx: int): + if self.num_samples > idx: + return self.open_gptx_dataset.__getitem__(idx) + else: + raise ValueError("num_samples <= idx") + + +class DatasetFactory: + @staticmethod + def get_mem_map_dataset( + raw_data_path: Path, + block_size: int, + tokenizer: PreTrainedTokenizer, + sample_key: str, + index_path: Optional[Path] = None, + jq_pattern: str = ".text", + ) -> MemMapDataset: + # TODO this was part of the old Dataloader implementation. + # we need to check if this is actually wanted generally. + tokenizer.pad_token = tokenizer.eos_token + + dataset = MemMapDataset( + raw_data_path=raw_data_path, + block_size=block_size, + tokenizer=tokenizer, + sample_key=sample_key, + index_path=index_path, + jq_pattern=jq_pattern, + ) + return dataset + + @staticmethod + def get_packed_mem_map_dataset_continuous( + raw_data_path: Path, block_size: int, sample_key: str + ) -> PackedMemMapDatasetContinuous: + dataset = PackedMemMapDatasetContinuous( + raw_data_path=raw_data_path, block_size=block_size, sample_key=sample_key + ) + return dataset + + @staticmethod + def get_packed_mem_map_dataset_megatron( + raw_data_path: Path, block_size: int, sample_key: str + ) -> PackedMemMapDatasetMegatron: + dataset = PackedMemMapDatasetMegatron(raw_data_path=raw_data_path, block_size=block_size, sample_key=sample_key) + return dataset + + @staticmethod + def get_open_gptx_mmap_dataset( + sample_key: str, + path: FilePath, + sequence_len: int, + num_samples: int, + seed: int = 47, + ) -> OpenGPTXMMapDataset: + # part of open gptx + dataset = OpenGPTXMMapDataset( + sample_key=sample_key, path=path, sequence_len=sequence_len, num_samples=num_samples, seed=seed + ) + + # BUG: Sometimes the dataset genereated by the OpenGPTXMMap implementation has too many samples. + # This is a workaround to fix the dataset to the size, as specified in the config! + # TODO: Fix the OpenGPTX implementation and get rid of this hack. + dataset_wrapped = OpenGPTXDatasetWrapper(open_gptx_dataset=dataset, num_samples=num_samples) + return dataset_wrapped diff --git a/src/modalities/dataloader/large_file_lines_reader.py b/src/modalities/dataloader/large_file_lines_reader.py index d548c1f4..3d45072b 100644 --- a/src/modalities/dataloader/large_file_lines_reader.py +++ b/src/modalities/dataloader/large_file_lines_reader.py @@ -1,11 +1,8 @@ import pickle -import warnings from abc import ABC, abstractmethod from pathlib import Path from typing import List -import numpy as np - class BaseReader(ABC): @abstractmethod @@ -17,7 +14,6 @@ def __getitem__(self, key: int | slice) -> str | List[str]: raise NotImplementedError -# TODO: benchmark tokenized version vs plain text version (regarding speed and storage consumption) class LargeFileLinesReader(BaseReader): def __init__(self, raw_data_path: Path, index_path: Path = None): """ @@ -33,7 +29,7 @@ def __init__(self, raw_data_path: Path, index_path: Path = None): if not self.raw_data_path.is_file(): raise FileNotFoundError("Raw data file does not exist") if not self.index_path.is_file(): - raise FileNotFoundError("Index file does not exist. Use `modalities create_memmap_index` to create one.") + raise FileNotFoundError("Index file does not exist. Use `modalities data create_raw_index` to create one.") with self.index_path.open("rb") as f: self.index = pickle.load(f) @@ -56,21 +52,6 @@ def __getitem__(self, key: int | slice) -> str | List[str]: return self.__read_from_raw_file(offset, sample_length_in_bytes) def __read_from_raw_file(self, offset: int, sample_length_in_bytes: int) -> str: - def safe_decoder(byte_char): - try: - # TODO: verify why iso-8859-1 was necessary here in the path. - # Maybe there was an issue with the actual loading of the jsonl-files - c = byte_char.decode("utf8") - except Exception as exception: - c = "" - warnings.warn(f'Encountered invalid char: "{byte_char}".') - warnings.warn(f"Encountered problem: {exception}") - return c - - string = ( - np.memmap(self.raw_data_path, mode="r", offset=offset, shape=(sample_length_in_bytes,)).view("S1").tolist() - ) - decoded_string = [] - for c in string: - decoded_string.append(safe_decoder(c)) - return "".join(decoded_string) + f = self.raw_data_path.open() + f.seek(offset) + return f.read(sample_length_in_bytes) diff --git a/src/modalities/dataloader/open_gptx_dataset/open_gptx_dataset.py b/src/modalities/dataloader/open_gptx_dataset/open_gptx_dataset.py index 793649dc..43439377 100644 --- a/src/modalities/dataloader/open_gptx_dataset/open_gptx_dataset.py +++ b/src/modalities/dataloader/open_gptx_dataset/open_gptx_dataset.py @@ -417,23 +417,3 @@ def __getitem__(self, idx: int): # Sample is of length sequence_len + 1 because target toke is part of the sample return {self.sample_key: np.array(sample, dtype=np.int64)} - - -class OpenGPTXMMapDatasetFactory: - @staticmethod - def create_dataset(num_samples: int, path: FilePath, sample_key: str, sequence_len: int) -> OpenGPTXMMapDataset: - # dataset_dir = path.parents[0] - # dataset_filename_prefix = path.stem - # text_dataset = make_dataset(path=dataset_dir.joinpath(dataset_filename_prefix)) - - # instances = OpenGPTXDataset( - # sample_key=sample_key, - # text_dataset=text_dataset, - # doc_idx=np.arange(0, len(text_dataset)), - # dataset_dir=dataset_dir, - # num_samples=num_samples, - # dataset_name=dataset_filename_prefix, - # sequence_len=sequence_len, - # ) - # return instances - pass diff --git a/src/modalities/dataloader/samplers.py b/src/modalities/dataloader/samplers.py index af3a4aa2..c5ab2699 100644 --- a/src/modalities/dataloader/samplers.py +++ b/src/modalities/dataloader/samplers.py @@ -13,6 +13,8 @@ def __init__(self, start_index: int, underlying_batch_sampler: BatchSampler): self.start_index = start_index self.underlying_batch_sampler = underlying_batch_sampler + # NOTE: we are only iterating ove the indices not the actual data + # so this is relatively cheap self.indices = list(iter(self.underlying_batch_sampler)) def __iter__(self): @@ -22,5 +24,5 @@ def __len__(self): return len(self.indices) - self.start_index @property - def sampler_batch_size(self) -> int: + def batch_size(self) -> int: return self.underlying_batch_sampler.batch_size diff --git a/src/modalities/evaluator.py b/src/modalities/evaluator.py index 78adfe1e..a2823a2a 100644 --- a/src/modalities/evaluator.py +++ b/src/modalities/evaluator.py @@ -45,11 +45,11 @@ def evaluate( ) -> Dict[str, EvaluationResultBatch]: result_dict: Dict[str, EvaluationResultBatch] = {} model.eval() + + device = torch.device(self.local_rank if torch.cuda.is_available() else "cpu") + for data_loader in data_loaders: - if torch.cuda.is_available(): - cummulated_loss = torch.zeros(3).to(torch.device(self.local_rank)) - else: - cummulated_loss = torch.zeros(3).to("cpu") + cumulated_loss = torch.zeros(3).to(device) Evaluator._publish_progress( batch_progress_publisher=self.batch_progress_publisher, @@ -66,13 +66,13 @@ def evaluate( loss_fun=loss_fun, ) - cummulated_loss[0] += batch_loss.item() # sum up batch loss - cummulated_loss[1] += len(batch) - batch_length_tensor = torch.tensor(len(batch)).to(torch.device(self.local_rank)) + cumulated_loss[0] += batch_loss.item() # sum up batch loss + cumulated_loss[1] += len(batch) + batch_length_tensor = torch.tensor(len(batch)).to(device) thoughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor) local_dataset_sample_id = Evaluator._get_local_sample_id( - batch_id=batch_id, batch_size=data_loader.sampler_batch_size + batch_id=batch_id, batch_size=data_loader.batch_size ) global_dataset_sample_id = local_sample_id_to_global_sample_id(local_dataset_sample_id) @@ -85,22 +85,20 @@ def evaluate( ) # TODO: insert reducer from outside so Evaluator is independent of FSDP total_loss = Reducer.reduce( - tensor=cummulated_loss, + tensor=cumulated_loss, operation=dist.ReduceOp.SUM, post_processing_fun=lambda t: t[0] / t[1], ) - foward_backward_time = torch.tensor(forward_backward_timer_recorder.delta_t).to( - torch.device(self.local_rank) - ) + forward_backward_time = torch.tensor(forward_backward_timer_recorder.delta_t).to(device) thoughput_aggregator.add_value( - key=ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, value=foward_backward_time + key=ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, value=forward_backward_time ) synced_num_samples = thoughput_aggregator.get_all_reduced_value(ThroughputAggregationKeys.NUM_SAMPLES) - synced_foward_backward_time = thoughput_aggregator.get_all_reduced_value( + synced_forward_backward_time = thoughput_aggregator.get_all_reduced_value( ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, reduce_operation=dist.ReduceOp.MAX ) - num_samples_per_second = synced_num_samples / synced_foward_backward_time + num_samples_per_second = synced_num_samples / synced_forward_backward_time evaluation_result = EvaluationResultBatch( losses={loss_fun.tag: total_loss}, diff --git a/src/modalities/exceptions.py b/src/modalities/exceptions.py index c5e5e3a2..07e344d5 100644 --- a/src/modalities/exceptions.py +++ b/src/modalities/exceptions.py @@ -15,4 +15,4 @@ class RunningEnvError(Exception): class TimeRecorderStateError(Exception): - pass \ No newline at end of file + pass diff --git a/src/modalities/logging_broker/message_broker.py b/src/modalities/logging_broker/message_broker.py index d5f4aec2..7b38e58f 100644 --- a/src/modalities/logging_broker/message_broker.py +++ b/src/modalities/logging_broker/message_broker.py @@ -1,12 +1,14 @@ from abc import ABC, abstractmethod from collections import defaultdict +from typing import Dict, List + from modalities.logging_broker.messages import Message, MessageTypes from modalities.logging_broker.subscriber import MessageSubscriberIF -from typing import Dict, List class MessageBrokerIF(ABC): """Interface for message broker objects.""" + @abstractmethod def add_subscriber(self, subscription: MessageTypes, subscriber: MessageSubscriberIF): raise NotImplementedError @@ -18,6 +20,7 @@ def distribute_message(self, message: Message): class MessageBroker(MessageBrokerIF): """The MessageBroker sends notifications to its subscribers.""" + def __init__(self) -> None: self.subscriptions: Dict[MessageTypes, List[MessageSubscriberIF]] = defaultdict(list) diff --git a/src/modalities/logging_broker/publisher.py b/src/modalities/logging_broker/publisher.py index 34ff834b..28cc27de 100644 --- a/src/modalities/logging_broker/publisher.py +++ b/src/modalities/logging_broker/publisher.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Generic, TypeVar -from modalities.logging_broker.message_broker import Message, MessageBroker +from modalities.logging_broker.message_broker import Message, MessageBroker from modalities.logging_broker.messages import MessageTypes T = TypeVar("T") @@ -15,6 +15,7 @@ def publish_message(self, payload: T, message_type: MessageTypes): class MessagePublisher(MessagePublisherIF[T]): """The MessagePublisher sends messages through a message broker.""" + def __init__( self, message_broker: MessageBroker, diff --git a/src/modalities/logging_broker/subscriber.py b/src/modalities/logging_broker/subscriber.py index 7e965b75..6b4e5c2d 100644 --- a/src/modalities/logging_broker/subscriber.py +++ b/src/modalities/logging_broker/subscriber.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from typing import Generic, TypeVar + from modalities.logging_broker.messages import Message T = TypeVar("T") @@ -11,4 +12,3 @@ class MessageSubscriberIF(ABC, Generic[T]): @abstractmethod def consume_message(self, message: Message[T]): raise NotImplementedError - diff --git a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py index b2965725..ff558a4d 100644 --- a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py +++ b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Optional +from typing import Dict, Optional import rich from rich.console import Group @@ -7,9 +7,10 @@ import wandb from modalities.batch import EvaluationResultBatch +from modalities.config.config import WandbMode from modalities.logging_broker.messages import Message from modalities.logging_broker.subscriber import MessageSubscriberIF -from modalities.config.config import AppConfig, WandbConfig + class DummyResultSubscriber(MessageSubscriberIF[EvaluationResultBatch]): def consume_message(self, message: Message[EvaluationResultBatch]): @@ -49,16 +50,23 @@ def consume_message(self, message: Message[EvaluationResultBatch]): class WandBEvaluationResultSubscriber(MessageSubscriberIF[EvaluationResultBatch]): """A subscriber object for the WandBEvaluationResult observable.""" - def __init__(self, num_ranks: int, project: str, experiment_id: str, mode: WandbConfig.WandbMode, dir: Path, - experiment_config: Optional[AppConfig] = None) -> None: + def __init__( + self, + project: str, + experiment_id: str, + mode: WandbMode, + directory: Path, + experiment_config: Optional[Dict] = None, + ) -> None: super().__init__() - self.num_ranks = num_ranks # experiment_config_json = None # if experiment_config is not None: # experiment_config_json = experiment_config.model_dump(mode="json") - wandb.init(project=project, name=experiment_id, mode=mode.value.lower(), dir=dir, config=experiment_config) + wandb.init( + project=project, name=experiment_id, mode=mode.value.lower(), dir=directory, config=experiment_config + ) def consume_message(self, message: Message[EvaluationResultBatch]): """Consumes a message from a message broker.""" @@ -82,6 +90,4 @@ def consume_message(self, message: Message[EvaluationResultBatch]): f"{eval_result.dataloader_tag} {metric_key}": metric_values for metric_key, metric_values in eval_result.throughput_metrics.items() } - wandb.log( - data=throughput_metrics, step=eval_result.global_train_sample_id + 1 - ) + wandb.log(data=throughput_metrics, step=eval_result.global_train_sample_id + 1) diff --git a/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py b/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py new file mode 100644 index 00000000..3d63cdad --- /dev/null +++ b/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py @@ -0,0 +1,76 @@ +from pathlib import Path +from typing import Dict, List + +from modalities.config.config import WandbMode +from modalities.dataloader.dataloader import LLMDataLoader +from modalities.logging_broker.subscriber_impl.batch_progress_subscriber import ( + DummyProgressSubscriber, + RichProgressSubscriber, +) +from modalities.logging_broker.subscriber_impl.results_subscriber import ( + DummyResultSubscriber, + RichResultSubscriber, + WandBEvaluationResultSubscriber, +) + + +class ProgressSubscriberFactory: + @staticmethod + def get_rich_progress_subscriber( + train_dataloader: LLMDataLoader, + eval_dataloaders: List[LLMDataLoader], + world_size: int, + global_num_seen_samples: int, + local_rank: int, + ) -> RichProgressSubscriber: + if local_rank == 0: + skip_num_local_train_batches = global_num_seen_samples // world_size // train_dataloader.batch_size + train_split_num_samples = { + train_dataloader.dataloader_tag: (len(train_dataloader) + skip_num_local_train_batches) + * world_size + * train_dataloader.batch_size + } + + eval_splits_num_samples = { + dataloader.dataloader_tag: len(dataloader) * world_size * dataloader.batch_size + for dataloader in eval_dataloaders + } + + subscriber = RichProgressSubscriber(world_size, train_split_num_samples, eval_splits_num_samples) + else: + subscriber = ProgressSubscriberFactory.get_dummy_progress_subscriber() + return subscriber + + @staticmethod + def get_dummy_progress_subscriber() -> DummyProgressSubscriber: + return DummyProgressSubscriber() + + +class ResultsSubscriberFactory: + @staticmethod + def get_rich_result_subscriber(num_ranks: int, local_rank: int) -> RichResultSubscriber: + if local_rank == 0: + return RichResultSubscriber(num_ranks) + else: + return ResultsSubscriberFactory.get_dummy_result_subscriber() + + @staticmethod + def get_dummy_result_subscriber() -> DummyResultSubscriber: + return DummyResultSubscriber() + + @staticmethod + def get_wandb_result_subscriber( + local_rank: int, + project: str, + experiment_id: str, + mode: WandbMode, + directory: Path = None, + experiment_config: Dict = None, + ) -> WandBEvaluationResultSubscriber: + if local_rank == 0 and (mode == WandbMode.ONLINE or mode == WandbMode.OFFLINE): + result_subscriber = WandBEvaluationResultSubscriber( + project, experiment_id, mode, directory, experiment_config + ) + else: + result_subscriber = ResultsSubscriberFactory.get_dummy_result_subscriber() + return result_subscriber diff --git a/src/modalities/loss_functions.py b/src/modalities/loss_functions.py index c3144597..bf7b4251 100644 --- a/src/modalities/loss_functions.py +++ b/src/modalities/loss_functions.py @@ -41,3 +41,83 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: # Flatten the tokens loss = self.loss_fun(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return loss + + +def nce_loss( + embedding1: torch.Tensor, embedding2: torch.Tensor, device: torch.device, is_asymmetric: bool, temperature: float +) -> torch.Tensor: + """ + This implementation calculates the noise contrastive estimation loss between embeddings of two different modalities + Implementation slightly adapted from https://arxiv.org/pdf/1912.06430.pdf, https://github.com/antoine77340/MIL-NCE_HowTo100M + changes include adding a temperature value and the choice of calculating asymmetric loss w.r.t. one modality + This implementation is adapted to contrastive loss from CoCa model https://arxiv.org/pdf/2205.01917.pdf + + Args: + embedding1 (torch.Tensor): embeddings from modality 1 of size batch_size x embed_dim. + embedding2 (torch.Tensor): embeddings from modality 2 of size batch_size x embed_dim. + device (torch.device): torch device for calculating loss. + is_asymmetric (bool): boolean value to specify if the loss is calculated in one direction or both directions. + temperature (float): temperature value for regulating loss. + + Returns: + torch.Tensor: loss tensor. + """ + # calculating the similarity matrix of size (batch_size x batch_size) + sim_matrix = torch.matmul(embedding1, embedding2.t()) / temperature + # numerator of loss: using similarity scores for all positive pairs (e.g., image and its caption) + numerator = sim_matrix * torch.eye(sim_matrix.shape[0], device=device) + numerator = numerator.sum(dim=0).view(sim_matrix.shape[0], -1) + numerator = torch.logsumexp(numerator, dim=1) + if is_asymmetric: + # denominator of loss: using all similarity scores for all pairs (positive and negative) + denominator = torch.logsumexp(sim_matrix, dim=1) + else: + # calculate bidirectional loss + numerator *= 2 + denominator = torch.logsumexp(sim_matrix, dim=1) + torch.logsumexp(sim_matrix.t(), dim=1) + return torch.mean(denominator - numerator) # calculated in log space + + +class NCELoss(Loss): + def __init__( + self, + prediction_key1: str, + prediction_key2: str, + is_asymmetric: bool = True, + temperature: float = 1.0, + tag: str = "NCELoss", + ): + """ + Noise Contrastive Estimation Loss + + Args: + prediction_key1 (str): key to access embedding 1. + prediction_key2 (str): key to access embedding 2. + is_asymmetric (bool, optional): specifies symmetric or asymmetric calculation of NCEloss. Defaults to True. + temperature (float, optional): temperature. Defaults to 1.0. + tag (str, optional): Defaults to "NCELoss". + """ + super().__init__(tag) + self.prediction_key1 = prediction_key1 + self.prediction_key2 = prediction_key2 + self.is_asymmetric = is_asymmetric + self.temperature = temperature + + def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: + """ + Args: + forward_batch (InferenceResultBatch): data batch. + + Returns: + torch.Tensor: loss tensor. + """ + embedding1 = forward_batch.get_predictions(self.prediction_key1) + embedding2 = forward_batch.get_predictions(self.prediction_key2) + + contiguous_embedding1 = embedding1.contiguous() + contiguous_embedding2 = embedding2.contiguous() + + loss = nce_loss( + contiguous_embedding1, contiguous_embedding2, embedding1.device, self.is_asymmetric, self.temperature + ) + return loss diff --git a/src/modalities/models/gpt2/collator.py b/src/modalities/models/gpt2/collator.py index 5004842f..0f7ce515 100644 --- a/src/modalities/models/gpt2/collator.py +++ b/src/modalities/models/gpt2/collator.py @@ -1,4 +1,4 @@ -from dataclasses import field +from abc import ABC, abstractmethod from typing import Dict, List import torch @@ -6,9 +6,14 @@ from modalities.batch import DatasetBatch -class GPT2LLMCollator: +class CollateFnIF(ABC): + @abstractmethod + def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: + raise NotImplementedError + + +class GPT2LLMCollateFn(CollateFnIF): def __init__(self, sample_key: str, target_key: str): - self.device: torch.device = field(default_factory=lambda: torch.device("cpu")) self.sample_key = sample_key self.target_key = target_key diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index f5356710..45d8f0d4 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -1,12 +1,12 @@ import math from enum import Enum from functools import partial -from typing import Dict +from typing import Annotated, Dict import torch import torch.nn as nn import xformers.ops as xops -from pydantic import BaseModel, confloat, conint, model_validator +from pydantic import BaseModel, Field, model_validator from torch.nn import functional as F from modalities.models.model import NNModel @@ -24,34 +24,33 @@ class ActivationType(str, Enum): FUSED_SWIGLU = "fused_swiglu" -class AttentionConfig(BaseModel): - attention_type: AttentionType - scaling_factor: conint(ge=1) - - class WeightInitailizationConfig(BaseModel): - mean: confloat(ge=0.0) - std: confloat(ge=0.0) + mean: Annotated[float, Field(strict=True, ge=0.0)] + std: Annotated[float, Field(strict=True, ge=0.0)] -class GPT2Config(BaseModel): +class GPT2LLMConfig(BaseModel): sample_key: str prediction_key: str - block_size: conint(ge=1) - vocab_size: conint(ge=1) # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency - n_layer: conint(ge=1) - n_head: conint(ge=1) - n_embd: conint(ge=1) - ffn_hidden: conint(ge=1) - dropout: confloat(ge=0.0) + block_size: Annotated[int, Field(strict=True, ge=1)] + vocab_size: Annotated[ + int, Field(strict=True, ge=1) + ] # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: Annotated[int, Field(strict=True, ge=1)] + n_head_q: Annotated[int, Field(strict=True, ge=1)] + n_head_kv: Annotated[int, Field(strict=True, ge=1)] + n_embd: Annotated[int, Field(strict=True, ge=1)] + ffn_hidden: Annotated[int, Field(strict=True, ge=1)] + + dropout: Annotated[float, Field(strict=True, ge=0.0)] bias: bool # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster - attention: AttentionConfig + attention_type: AttentionType activation: ActivationType - epsilon: confloat(ge=0.0) + epsilon: Annotated[float, Field(strict=True, ge=0.0)] weight_init: WeightInitailizationConfig @model_validator(mode="after") - def validate_sizes(self) -> "GPT2Config": + def validate_sizes(self) -> "GPT2LLMConfig": for param, param_name in zip( [self.ffn_hidden, self.vocab_size, self.n_embd], ["ffn_hidden", "vocab_size", "n_embd"] ): @@ -82,14 +81,41 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class CausalSelfAttention(nn.Module): def __init__( - self, n_head: int, n_embd: int, attention: AttentionConfig, bias: bool, dropout: float, block_size: int + self, + n_head_q: int, + n_head_kv: int, + n_embd: int, + attention_type: AttentionType, + bias: bool, + dropout: float, + block_size: int, ): super().__init__() - assert n_embd % n_head == 0 - # key, query, value projections for all heads, but in a batch - self.c_attn = nn.Linear( + assert n_embd % n_head_q == 0, ( + "Embeddings get passed to `n_head_q` different heads " + "and their dimension needs to be divisible by `n_head_q`." + ) + assert n_head_q % n_head_kv == 0, ( + "It is necessary to have `n_head_q` divisible by `n_head_kv`." + ' For more details, read about "Grouped Query Attention"' + ) + + self.n_rep = n_head_q // n_head_kv + + # query, key, value projections (separate) + self.q_attn = nn.Linear( in_features=n_embd, - out_features=attention.scaling_factor * n_embd, + out_features=n_embd, + bias=bias, + ) + self.k_attn = nn.Linear( + in_features=n_embd, + out_features=n_embd // self.n_rep, + bias=bias, + ) + self.v_attn = nn.Linear( + in_features=n_embd, + out_features=n_embd // self.n_rep, bias=bias, ) @@ -103,10 +129,12 @@ def __init__( # regularization self.attn_dropout = nn.Dropout(dropout) self.resid_dropout = nn.Dropout(dropout) - self.n_head = n_head + self.n_head_q = n_head_q + self.n_head_kv = n_head_kv + self.n_embd = n_embd self.dropout = dropout - self.flash = attention.attention_type == AttentionType.PYTORCH_FLASH_ATTENTION + self.flash = attention_type == AttentionType.PYTORCH_FLASH_ATTENTION if not self.flash: # causal mask to ensure that attention is only applied to the left in the input sequence @@ -116,15 +144,22 @@ def __init__( ) def forward(self, x: torch.Tensor) -> torch.Tensor: - B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + B, T, _ = x.size() # batch size (B), sequence length (T), embedding dimensionality (self.n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim - q, k, v = self.c_attn(x).split(self.n_embd, dim=2) - k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = self.q_attn(x) # (B, T, n_embd) + k = self.k_attn(x) # (B, T, n_embd / n_rep) + v = self.v_attn(x) # (B, T, n_embd / n_rep) - # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + q = q.view(B, T, self.n_head_q, self.n_embd // self.n_head_q).transpose(1, 2) # (B, nh_q, T, hs) + k = k.view(B, T, self.n_head_kv, self.n_embd // self.n_head_q).transpose(1, 2) # (B, nh_kv, T, hs) + v = v.view(B, T, self.n_head_kv, self.n_embd // self.n_head_q).transpose(1, 2) # (B, nh_kv, T, hs) + + # repeat k/v heads if self.n_rep > 1 + k = repeat_kv(k, self.n_rep) # (B, nh_q, T, hs) + v = repeat_kv(v, self.n_rep) # (B, nh_q, T, hs) + + # causal self-attention; Self-attend: (B, nh_q, T, hs) x (B, nh_q, hs, T) -> (B, nh_q, T, T) if self.flash: # efficient attention using Flash Attention CUDA kernels y = torch.nn.functional.scaled_dot_product_attention( @@ -141,8 +176,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) att = F.softmax(att, dim=-1) att = self.attn_dropout(att) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + y = att @ v # (B, nh_q, T, T) x (B, nh_q, T, hs) -> (B, nh_q, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, self.n_embd) # re-assemble all head outputs side by side # output projection y = self.resid_dropout(self.c_proj(y)) @@ -173,15 +208,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class Block(nn.Module): +class GPT2Block(nn.Module): def __init__( self, n_embd: int, bias: bool, epsilon: float, activation: ActivationType, - n_head: int, - attention: AttentionConfig, + n_head_q: int, + n_head_kv: int, + attention_type: AttentionType, dropout: float, block_size: int, ffn_hidden: int, @@ -189,7 +225,13 @@ def __init__( super().__init__() self.ln_1 = LayerNorm(ndim=n_embd, bias=bias, epsilon=epsilon) self.attn = CausalSelfAttention( - n_head=n_head, n_embd=n_embd, attention=attention, bias=bias, dropout=dropout, block_size=block_size + n_head_q=n_head_q, + n_head_kv=n_head_kv, + n_embd=n_embd, + attention_type=attention_type, + bias=bias, + dropout=dropout, + block_size=block_size, ) self.ln_2 = LayerNorm(ndim=n_embd, bias=bias, epsilon=epsilon) @@ -215,12 +257,13 @@ def __init__( block_size: int, vocab_size: int, n_layer: int, - n_head: int, + n_head_q: int, + n_head_kv: int, n_embd: int, ffn_hidden: int, dropout: float, bias: bool, - attention: AttentionConfig, + attention_type: AttentionType, activation: ActivationType, epsilon: float, weight_init: WeightInitailizationConfig, @@ -240,13 +283,14 @@ def __init__( drop=nn.Dropout(dropout), h=nn.ModuleList( [ - Block( + GPT2Block( n_embd=n_embd, bias=bias, epsilon=epsilon, activation=activation, - n_head=n_head, - attention=attention, + n_head_q=n_head_q, + n_head_kv=n_head_kv, + attention_type=attention_type, dropout=dropout, block_size=block_size, ffn_hidden=ffn_hidden, @@ -298,3 +342,15 @@ def forward_impl(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tenso def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return self.forward_impl(inputs) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + Source code adopted from + https://github.com/facebookresearch/llama/blob/9a001c7a0987afd7b8de94e538916eff8950a73a/llama/model.py#L164 + Adapted ordered dimensions and namings: bs=B, n_kv_heads=nh_kv, slen=T, head_dim=hs + """ + B, nh_kv, T, hs = x.shape + if n_rep == 1: + return x + return x[:, :, None, :, :].expand(B, nh_kv, n_rep, T, hs).reshape(B, nh_kv * n_rep, T, hs) diff --git a/src/modalities/models/gpt2/preprocess_dataset.py b/src/modalities/models/gpt2/preprocess_dataset.py index 99afb069..e89d591e 100644 --- a/src/modalities/models/gpt2/preprocess_dataset.py +++ b/src/modalities/models/gpt2/preprocess_dataset.py @@ -1,21 +1,25 @@ +import os from itertools import chain -from datasets import load_dataset -from transformers import GPT2TokenizerFast, GPT2LMHeadModel, GPT2Config + from accelerate import Accelerator -import os +from datasets import load_dataset +from transformers import GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast def main(): - def group_texts(examples): # Concatenate all texts. concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) - # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict. - # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. + # We drop the small remainder, and if the total_length < block_size + # we exclude this batch and return an empty dict. We could add padding if the + # model supported it instead of this drop, you can customize this part to your needs. total_length = (total_length // block_size) * block_size # Split by chunks of max_len. - result = {k: [t[i: i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items()} + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } result["labels"] = result["input_ids"].copy() return result diff --git a/src/modalities/models/huggingface/__init__.py b/src/modalities/models/huggingface/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/modalities/models/huggingface/__init__.py @@ -0,0 +1 @@ + diff --git a/src/modalities/models/huggingface/huggingface_models.py b/src/modalities/models/huggingface/huggingface_models.py new file mode 100644 index 00000000..4c66d46f --- /dev/null +++ b/src/modalities/models/huggingface/huggingface_models.py @@ -0,0 +1,84 @@ +from pathlib import Path +from typing import Any, Dict, List, Optional + +import torch +from pydantic import BaseModel +from transformers import AutoModelForCausalLM, AutoModelForMaskedLM, AutoTokenizer + +from modalities.config.lookup_enum import LookupEnum +from modalities.models.model import NNModel + +# Huggingface Model dependencies +# +# ModuleUtilsMixin +# GenerationMixin +# PushToHubMixin +# PeftAdapterMixin +# <- PreTrainedModel +# <- LlamaPreTrainedModel The bare LLaMA Model outputting raw hidden-states without any specific head on top. +# <- LlamaModel The bare LLaMA Model outputting raw hidden-states without any specific head on top. +# <- LlamaForCausalLM +# <- LlamaForSequenceClassification The LLaMa transformer with a sequence classif. head on top (lin. layer) + + +class HuggingFaceModelTypes(LookupEnum): + AutoModelForCausalLM = AutoModelForCausalLM + AutoModelForMaskedLM = AutoModelForMaskedLM + + +class HuggingFacePretrainedModelConfig(BaseModel): + model_type: HuggingFaceModelTypes + model_name: Path + prediction_key: str + huggingface_prediction_subscription_key: str + sample_key: str + model_args: Optional[Any] = None + kwargs: Optional[Any] = None + + +class HuggingFacePretrainedModel(NNModel): + def __init__( + self, + model_type: HuggingFaceModelTypes, + model_name: str, + prediction_key: str, + huggingface_prediction_subscription_key: str, + sample_key: str, + model_args: Optional[Any] = None, + kwargs: Optional[Any] = None, + ): + super().__init__() + if model_args is None: + model_args = [] + if kwargs is None: + kwargs = {} + self.prediction_key = prediction_key + self.huggingface_prediction_subscription_key = huggingface_prediction_subscription_key + self.sample_key = sample_key + + # NOTE: If the model needs to be downloaded, it is NOT necessary to guard the access for rank 0. + # This is taken care of internally in huggingface hub see: + # https://github.com/huggingface/huggingface_hub/blob/3788f537b10c7d02149d6bf017d2ce19885f90a2/src/huggingface_hub/file_download.py#L1457 + self.huggingface_model = model_type.value.from_pretrained( + model_name, local_files_only=False, *model_args, **kwargs + ) + + def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + output = self.huggingface_model.forward(inputs[self.sample_key]) + return {self.prediction_key: output[self.huggingface_prediction_subscription_key]} + + @property + def fsdp_block_names(self) -> List[str]: + return self.huggingface_model._no_split_modules + + +if __name__ == "__main__": + tokenizer = AutoTokenizer.from_pretrained("epfl-llm/meditron-7b") + model = HuggingFacePretrainedModel( + model_type=HuggingFaceModelTypes.AutoModelForCausalLM, + model_name="epfl-llm/meditron-7b", + prediction_key="logits", + huggingface_prediction_subscription_key="logits", + sample_key="input_ids", + ) + print(model) diff --git a/src/modalities/models/model.py b/src/modalities/models/model.py index d00a8043..511419b9 100644 --- a/src/modalities/models/model.py +++ b/src/modalities/models/model.py @@ -1,9 +1,11 @@ from abc import abstractmethod from typing import Dict -from modalities.batch import DatasetBatch, InferenceResultBatch + import torch import torch.nn as nn +from modalities.batch import DatasetBatch, InferenceResultBatch + class NNModel(nn.Module): def __init__(self, seed: int = None): diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py new file mode 100644 index 00000000..3df388d9 --- /dev/null +++ b/src/modalities/models/model_factory.py @@ -0,0 +1,44 @@ +from pathlib import Path +from typing import List + +import torch +import torch.nn as nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy + +from modalities.checkpointing.checkpointing import Checkpointing +from modalities.running_env.env_utils import MixedPrecisionSettings +from modalities.running_env.fsdp.fsdp_auto_wrapper import FSDPTransformerAutoWrapPolicyFactory + + +class ModelFactory: + @staticmethod + def get_checkpointed_model(checkpointing: Checkpointing, checkpoint_path: Path, model: nn.Module) -> nn.Module: + wrapped_model = checkpointing.load_model_checkpoint( + file_path=checkpoint_path, + model=model, + ) + return wrapped_model + + @staticmethod + def get_fsdp_wrapped_model( + model: nn.Module, + sync_module_states: bool, + block_names: List[str], + mixed_precision_settings: MixedPrecisionSettings, + sharding_strategy: ShardingStrategy, + ) -> FSDP: + # Here, FSDPTransformerAutoWrapPolicyFactory is hardcoded and should be passed in instead! + # we also might want to have different auto wrap policies later... + fsdp_auto_wrap_factory = FSDPTransformerAutoWrapPolicyFactory(model=model, block_names=block_names) + + # model is on CPU before input to FSDP + fsdp_model = FSDP( + model, + auto_wrap_policy=fsdp_auto_wrap_factory.get_auto_wrap_policy(), + mixed_precision=mixed_precision_settings.value, + sharding_strategy=sharding_strategy, + device_id=torch.cuda.current_device(), + sync_module_states=sync_module_states, + ) + return fsdp_model diff --git a/src/__init__.py b/src/modalities/optimizers/__init__.py similarity index 100% rename from src/__init__.py rename to src/modalities/optimizers/__init__.py diff --git a/src/modalities/optimizers/optimizer_factory.py b/src/modalities/optimizers/optimizer_factory.py new file mode 100644 index 00000000..e1282068 --- /dev/null +++ b/src/modalities/optimizers/optimizer_factory.py @@ -0,0 +1,21 @@ +import torch.nn as nn +from torch.optim import AdamW, Optimizer + +from modalities.checkpointing.checkpointing import Checkpointing + + +class OptimizerFactory: + @staticmethod + def get_adam_w(lr: float, wrapped_model: nn.Module): + model_parameters = wrapped_model.parameters() + optimizer = AdamW(params=model_parameters, lr=lr) + return optimizer + + @staticmethod + def get_checkpointed_optimizer( + checkpointing: Checkpointing, checkpoint_path, wrapped_model: nn.Module, optimizer: Optimizer + ): + wrapped_optimizer = checkpointing.load_optimizer_checkpoint( + file_path=checkpoint_path, optimizer=optimizer, wrapped_model=wrapped_model + ) + return wrapped_optimizer diff --git a/src/modalities/registry/__init__.py b/src/modalities/registry/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py new file mode 100644 index 00000000..40dfccfd --- /dev/null +++ b/src/modalities/registry/components.py @@ -0,0 +1,157 @@ +from dataclasses import dataclass +from typing import Type + +from pydantic import BaseModel +from torch.utils.data import BatchSampler, DistributedSampler +from transformers import GPT2TokenizerFast + +from modalities.checkpointing.checkpointing import Checkpointing +from modalities.checkpointing.checkpointing_execution import FSDPToDiscCheckpointing +from modalities.checkpointing.checkpointing_strategies import ( + SaveEveryKStepsCheckpointingStrategy, + SaveKMostRecentCheckpointsStrategy, +) +from modalities.config.config import ( + AdamWOptimizerConfig, + BatchSamplerConfig, + CheckpointedModelConfig, + CheckpointedOptimizerConfig, + CheckpointingConfig, + CLMCrossEntropyLossConfig, + DistributedSamplerConfig, + DummyProgressSubscriberConfig, + DummyResultSubscriberConfig, + FSDPToDiscCheckpointingConfig, + FSDPWrappedModelConfig, + GPT2LLMCollateFnConfig, + GPT2TokenizerFastConfig, + LLMDataLoaderConfig, + MemMapDatasetConfig, + OpenGPTXMMapDatasetConfig, + PackedMemMapDatasetContinuousConfig, + PackedMemMapDatasetMegatronConfig, + RichProgressSubscriberConfig, + RichResultSubscriberConfig, + SaveEveryKStepsCheckpointingStrategyConfig, + SaveKMostRecentCheckpointsStrategyConfig, + WandBEvaluationResultSubscriberConfig, +) +from modalities.dataloader.dataloader_factory import DataloaderFactory +from modalities.dataloader.dataset_factory import DatasetFactory +from modalities.logging_broker.subscriber_impl.subscriber_factory import ( + ProgressSubscriberFactory, + ResultsSubscriberFactory, +) +from modalities.loss_functions import CLMCrossEntropyLoss +from modalities.models.gpt2.collator import GPT2LLMCollateFn +from modalities.models.gpt2.gpt2_model import GPT2LLM, GPT2LLMConfig +from modalities.models.huggingface.huggingface_models import ( + HuggingFacePretrainedModel, + HuggingFacePretrainedModelConfig, +) +from modalities.models.model_factory import ModelFactory +from modalities.optimizers.optimizer_factory import OptimizerFactory + + +@dataclass +class ComponentEntity: + component_key: str + variant_key: str + component_type: Type + component_config_type: Type[BaseModel] + + +COMPONENTS = [ + # models + ComponentEntity("model", "gpt2", GPT2LLM, GPT2LLMConfig), + ComponentEntity( + "model", "huggingface_pretrained_model", HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig + ), + ComponentEntity("model", "checkpointed", ModelFactory.get_checkpointed_model, CheckpointedModelConfig), + ComponentEntity("model", "fsdp_wrapped", ModelFactory.get_fsdp_wrapped_model, FSDPWrappedModelConfig), + # losses + ComponentEntity("loss", "clm_cross_entropy_loss", CLMCrossEntropyLoss, CLMCrossEntropyLossConfig), + # optmizers + ComponentEntity("optimizer", "adam_w", OptimizerFactory.get_adam_w, AdamWOptimizerConfig), + ComponentEntity( + "optimizer", "checkpointed", OptimizerFactory.get_checkpointed_optimizer, CheckpointedOptimizerConfig + ), + # schedulers + # ComponentEntity("scheduler", "step_lr", torch.optim.lr_scheduler.StepLR, None), # TODO + # ComponentEntity("scheduler", "constant_lr", torch.optim.lr_scheduler.ConstantLR, None), # TODO + # ComponentEntity("scheduler", "onecycle_lr", torch.optim.lr_scheduler.OneCycleLR, None), # TODO + # tokenizers + ComponentEntity("tokenizer", "gpt2_tokenizer_fast", GPT2TokenizerFast, GPT2TokenizerFastConfig), + # ComponentEntity("tokenizer", "llama_tokenizer_fast", GPT2TokenizerFast, None), # TODO + # datasets + ComponentEntity("dataset", "mem_map_dataset", DatasetFactory.get_mem_map_dataset, MemMapDatasetConfig), + ComponentEntity( + "dataset", + "packed_mem_map_dataset_continuous", + DatasetFactory.get_packed_mem_map_dataset_continuous, + PackedMemMapDatasetContinuousConfig, + ), + ComponentEntity( + "dataset", + "packed_mem_map_dataset_megatron", + DatasetFactory.get_packed_mem_map_dataset_megatron, + PackedMemMapDatasetMegatronConfig, + ), + ComponentEntity( + "dataset", "open_gptx_mmap_dataset", DatasetFactory.get_open_gptx_mmap_dataset, OpenGPTXMMapDatasetConfig + ), + # samplers + ComponentEntity("sampler", "distributed_sampler", DistributedSampler, DistributedSamplerConfig), + # batch samplers + ComponentEntity("batch_sampler", "default", BatchSampler, BatchSamplerConfig), + # collators + ComponentEntity("collate_fn", "gpt_2_llm_collator", GPT2LLMCollateFn, GPT2LLMCollateFnConfig), + # data loaders + ComponentEntity("data_loader", "default", DataloaderFactory.get_dataloader, LLMDataLoaderConfig), + # ComponentEntity("data_loader", "repeating_data_loader",(RepeatingDataLoader, None), # TODO + # checkpointing + ComponentEntity("checkpointing", "default", Checkpointing, CheckpointingConfig), + # checkpointing strategies + ComponentEntity( + "checkpointing_strategy", + "save_every_k_steps_checkpointing_strategy", + SaveEveryKStepsCheckpointingStrategy, + SaveEveryKStepsCheckpointingStrategyConfig, + ), + ComponentEntity( + "checkpointing_strategy", + "save_k_most_recent_checkpoints_strategy", + SaveKMostRecentCheckpointsStrategy, + SaveKMostRecentCheckpointsStrategyConfig, + ), + # checkpointing execution + ComponentEntity( + "checkpointing_execution", "fsdp_to_disc_checkpointing", FSDPToDiscCheckpointing, FSDPToDiscCheckpointingConfig + ), + # Progress subscriber + ComponentEntity( + "progress_subscriber", + "dummy", + ProgressSubscriberFactory.get_dummy_progress_subscriber, + DummyProgressSubscriberConfig, + ), + ComponentEntity( + "progress_subscriber", + "rich", + ProgressSubscriberFactory.get_rich_progress_subscriber, + RichProgressSubscriberConfig, + ), + # Results subscriber + ComponentEntity( + "results_subscriber", "dummy", ResultsSubscriberFactory.get_dummy_result_subscriber, DummyResultSubscriberConfig + ), + ComponentEntity( + "results_subscriber", "rich", ResultsSubscriberFactory.get_rich_result_subscriber, RichResultSubscriberConfig + ), + ComponentEntity( + "results_subscriber", + "wandb", + ResultsSubscriberFactory.get_wandb_result_subscriber, + WandBEvaluationResultSubscriberConfig, + ), +] diff --git a/src/modalities/registry/registry.py b/src/modalities/registry/registry.py new file mode 100644 index 00000000..a55df227 --- /dev/null +++ b/src/modalities/registry/registry.py @@ -0,0 +1,44 @@ +from dataclasses import asdict +from typing import Dict, List, Optional, Tuple, Type + +from pydantic import BaseModel + +from modalities.registry.components import ComponentEntity + +Entity = Tuple[Type, Type[BaseModel]] + + +class Registry: + def __init__(self, components: Optional[List[ComponentEntity]] = None) -> None: + # maps component_key -> variant_key -> entity = (component, config) + self._registry_dict: Dict[str, Dict[str, Entity]] = {} + if components is not None: + for component in components: + self.add_entity(**asdict(component)) + + def add_entity( + self, component_key: str, variant_key: str, component_type: Type, component_config_type: Type[BaseModel] + ) -> None: + if component_key not in self._registry_dict: + self._registry_dict[component_key] = {} + self._registry_dict[component_key][variant_key] = (component_type, component_config_type) + + def get_component(self, component_key: str, variant_key: str) -> Type: + entity = self._get_entity(component_key, variant_key) + try: + return entity[0] + except IndexError as e: + raise ValueError(f"0 is not a valid index in registry[{component_key}][{variant_key}]") from e + + def get_config(self, component_key: str, variant_key: str) -> Type[BaseModel]: + entity = self._get_entity(component_key, variant_key) + try: + return entity[1] + except IndexError as e: + raise ValueError(f"1 is not a valid index in registry[{component_key}][{variant_key}]") from e + + def _get_entity(self, component_key: str, variant_key: str) -> Entity: + try: + return self._registry_dict[component_key][variant_key] + except KeyError as e: + raise ValueError(f"[{component_key}][{variant_key}] are not valid keys in registry") from e diff --git a/src/modalities/resolver_register.py b/src/modalities/resolver_register.py deleted file mode 100644 index 9f571efe..00000000 --- a/src/modalities/resolver_register.py +++ /dev/null @@ -1,147 +0,0 @@ -from typing import Any, Dict, List - -import torch.optim as optim -from class_resolver import ClassResolver -from pydantic import BaseModel -from torch.utils.data import BatchSampler, DataLoader, Sampler -from torch.utils.data.distributed import DistributedSampler -from transformers import PreTrainedTokenizer - -from modalities.checkpointing.checkpointing import CheckpointingExecutionIF, CheckpointingStrategyIF -from modalities.config.config import AppConfig, OptimizerTypes, SchedulerTypes -from modalities.config.lookup_types import ( - BatchSamplerTypes, - CheckpointingExectionTypes, - CheckpointingStrategyTypes, - CollatorTypes, - DataloaderTypes, - DatasetTypes, - LossTypes, - ModelTypes, - SamplerTypes, - TokenizerTypes, -) -from modalities.dataloader.dataloader import LLMDataLoader -from modalities.dataloader.dataset import Dataset -from modalities.loss_functions import CLMCrossEntropyLoss, Loss -from modalities.models.gpt2.collator import GPT2LLMCollator -from modalities.models.gpt2.gpt2_model import GPT2LLM, NNModel -from modalities.running_env.fsdp.fsdp_running_env import FSDPRunningEnv, RunningEnv, RunningEnvTypes - - -class ResolverRegister: - def __init__(self, config: AppConfig) -> None: - self._resolver_register: Dict[str, ClassResolver] = self._create_resolver_register(config=config) - - def build_component_by_config(self, config: BaseModel, extra_kwargs: Dict = {}) -> Any: - assert ( - "type_hint" in config.model_fields.keys() - ), f"Field 'type_hint' missing but needed for initalisation in {config}" - - kwargs = {key: getattr(config.config, key) for key in config.config.model_dump().keys()} - kwargs.update(extra_kwargs) # allow override via extra_kwargs, to add nested objects - return self._build_component( - register_key=config.type_hint, - register_query=config.type_hint.name, - extra_kwargs=kwargs, - ) - - def build_component_by_key_query(self, register_key: str, type_hint: str, extra_kwargs: Dict = {}) -> Any: - return self._build_component(register_key=register_key, register_query=type_hint, extra_kwargs=extra_kwargs) - - def _build_component(self, register_key: str, register_query: str, extra_kwargs: Dict = {}): - return self._resolver_register[register_key].make( - query=register_query, - pos_kwargs=extra_kwargs, - ) - - def _find_values_with_key_in_nested_structure(self, nested_structure: Dict, key: str) -> List[Any]: - found_values = [] - for k, v in nested_structure.items(): - if k == key: - found_values.append(v) - elif isinstance(v, dict): - found_values.extend(self._find_values_with_key_in_nested_structure(v, key)) - return found_values - - def _create_resolver_register(self, config: AppConfig) -> Dict[str, ClassResolver]: - set(self._find_values_with_key_in_nested_structure(nested_structure=config.model_dump(), key="type_hint")) - resolvers = { - config.running_env.type_hint: ClassResolver( - [t.value for t in RunningEnvTypes], - base=RunningEnv, - default=FSDPRunningEnv, - ), - config.model.type_hint: ClassResolver( - [t.value for t in ModelTypes], - base=NNModel, - default=GPT2LLM, - ), - config.optimizer.type_hint: ClassResolver( - [t.value for t in OptimizerTypes], - base=optim.Optimizer, - default=optim.AdamW, - ), - config.scheduler.type_hint: ClassResolver( - [t.value for t in SchedulerTypes], - base=optim.lr_scheduler.LRScheduler, - default=optim.lr_scheduler.StepLR, - ), - config.loss.type_hint: ClassResolver( - [t.value for t in LossTypes], - base=Loss, - default=CLMCrossEntropyLoss, - ), - **{ - sampler_type: ClassResolver( - classes=[t.value for t in SamplerTypes], - base=Sampler, - default=DistributedSampler, - ) - for sampler_type in SamplerTypes - }, - **{ - batch_sampler_type: ClassResolver( - classes=[t.value for t in BatchSamplerTypes], - base=BatchSampler, - default=BatchSampler, - ) - for batch_sampler_type in BatchSamplerTypes - }, - **{ - dataloader_type: ClassResolver( - [t.value for t in DataloaderTypes], - base=DataLoader, - default=LLMDataLoader, - ) - for dataloader_type in DataloaderTypes - }, - **{ - dataset_type: ClassResolver([t.value for t in DatasetTypes], base=Dataset) - for dataset_type in DatasetTypes - }, - **{ - collator_type: ClassResolver([t.value for t in CollatorTypes], base=GPT2LLMCollator) - for collator_type in CollatorTypes - }, - **{ - tokenizer_type: ClassResolver([t.value for t in TokenizerTypes], base=PreTrainedTokenizer) - for tokenizer_type in TokenizerTypes - }, - **{ - checkpointing_strategy_type: ClassResolver( - [t.value for t in CheckpointingStrategyTypes], base=CheckpointingStrategyIF - ) - for checkpointing_strategy_type in CheckpointingStrategyTypes - }, - **{ - checkpointing_execution_type: ClassResolver( - [t.value for t in CheckpointingExectionTypes], base=CheckpointingExecutionIF - ) - for checkpointing_execution_type in CheckpointingExectionTypes - }, - } - return resolvers - - def add_resolver(self, resolver_key: str, resolver: ClassResolver): - self._resolver_register[resolver_key] = resolver diff --git a/src/modalities/running_env/cuda_env.py b/src/modalities/running_env/cuda_env.py new file mode 100644 index 00000000..e8551869 --- /dev/null +++ b/src/modalities/running_env/cuda_env.py @@ -0,0 +1,27 @@ +import os + +import torch +import torch.distributed as dist + +from modalities.config.config import ProcessGroupBackendType + + +class CudaEnv: + def __init__( + self, + process_group_backend: ProcessGroupBackendType, + ) -> None: + self.process_group_backend = process_group_backend + # TODO we might want to set this from outside via the config + self.local_rank = int(os.getenv("LOCAL_RANK", "0")) + + def __enter__(self) -> "CudaEnv": + dist.init_process_group(self.process_group_backend.value) + torch.cuda.set_device(self.local_rank) + return self + + def __exit__(self, type, value, traceback): + pass + # TODO uncomment part below + # dist.barrier() # TODO check for concurrency issues + # dist.destroy_process_group() diff --git a/src/modalities/running_env/env_utils.py b/src/modalities/running_env/env_utils.py index 04d4d05a..87a41c62 100644 --- a/src/modalities/running_env/env_utils.py +++ b/src/modalities/running_env/env_utils.py @@ -1,11 +1,11 @@ -from enum import Enum - import torch import torch.cuda.nccl as nccl import torch.distributed as dist from pkg_resources import packaging from torch.distributed.fsdp import MixedPrecision +from modalities.config.lookup_enum import LookupEnum + def has_bfloat_support(): return ( @@ -48,7 +48,7 @@ def has_bfloat_support(): ) -class MixedPrecisionSettings(Enum): +class MixedPrecisionSettings(LookupEnum): FP_16 = fpSixteen BF_16 = bfSixteen BF_16_WORKING = bfSixteen_working diff --git a/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py b/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py new file mode 100644 index 00000000..634b0c12 --- /dev/null +++ b/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py @@ -0,0 +1,55 @@ +import functools +import logging +from abc import ABC, abstractmethod +from typing import Callable, List + +import torch.nn as nn +from accelerate.utils.dataclasses import FullyShardedDataParallelPlugin +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy + +from modalities.config.lookup_enum import LookupEnum + + +class FSDPAutoWrapFactoryIF(ABC): + @abstractmethod + def get_auto_wrap_policy(self) -> Callable: + raise NotImplementedError + + +class FSDPTransformerAutoWrapPolicyFactory(FSDPAutoWrapFactoryIF): + def __init__(self, model: nn.Module, block_names: List[str]) -> None: + # TODO it's problematic that we store the model in-memory here. Might get too large in RAM... + self.model = model + self.block_names = block_names + + @staticmethod + def _get_fsdp_blocks_from_block_names(model: nn.Module, block_names: List[str]) -> List[nn.Module]: + fsdp_block_types = [] + for cls_block_name in block_names: + # TODO FullyShardedDataParallelPlugin from Accelerate uses string matching to find the correct + # block class. In the long-term we should implmement this ourselves in a robuster fashion. + block_type = FullyShardedDataParallelPlugin.get_module_class_from_name(model, cls_block_name) + if block_type is None: + raise ValueError(f"Could not find block with name {cls_block_name} in model") + fsdp_block_types.append(block_type) + return fsdp_block_types + + def get_auto_wrap_policy(self) -> Callable: + transformer_layer_cls = self._get_fsdp_blocks_from_block_names(model=self.model, block_names=self.block_names) + logging.info(f"Wrapped layer classes: {transformer_layer_cls}\n") + print(f"\nWrapped layer classes: {transformer_layer_cls}\n") + + if len(transformer_layer_cls) == 0: + raise ValueError("No FSDP blocks found in model") + + auto_wrapper_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + *transformer_layer_cls, + }, + ) + return auto_wrapper_policy + + +class FSDPAutoWrapFactoryTypes(LookupEnum): + FSDPTransformerAutoWrapPolicyFactory = FSDPTransformerAutoWrapPolicyFactory diff --git a/src/modalities/running_env/fsdp/fsdp_running_env.py b/src/modalities/running_env/fsdp/fsdp_running_env.py deleted file mode 100644 index 434444bf..00000000 --- a/src/modalities/running_env/fsdp/fsdp_running_env.py +++ /dev/null @@ -1,111 +0,0 @@ -import functools -from enum import Enum -from typing import Type - -import torch -import torch.distributed as dist -import torch.nn as nn -from pydantic import BaseModel, ValidationError, validator -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import ShardingStrategy -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy - -from modalities.config.lookup_types import LookupEnum -from modalities.config.types import ProcessGroupBackendType -from modalities.models.gpt2.gpt2_model import Block -from modalities.running_env.env_utils import MixedPrecisionSettings, has_bfloat_support -from modalities.running_env.running_env import RunningEnv - - -def parse_enum_by_name(name: str, enum_type: Type[Enum]) -> Enum: - try: - return enum_type[name] - except KeyError: - raise ValidationError(f"Invalid {enum_type} member name: {name}") - - -transformer_auto_wrapper_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={ - Block, - }, -) - - -class AutoWrapPolicies(Enum): - TRANSFORMER_AUTO_WRAP_POLICY = transformer_auto_wrapper_policy - - -class FSDPRunningEnvConfig(BaseModel): - process_group_backend: ProcessGroupBackendType - local_rank: int - mixed_precision_settings: MixedPrecisionSettings - sharding_strategy: ShardingStrategy - auto_wrap_policy: AutoWrapPolicies - - @validator("mixed_precision_settings", pre=True, always=True) - def parse_mixed_precision_setting_by_name(cls, name): - mixed_precision_settings: MixedPrecisionSettings = parse_enum_by_name( - name=name, enum_type=MixedPrecisionSettings - ) - if not has_bfloat_support() and ( - mixed_precision_settings == MixedPrecisionSettings.BF_16 - or mixed_precision_settings == MixedPrecisionSettings.BF_16_WORKING - ): - raise ValueError("BF16 not supported in the current environment") - return mixed_precision_settings - - @validator("sharding_strategy", pre=True, always=True) - def parse_sharding_strategy_by_name(cls, name): - return parse_enum_by_name(name=name, enum_type=ShardingStrategy) - - @validator("auto_wrap_policy", pre=True, always=True) - def parse_auto_wrap_policy_by_name(cls, name): - return parse_enum_by_name(name=name, enum_type=AutoWrapPolicies) - - -class FSDPRunningEnv(RunningEnv): - def __init__( - self, - process_group_backend: ProcessGroupBackendType, - local_rank: int, - mixed_precision_settings: MixedPrecisionSettings, - sharding_strategy: ShardingStrategy, - auto_wrap_policy: AutoWrapPolicies, - ) -> None: - self.process_group_backend = process_group_backend - self.local_rank = local_rank - self.mixed_precision_settings = mixed_precision_settings - self.sharding_strategy = sharding_strategy - self.auto_wrap_policy = auto_wrap_policy - - def __enter__(self) -> "RunningEnv": - dist.init_process_group(self.process_group_backend.value) - torch.cuda.set_device(self.local_rank) - return self - - def __exit__(self, type, value, traceback): - pass # TODO uncomment part below - # dist.barrier() # TODO check for concurrency issues - # dist.destroy_process_group() - - def wrap_model(self, model: nn.Module, sync_module_states: bool) -> FSDP: - # model is on CPU before input to FSDP - fsdp_model = FSDP( - model, - auto_wrap_policy=self.auto_wrap_policy.value, - mixed_precision=self.mixed_precision_settings.value, - sharding_strategy=self.sharding_strategy, - device_id=torch.cuda.current_device(), - sync_module_states=sync_module_states, - ) - return fsdp_model - - -class RunningEnvTypes(LookupEnum): - FSDPRunningEnv = FSDPRunningEnv - - -class RunningEnvConfig(BaseModel): - type_hint: RunningEnvTypes - config: FSDPRunningEnvConfig diff --git a/src/modalities/running_env/running_env.py b/src/modalities/running_env/running_env.py deleted file mode 100644 index 8d121353..00000000 --- a/src/modalities/running_env/running_env.py +++ /dev/null @@ -1,15 +0,0 @@ -from abc import ABC, abstractmethod - -import torch.nn as nn - - -class RunningEnv(ABC, object): - def __enter__(self) -> "RunningEnv": - raise NotImplementedError - - def __exit__(self, type, value, traceback): - raise NotImplementedError - - @abstractmethod - def wrap_model(self, model: nn.Module, sync_module_states: bool) -> nn.Module: - raise NotImplementedError diff --git a/src/modalities/test.py b/src/modalities/test.py index ea16a091..f81c3630 100644 --- a/src/modalities/test.py +++ b/src/modalities/test.py @@ -3,7 +3,6 @@ from rich.progress import Progress with Progress() as progress: - task1 = progress.add_task("[red]Downloading...", total=1000) task2 = progress.add_task("[green]Processing...", total=1000) task3 = progress.add_task("[cyan]Cooking...", total=1000) @@ -12,4 +11,4 @@ progress.update(task1, advance=0.5) progress.update(task2, advance=0.3) progress.update(task3, advance=0.9) - time.sleep(0.02) \ No newline at end of file + time.sleep(0.02) diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index 12e8193c..6994af98 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -26,12 +26,12 @@ def __init__( local_rank: int, batch_progress_publisher: MessagePublisher[BatchProgressUpdate], evaluation_result_publisher: MessagePublisher[EvaluationResultBatch], - gradient_acc_step: int, + gradient_acc_steps: int, ) -> None: self.local_rank = local_rank self.batch_progress_publisher = batch_progress_publisher self.evaluation_result_publisher = evaluation_result_publisher - self.gradient_acc_step = gradient_acc_step + self.gradient_acc_steps = gradient_acc_steps def _train_batch( self, @@ -43,10 +43,10 @@ def _train_batch( data_loader: LLMDataLoader, ) -> torch.Tensor: result_batch = model_predict_batch(model=model, batch=batch) - loss = loss_fun(result_batch) / self.gradient_acc_step + loss = loss_fun(result_batch) / self.gradient_acc_steps loss.backward() - if (batch_id + 1) % self.gradient_acc_step == 0 or (batch_id + 1) == len(data_loader): + if (batch_id + 1) % self.gradient_acc_steps == 0 or (batch_id + 1) == len(data_loader): optimizer.step() optimizer.zero_grad() return loss @@ -63,9 +63,11 @@ def train( local_sample_id_to_global_sample_id: Callable[[int], int], ): model.train() - cummulated_loss = self._reset_loss() + cumulated_loss = self._reset_loss() thoughput_aggregator = Aggregator[ThroughputAggregationKeys]() + device = torch.device(self.local_rank if torch.cuda.is_available() else "cpu") + # batch loop batch: DatasetBatch # TODO: why do we need a barrier here? @@ -86,14 +88,14 @@ def train( ) forward_backward_time_recorder.stop() # Save the batch loss - cummulated_loss[0] += batch_loss.item() - cummulated_loss[1] += len(batch) - batch_length_tensor = torch.tensor(len(batch)).to(torch.device(self.local_rank)) + cumulated_loss[0] += batch_loss.item() + cumulated_loss[1] += len(batch) + batch_length_tensor = torch.tensor(len(batch)).to(device) thoughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor) self._publish_progress( batch_progress_publisher=self.batch_progress_publisher, local_batch_id=local_train_batch_id, - batch_size=train_loader.sampler_batch_size, + batch_size=train_loader.batch_size, dataloader_tag=train_loader.dataloader_tag, local_sample_id_to_global_sample_id=local_sample_id_to_global_sample_id, ) @@ -101,29 +103,27 @@ def train( # Check, if model should be evaluated if (local_train_batch_id + 1) % callback_interval_in_batches == 0: if local_train_batch_id > 0: - foward_backward_time = torch.tensor(forward_backward_time_recorder.delta_t).to( - torch.device(self.local_rank) - ) + forward_backward_time = torch.tensor(forward_backward_time_recorder.delta_t).to(device) forward_backward_time_recorder.reset() thoughput_aggregator.add_value( - key=ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, value=foward_backward_time + key=ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, value=forward_backward_time ) synced_num_samples = thoughput_aggregator.get_all_reduced_value( ThroughputAggregationKeys.NUM_SAMPLES ) - synced_foward_backward_time = thoughput_aggregator.get_all_reduced_value( + synced_forward_backward_time = thoughput_aggregator.get_all_reduced_value( ThroughputAggregationKeys.FORWARD_BACKWARD_TIME, reduce_operation=dist.ReduceOp.MAX ) - synced_num_samples_per_second = synced_num_samples / synced_foward_backward_time + synced_num_samples_per_second = synced_num_samples / synced_forward_backward_time # TODO: insert reducer from outside so Trainer is independent of FSDP train_loss = Reducer.reduce( - tensor=cummulated_loss, + tensor=cumulated_loss, operation=dist.ReduceOp.SUM, post_processing_fun=lambda t: t[0] / t[1], ) local_train_sample_id = Trainer._get_local_sample_id( - batch_id=local_train_batch_id, batch_size=train_loader.sampler_batch_size + batch_id=local_train_batch_id, batch_size=train_loader.batch_size ) global_train_sample_id = local_sample_id_to_global_sample_id(local_train_sample_id) @@ -144,19 +144,19 @@ def train( model.train() # TODO early stopping - cummulated_loss = self._reset_loss() + cumulated_loss = self._reset_loss() # we start the time recoder here again to also capture the time spend loading # via the dataloader. forward_backward_time_recorder.start() def _reset_loss(self): # TODO: we should handle the device assignment more centrally. - cummulated_loss = torch.zeros(2) + cumulated_loss = torch.zeros(2) if torch.cuda.is_available(): - cummulated_loss = cummulated_loss.to(torch.device(self.local_rank)) + cumulated_loss = cumulated_loss.to(torch.device(self.local_rank)) else: - cummulated_loss = cummulated_loss.to("cpu") - return cummulated_loss + cumulated_loss = cumulated_loss.to("cpu") + return cumulated_loss @staticmethod def _publish_progress( diff --git a/src/modalities/util.py b/src/modalities/util.py index 027e14a1..7eafb921 100644 --- a/src/modalities/util.py +++ b/src/modalities/util.py @@ -1,16 +1,44 @@ import time +import warnings from datetime import datetime from enum import Enum from types import TracebackType -from typing import Callable, Dict, Generic, TypeVar +from typing import Callable, Dict, Generic, Type, TypeVar import torch import torch.distributed as dist +from pydantic import ValidationError from modalities.exceptions import TimeRecorderStateError from modalities.running_env.fsdp.reducer import Reducer +def get_callback_interval_in_batches_per_rank( + callback_interval_in_samples: int, local_train_micro_batch_size: int, world_size: int, gradient_acc_steps: int +): + num_local_train_micro_batches_exact = callback_interval_in_samples / local_train_micro_batch_size / world_size + num_local_train_micro_batches_ret = max( + callback_interval_in_samples // local_train_micro_batch_size // world_size, 1 + ) + if num_local_train_micro_batches_exact != num_local_train_micro_batches_ret: + warnings.warn( + f"Calculated callback_interval_in_batches_per_rank is not an integer." + f"Clipping {num_local_train_micro_batches_exact} to {num_local_train_micro_batches_ret} " + ) + assert ( + num_local_train_micro_batches_ret % gradient_acc_steps == 0 + ), "callback_interval_in_batches_per_rank must be divisible by gradient_acc_steps" + return num_local_train_micro_batches_ret + + +def parse_enum_by_name(name: str, enum_type: Type[Enum]) -> Enum: + try: + val = enum_type[name] + return val + except KeyError: + raise ValidationError(f"Invalid {enum_type} member name: {name}") + + def get_date_of_run(): """create date and time for file save uniqueness example: 2022-05-07__14-31-22' diff --git a/src/modalities/utils/generate_text.py b/src/modalities/utils/generate_text.py index 21d86c2c..d503b4fd 100755 --- a/src/modalities/utils/generate_text.py +++ b/src/modalities/utils/generate_text.py @@ -10,15 +10,16 @@ from pathlib import Path import torch -from omegaconf import OmegaConf from torch.nn import functional as F from transformers import PreTrainedTokenizer -from modalities.config.config import AppConfig -from modalities.resolver_register import ResolverRegister +from modalities.config.component_factory import ComponentFactory +from modalities.config.config import ComponentsInferenceModel, load_app_config_dict +from modalities.registry.components import COMPONENTS +from modalities.registry.registry import Registry chat_prefix = """ -This is a converstation between a user and a helpful bot, which answers the user's questsions as good as possible. +This is a conversation between a user and a helpful bot, which answers the user's questions as good as possible. user: What is 1+1? bot: 1+1 is 2. @@ -95,11 +96,15 @@ def main(model_path: Path, config_path: Path, tokenizer: PreTrainedTokenizer, ma state_dict = torch.load(path) print(f"using {model_path}") - config_dict = OmegaConf.load(config_path) - config_dict = OmegaConf.to_container(config_dict, resolve=True) - config = AppConfig.model_validate(config_dict) - resolvers = ResolverRegister(config=config) - model: torch.nn.Module = resolvers.build_component_by_config(config=config.model) + config_dict = load_app_config_dict(config_path) + registry = Registry(COMPONENTS) + component_factory = ComponentFactory(registry=registry) + components = component_factory.build_components( + config_dict=config_dict, components_model_type=ComponentsInferenceModel + ) + + model = components.wrapped_model + model.load_state_dict(state_dict) model.eval() @@ -109,11 +114,11 @@ def main(model_path: Path, config_path: Path, tokenizer: PreTrainedTokenizer, ma if chat is True: prompt = input("enter question> ").strip() prompt = chat_prefix + chat_prompt_template.format(prompt=prompt) - generate(model, tokenizer, prompt, config.model.config.block_size, max_new_tokens) + generate(model, tokenizer, prompt, model.config.block_size, max_new_tokens) else: prompt = input("enter prompt> ") print(prompt, end="") - generate(model, tokenizer, prompt, config.model.config.block_size, max_new_tokens) + generate(model, tokenizer, prompt, model.config.block_size, max_new_tokens) except KeyboardInterrupt: print("closing app...") break diff --git a/tests/checkpointing/gpt2_config.yaml b/tests/checkpointing/gpt2_config.yaml index ea401600..8a4b4946 100644 --- a/tests/checkpointing/gpt2_config.yaml +++ b/tests/checkpointing/gpt2_config.yaml @@ -1,24 +1,22 @@ -llm_model_conf: - sample_key: input_ids - prediction_key: "logits" - block_size: 1024 - vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency - n_layer: 12 - n_head: 12 - ffn_hidden: 2048 - n_embd: 768 - dropout: 0.0 - bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster - attention: - attention_type: pytorch_flash_attention - scaling_factor: 3 - activation: fused_swiglu - epsilon: 1e-5 - weight_init: - mean: 0.0 - std: 0.02 - -running_env_conf: - process_group_backend: "nccl" - local_rank: ${oc.env:LOCAL_RANK} +model: + component_key: model + variant_key: gpt2 + config: + sample_key: "input_ids" # TODO reference this + prediction_key: "logits" # TODO reference this + block_size: 256 # TODO reference this (same as sequence length) + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 2 + n_head_q: 4 + n_head_kv: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_type: default_attention # pytorch_flash_attention + activation: gelu + epsilon: 1e-5 + weight_init: + mean: 0.0 + std: 0.02 diff --git a/tests/checkpointing/test_checkpoint_execution_functions.py b/tests/checkpointing/test_checkpoint_execution_functions.py index 83e0dbe0..229e703a 100644 --- a/tests/checkpointing/test_checkpoint_execution_functions.py +++ b/tests/checkpointing/test_checkpoint_execution_functions.py @@ -3,8 +3,10 @@ import pytest import torch.nn as nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy -from src.modalities.checkpointing.checkpointing_execution import FSDPToDiscCheckpointing +from modalities.checkpointing.checkpointing_execution import FSDPToDiscCheckpointing +from modalities.running_env.env_utils import MixedPrecisionSettings @pytest.mark.skip @@ -28,7 +30,12 @@ def test_get_paths_to_delete(tmp_path): # pytest temp path p.write_text(CONTENT) checkpointing = FSDPToDiscCheckpointing( - checkpoint_path=d, experiment_id=str(1), global_rank=0, model_wrapping_fn=dummy_method + checkpoint_path=d, + experiment_id=str(1), + global_rank=0, + block_names=["model"], + mixed_precision_settings=MixedPrecisionSettings.BF_16, + sharding_strategy=ShardingStrategy.FULL_SHARD, ) files_paths_to_delete = checkpointing._get_paths_to_delete(global_train_sample_id=100) assert len(files_paths_to_delete) != 0 @@ -50,7 +57,9 @@ def test_delete_checkpoint(tmpdir): checkpoint_path=directory, experiment_id=experiment_id, global_rank=0, - model_wrapping_fn=dummy_method, + block_names=["model"], + mixed_precision_settings=MixedPrecisionSettings.BF_16, + sharding_strategy=ShardingStrategy.FULL_SHARD, ) checkpointing._delete_checkpoint(global_train_sample_id=100) assert is_empty_directory((directory / experiment_id).__str__()) diff --git a/tests/checkpointing/test_fsdp_to_disc_checkpointing.py b/tests/checkpointing/test_fsdp_to_disc_checkpointing.py index 13bface2..a04bd44a 100644 --- a/tests/checkpointing/test_fsdp_to_disc_checkpointing.py +++ b/tests/checkpointing/test_fsdp_to_disc_checkpointing.py @@ -1,69 +1,73 @@ +import os import tempfile from copy import deepcopy from pathlib import Path -from typing import Any, Dict, Generator +from typing import Dict import pytest import torch import torch.distributed as dist -from pydantic import BaseModel +import torch.nn as nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy from torch.nn import CrossEntropyLoss from torch.optim import AdamW, Optimizer from modalities.__main__ import load_app_config_dict -from modalities.checkpointing.checkpointing_execution import FSDPToDiscCheckpointing -from modalities.models.gpt2.gpt2_model import GPT2LLM, GPT2Config -from modalities.running_env.fsdp.fsdp_running_env import FSDPRunningEnv, FSDPRunningEnvConfig, RunningEnv +from modalities.checkpointing.checkpointing_execution import CheckpointingEntityType, FSDPToDiscCheckpointing +from modalities.config.component_factory import ComponentFactory +from modalities.config.config import ProcessGroupBackendType +from modalities.models.gpt2.gpt2_model import GPT2LLM, GPT2LLMConfig +from modalities.models.model_factory import ModelFactory +from modalities.optimizers.optimizer_factory import OptimizerFactory +from modalities.running_env.cuda_env import CudaEnv +from modalities.running_env.env_utils import MixedPrecisionSettings # NOTE: We need to run the tests in a torch distributed environment with at least two GPUs. # CUDA_VISIBLE_DEVICES=0,1 torchrun --rdzv-endpoint localhost:29502 --nnodes 1 --nproc_per_node 2 \ -# /path/to/pytest path/to/test_fsdp_to_disc_checkpointing.py +# $(which pytest) path/to/test_fsdp_to_disc_checkpointing.py _ROOT_DIR = Path(__file__).parents[1] -class ExperimentConfig(BaseModel): - llm_model_conf: GPT2Config # Named it llm_model_conf as model_ is a protected namespace in pydantic - running_env_conf: FSDPRunningEnvConfig - - -@pytest.mark.skip( - reason="Need to fix absolute path for config_file_path and needs to be run via " - "torchrun in a torch distributed environment (torchrun)" +@pytest.mark.skipif( + "RANK" not in os.environ or torch.cuda.device_count() < 2, + reason="This e2e test requires 2 GPUs and a torchrun distributed environment.", ) class TestFSDPToDiscCheckpointing: - @pytest.fixture - def experiment_config(self) -> ExperimentConfig: - config_file_path = _ROOT_DIR / Path("tests/checkpointing/gpt2_config.yaml") + @pytest.fixture(scope="function") + def gpt2_model_config(self) -> GPT2LLMConfig: + config_file_path = Path("tests/checkpointing/gpt2_config.yaml") config_dict = load_app_config_dict(config_file_path=config_file_path) - experiment_config = ExperimentConfig.model_validate(config_dict) - return experiment_config + config = GPT2LLMConfig(**config_dict["model"]["config"]) + return config @pytest.fixture(scope="function") - def gpt2_model(self, experiment_config: ExperimentConfig) -> GPT2LLM: - model = GPT2LLM(config=experiment_config.llm_model_conf) + def gpt2_model(self, gpt2_model_config: GPT2LLMConfig) -> GPT2LLM: + config_dict = ComponentFactory.base_model_to_dict(gpt2_model_config) + model = GPT2LLM(**config_dict) return model @pytest.fixture(scope="function") - def gpt2_model_2(self, experiment_config: ExperimentConfig) -> GPT2LLM: - model = GPT2LLM(config=experiment_config.llm_model_conf) + def gpt2_model_2(self, gpt2_model_config: GPT2LLMConfig) -> GPT2LLM: + config_dict = ComponentFactory.base_model_to_dict(gpt2_model_config) + model = GPT2LLM(**config_dict) return model @pytest.fixture - def fsdp_running_env(self, experiment_config: ExperimentConfig) -> Generator[RunningEnv, Any, Any]: - running_env = FSDPRunningEnv(**dict(experiment_config.running_env_conf)) - with running_env as running_env: - yield running_env - - @pytest.fixture - def fsdp_wrapped_model(self, gpt2_model: GPT2LLM, fsdp_running_env) -> FSDP: - wrapped_model: FSDP = FSDPRunningEnv.wrap_model(gpt2_model, sync_module_states=True) + def fsdp_wrapped_model(self, gpt2_model: GPT2LLM) -> FSDP: + wrapped_model: FSDP = ModelFactory.get_fsdp_wrapped_model( + gpt2_model, + sync_module_states=True, + block_names=["GPT2Block"], + mixed_precision_settings=MixedPrecisionSettings.FP_16, + sharding_strategy=ShardingStrategy.FULL_SHARD, + ) return wrapped_model @pytest.fixture - def optimizer(self, fsdp_wrapped_model: GPT2LLM) -> Optimizer: - optimizer = AdamW(fsdp_wrapped_model.parameters(), lr=0.001) + def optimizer(self, fsdp_wrapped_model: nn.Module) -> Optimizer: + optimizer = OptimizerFactory.get_adam_w(wrapped_model=fsdp_wrapped_model, lr=0.001) return optimizer @pytest.fixture @@ -71,20 +75,23 @@ def temporary_checkpoint_folder_path(self): with tempfile.TemporaryDirectory() as tmp_dir_path: yield Path(tmp_dir_path) + @pytest.fixture(autouse=True) + def cuda_env_context(self): + with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl): + yield + @staticmethod - def _generate_batch(experiment_config: ExperimentConfig): + def _generate_batch(gpt2_model_config: GPT2LLMConfig): # prepare input and targets - data = torch.randint( - 0, experiment_config.llm_model_conf.vocab_size, (8, experiment_config.llm_model_conf.block_size + 1) - ).cuda() - batch_input_ids_dict = {experiment_config.llm_model_conf.sample_key: data[:, :-1]} + data = torch.randint(0, gpt2_model_config.vocab_size, (8, gpt2_model_config.block_size + 1)).cuda() + batch_input_ids_dict = {gpt2_model_config.sample_key: data[:, :-1]} batch_target_ids = data[:, 1:] batch_target_ids = batch_target_ids.contiguous() return batch_input_ids_dict, batch_target_ids @staticmethod def _forward_backward_pass( - experiment_config: ExperimentConfig, + gpt2_model_config: GPT2LLMConfig, model: FSDP, optimizer: Optimizer, batch_input_ids_dict: Dict, @@ -96,7 +103,7 @@ def _forward_backward_pass( optimizer.zero_grad() # forward pass - predictions = model.forward(inputs=batch_input_ids_dict)[experiment_config.llm_model_conf.prediction_key] + predictions = model.forward(inputs=batch_input_ids_dict)[gpt2_model_config.prediction_key] predictions = predictions.contiguous() # backward pass loss = ce_loss(predictions.view(-1, predictions.size(-1)), batch_target_ids.view(-1)) @@ -147,25 +154,27 @@ def test_save_checkpoint_after_backward_pass( optimizer: Optimizer, temporary_checkpoint_folder_path: Path, gpt2_model_2: GPT2LLM, - experiment_config: ExperimentConfig, + gpt2_model_config: GPT2LLMConfig, ): experiment_id = "0" - global_train_batch_id = 1 + global_train_sample_id = 1 checkpointing = FSDPToDiscCheckpointing( checkpoint_path=temporary_checkpoint_folder_path, experiment_id=experiment_id, global_rank=dist.get_rank(), - model_wrapping_fn=FSDPRunningEnv.wrap_model, + block_names=["GPT2Block"], + mixed_precision_settings=MixedPrecisionSettings.FP_16, + sharding_strategy=ShardingStrategy.FULL_SHARD, ) untrained_model_parameters = [p.clone() for p in fsdp_wrapped_model.parameters()] untrained_optimizer_state_dict = deepcopy(optimizer.state_dict()) # run backward pass - batch_input_ids_dict, batch_target_ids = self._generate_batch(experiment_config) + batch_input_ids_dict, batch_target_ids = self._generate_batch(gpt2_model_config) self._forward_backward_pass( - experiment_config=experiment_config, + gpt2_model_config=gpt2_model_config, model=fsdp_wrapped_model, optimizer=optimizer, batch_input_ids_dict=batch_input_ids_dict, @@ -176,23 +185,28 @@ def test_save_checkpoint_after_backward_pass( # save model and optimizer before backward pass checkpointing._save_checkpoint( - model=fsdp_wrapped_model, optimizer=optimizer, global_train_batch_id=global_train_batch_id + model=fsdp_wrapped_model, optimizer=optimizer, global_train_sample_id=global_train_sample_id ) # load the model checkpoint - fsdp_wrapped_model_2 = checkpointing.load_model_checkpoint( - model=gpt2_model_2, + model_checkpointing_path = checkpointing._get_checkpointing_path( experiment_id=experiment_id, - global_train_batch_id=global_train_batch_id, + global_train_sample_id=global_train_sample_id, + entity_type=CheckpointingEntityType.MODEL, + ) + fsdp_wrapped_model_2 = checkpointing.load_model_checkpoint( + model=gpt2_model_2, file_path=model_checkpointing_path ) optimizer_2 = AdamW(fsdp_wrapped_model_2.parameters(), lr=0.001) - checkpointing.load_optimizer_checkpoint( - optimizer=optimizer_2, - model=fsdp_wrapped_model_2, + optimizer_checkpointing_path = checkpointing._get_checkpointing_path( experiment_id=experiment_id, - global_train_batch_id=global_train_batch_id, + global_train_sample_id=global_train_sample_id, + entity_type=CheckpointingEntityType.OPTIMIZER, + ) + checkpointing.load_optimizer_checkpoint( + optimizer=optimizer_2, wrapped_model=fsdp_wrapped_model_2, file_path=optimizer_checkpointing_path ) loaded_and_updated_model_parameters = [p.clone() for p in fsdp_wrapped_model_2.parameters()] @@ -218,17 +232,17 @@ def test_save_checkpoint_after_backward_pass( # we do another forward/backward pass and check # if the weights are equally updated for the loaded model as for the not-loaded model # run backward pass - batch_input_ids_dict, batch_target_ids = self._generate_batch(experiment_config) + batch_input_ids_dict, batch_target_ids = self._generate_batch(gpt2_model_config) loss_1 = self._forward_backward_pass( - experiment_config=experiment_config, + gpt2_model_config=gpt2_model_config, model=fsdp_wrapped_model, optimizer=optimizer, batch_input_ids_dict=batch_input_ids_dict, batch_target_ids=batch_target_ids, ) loss_2 = self._forward_backward_pass( - experiment_config=experiment_config, + gpt2_model_config=gpt2_model_config, model=fsdp_wrapped_model_2, optimizer=optimizer_2, batch_input_ids_dict=batch_input_ids_dict, diff --git a/tests/config/__init__.py b/tests/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/config/components.py b/tests/config/components.py new file mode 100644 index 00000000..67c9e9a3 --- /dev/null +++ b/tests/config/components.py @@ -0,0 +1,48 @@ +from enum import Enum +from typing import List + + +class Component_V_W_X_IF: + def print(self) -> None: + print("ComponentIF") + + +# Dependencies + + +class ComponentV(Component_V_W_X_IF): + def __init__(self, val_v: str) -> None: + self.val_v = val_v + + +class ComponentW(Component_V_W_X_IF): + def __init__(self, val_w: str) -> None: + self.val_w = val_w + + +# Components + + +class ComponentX(Component_V_W_X_IF): + def __init__(self, val_x: str, single_dependency: Component_V_W_X_IF) -> None: + self.val_x = val_x + self.single_dependency = single_dependency + + +class ComponentY: + def __init__(self, val_y: str, multi_dependency: List[Component_V_W_X_IF]) -> None: + self.val_y = val_y + self.multi_dependency = multi_dependency + + +class ComponentZ: + def __init__(self, val_z: str) -> None: + self.val_z = val_z + + +class ComponentTypes(Enum): + COMP_V = ComponentV + COMP_W = ComponentW + COMP_X = ComponentX + COMP_Y = ComponentY + COMP_Z = ComponentZ diff --git a/tests/config/configs.py b/tests/config/configs.py new file mode 100644 index 00000000..9b59b748 --- /dev/null +++ b/tests/config/configs.py @@ -0,0 +1,30 @@ +from typing import Annotated, List + +from pydantic import BaseModel + +from modalities.config.config import PydanticThirdPartyTypeIF +from tests.config.components import Component_V_W_X_IF + +PydanticComponent_V_W_X_IF_Type = Annotated[Component_V_W_X_IF, PydanticThirdPartyTypeIF(Component_V_W_X_IF)] + + +class CompVConfig(BaseModel): + val_v: str + + +class CompWConfig(BaseModel): + val_w: str + + +class CompXConfig(BaseModel): + val_x: str + single_dependency: PydanticComponent_V_W_X_IF_Type + + +class CompYConfig(BaseModel): + val_y: str + multi_dependency: List[PydanticComponent_V_W_X_IF_Type] + + +class CompZConfig(BaseModel): + val_z: str diff --git a/tests/config/custom_components.py b/tests/config/custom_components.py new file mode 100644 index 00000000..dd849e0d --- /dev/null +++ b/tests/config/custom_components.py @@ -0,0 +1,35 @@ +from abc import ABC +from enum import Enum +from typing import Literal + +from pydantic import BaseModel, validator + + +class CustomComponent1: + def __init__(self, val_1: str) -> None: + self.val_1 = val_1 + + +class CustomComponentTypes(Enum): + CUSTOM_COMP_1 = CustomComponent1 + + +class CustomCompConfigABC(BaseModel, ABC): + # TODO make this a string and then implement the mapping + # to the class outside of the basemodel (i.e. in the factory) + type_hint: Enum + + @validator("type_hint", pre=True, allow_reuse=True, check_fields=False) + def _string_to_enum(cls, key: str): + if isinstance(key, str): + try: + key = CustomComponentTypes[key] + except KeyError as e: + raise ValueError(f"{key} is not a valid ComponentType") from e + return key + return key + + +class CustomComp1Config(CustomCompConfigABC): + type_hint: Literal[CustomComponentTypes.CUSTOM_COMP_1] + val_1: str diff --git a/tests/config/test_component_factory.py b/tests/config/test_component_factory.py new file mode 100644 index 00000000..94644da5 --- /dev/null +++ b/tests/config/test_component_factory.py @@ -0,0 +1,111 @@ +from pathlib import Path + +import pytest + +from modalities.config.component_factory import ComponentFactory +from modalities.config.config import load_app_config_dict +from modalities.registry.components import ComponentEntity +from modalities.registry.registry import Registry +from tests.config.components import ComponentV, ComponentW, ComponentX, ComponentY +from tests.config.configs import CompVConfig, CompWConfig, CompXConfig, CompYConfig + + +@pytest.fixture(scope="function") +def component_factory() -> ComponentFactory: + components = [ + ComponentEntity("COMP_V", "default", ComponentV, CompVConfig), + ComponentEntity("COMP_W", "default", ComponentW, CompWConfig), + ComponentEntity("COMP_X", "default", ComponentX, CompXConfig), + ComponentEntity("COMP_Y", "default", ComponentY, CompYConfig), + ] + + registry = Registry(components=components) + component_factory = ComponentFactory(registry=registry) + return component_factory + + +@pytest.mark.parametrize( + "config_file_path", + [ + Path("tests/config/test_configs/config_backward_reference.yaml"), + Path("tests/config/test_configs/config_forward_reference.yaml"), + ], +) +def test_backward_reference(config_file_path: Path, component_factory: ComponentFactory): + component_names = ["comp_x_1", "comp_y_1"] + + config_dict = load_app_config_dict(config_file_path=config_file_path) + + components = component_factory._build_config(config_dict=config_dict, component_names=component_names) + + # make sure that the reference is not identical, despite both being of type COMP_W + assert components["comp_x_1"].single_dependency != components["comp_y_1"].multi_dependency[0] + # make sure that the reference is identical, since we are referencing comp_x_1 in the multi depencency of comp_y_1 + assert components["comp_x_1"] == components["comp_y_1"].multi_dependency[2] + + +@pytest.mark.parametrize( + "config_file_path", + [ + Path("tests/config/test_configs/config_non_existing_reference.yaml"), + ], +) +def test_non_existing_reference(config_file_path: Path, component_factory: ComponentFactory): + component_names = ["comp_x_1", "comp_y_1"] + + config_dict = load_app_config_dict(config_file_path=config_file_path) + + with pytest.raises(KeyError): + component_factory._build_config(config_dict=config_dict, component_names=component_names) + + +@pytest.mark.parametrize( + "config_file_path", + [ + Path("tests/config/test_configs/config_hierarchical_list_component.yaml"), + ], +) +def test_hierarchical_component_instantiation(config_file_path: Path, component_factory: ComponentFactory): + component_names = ["comp_y_1"] + + config_dict = load_app_config_dict(config_file_path=config_file_path) + + components = component_factory._build_config(config_dict=config_dict, component_names=component_names) + + assert isinstance(components["comp_y_1"].multi_dependency[0], ComponentW) + assert isinstance(components["comp_y_1"].multi_dependency[1], ComponentV) + assert isinstance(components["comp_y_1"], ComponentY) + + +@pytest.mark.parametrize( + "config_file_path", + [ + Path("tests/config/test_configs/config_hierarchical_list_component.yaml"), + ], +) +def test_component_filter(config_file_path: Path, component_factory: ComponentFactory): + component_names = ["comp_y_1"] + + config_dict = load_app_config_dict(config_file_path=config_file_path) + + components = component_factory._build_config(config_dict=config_dict, component_names=component_names) + assert "comp_y_1" in components + + component_names += "abc" + with pytest.raises(KeyError): + components = component_factory._build_config(config_dict=config_dict, component_names=component_names) + + +@pytest.mark.parametrize( + "config_file_path", + [ + Path("tests/config/test_configs/config_single_component.yaml"), + ], +) +def test_single_component(config_file_path: Path, component_factory: ComponentFactory): + component_names = ["custom_comp_1"] + + config_dict = load_app_config_dict(config_file_path=config_file_path) + + components = component_factory._build_config(config_dict=config_dict, component_names=component_names) + assert "custom_comp_1" in components diff --git a/tests/config/test_configs/config_backward_reference.yaml b/tests/config/test_configs/config_backward_reference.yaml new file mode 100644 index 00000000..bf82cb7d --- /dev/null +++ b/tests/config/test_configs/config_backward_reference.yaml @@ -0,0 +1,28 @@ +comp_x_1: + component_key: COMP_X + variant_key: default + config: + val_x: "some other value X" + single_dependency: + component_key: COMP_W + variant_key: default + config: + val_w: "some other value w" + +comp_y_1: + component_key: COMP_Y + variant_key: default + config: + val_y: "some other value y" + multi_dependency: + - component_key: COMP_W + variant_key: default + config: + val_w: "some other value w" + - component_key: COMP_V + variant_key: default + config: + val_v: "some other value v" + - instance_key: comp_x_1 + pass_type: BY_REFERENCE + diff --git a/tests/config/test_configs/config_forward_reference.yaml b/tests/config/test_configs/config_forward_reference.yaml new file mode 100644 index 00000000..d0ff73f2 --- /dev/null +++ b/tests/config/test_configs/config_forward_reference.yaml @@ -0,0 +1,27 @@ +comp_y_1: + component_key: COMP_Y + variant_key: default + config: + val_y: "some other value y" + multi_dependency: + - component_key: COMP_W + variant_key: default + config: + val_w: "some other value w" + - component_key: COMP_V + variant_key: default + config: + val_v: "some other value v" + - instance_key: comp_x_1 + pass_type: BY_REFERENCE + +comp_x_1: + component_key: COMP_X + variant_key: default + config: + val_x: "some other value X" + single_dependency: + component_key: COMP_W + variant_key: default + config: + val_w: "some other value w" \ No newline at end of file diff --git a/tests/config/test_configs/config_hierarchical_list_component.yaml b/tests/config/test_configs/config_hierarchical_list_component.yaml new file mode 100644 index 00000000..1298b590 --- /dev/null +++ b/tests/config/test_configs/config_hierarchical_list_component.yaml @@ -0,0 +1,15 @@ + +comp_y_1: + component_key: COMP_Y + variant_key: default + config: + val_y: "some other value y" + multi_dependency: + - component_key: COMP_W + variant_key: default + config: + val_w: "some other value w" + - component_key: COMP_V + variant_key: default + config: + val_v: "some other value v" \ No newline at end of file diff --git a/tests/config/test_configs/config_non_existing_reference.yaml b/tests/config/test_configs/config_non_existing_reference.yaml new file mode 100644 index 00000000..91bfa504 --- /dev/null +++ b/tests/config/test_configs/config_non_existing_reference.yaml @@ -0,0 +1,17 @@ +comp_y_1: + component_key: COMP_Y + variant_key: default + config: + val_y: "some other value y" + multi_dependency: + - component_key: COMP_W + variant_key: default + config: + val_w: "some other value w" + - component_key: COMP_V + variant_key: default + config: + val_v: "some other value v" + - instance_key: comp_x_1 + pass_type: BY_REFERENCE + diff --git a/tests/config/test_configs/config_single_component.yaml b/tests/config/test_configs/config_single_component.yaml new file mode 100644 index 00000000..01bedf8a --- /dev/null +++ b/tests/config/test_configs/config_single_component.yaml @@ -0,0 +1,5 @@ +custom_comp_1: + component_key: COMP_V + variant_key: default + config: + val_v: "some value v" \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index f94133ce..1ee049ca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import os import pickle from pathlib import Path +from typing import Dict from unittest.mock import MagicMock import pytest @@ -10,9 +11,8 @@ from torch.utils.data.sampler import BatchSampler, SequentialSampler from transformers import GPT2TokenizerFast -from modalities.__main__ import load_app_config_dict from modalities.checkpointing.checkpointing import CheckpointingIF -from modalities.config.config import AppConfig +from modalities.config.config import load_app_config_dict from modalities.dataloader.create_index import IndexGenerator from modalities.dataloader.dataloader import LLMDataLoader from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader @@ -30,10 +30,11 @@ def dummy_packed_data_path(tmpdir) -> Path: data = b"" header_size_in_bytes = 8 - int_size_in_bytes = 4 + token_size_in_bytes = 4 tokens = list(range(20)) - data += (len(tokens) * int_size_in_bytes).to_bytes(header_size_in_bytes, byteorder="big") - data += b"".join([t.to_bytes(int_size_in_bytes, byteorder="big") for t in tokens]) + data += (len(tokens) * token_size_in_bytes).to_bytes(header_size_in_bytes, byteorder="big") + data += token_size_in_bytes.to_bytes(4, byteorder="big") + data += b"".join([t.to_bytes(token_size_in_bytes, byteorder="big") for t in tokens]) index = [(4, 24), (28, 40), (68, 12), (80, 4)] # [(index,len), ...] -> in 4 bytes #lengths: 6,10,3,1 data += pickle.dumps(index) dummy_packed_data_path = Path(tmpdir, "dummy.pbin") @@ -42,14 +43,13 @@ def dummy_packed_data_path(tmpdir) -> Path: @pytest.fixture -def dummy_config(monkeypatch) -> AppConfig: +def dummy_config(monkeypatch) -> Dict: monkeypatch.setenv("RANK", "0") monkeypatch.setenv("LOCAL_RANK", "0") monkeypatch.setenv("WORLD_SIZE", "1") dummy_config_path = _ROOT_DIR / Path("config_files/config_lorem_ipsum.yaml") config_dict = load_app_config_dict(dummy_config_path) - app_config = AppConfig.model_validate(config_dict) - return app_config + return config_dict, dummy_config_path @dataclasses.dataclass @@ -77,7 +77,7 @@ def indexed_dummy_data_path(dummy_data_path) -> DataPathCollection: @pytest.fixture def gpt2_tokenizer() -> GPT2TokenizerFast: - default_gpt2_tokenizer_path = Path(__file__).parents[1] / Path("data", "tokenizer", "tokenizer.json") + default_gpt2_tokenizer_path = Path(__file__).parents[1] / Path("data", "tokenizer", "tokenizer_gpt2.json") assert default_gpt2_tokenizer_path.is_file() return GPT2TokenizerFast(tokenizer_file=str(default_gpt2_tokenizer_path)) @@ -123,7 +123,7 @@ def trainer(progress_publisher_mock): local_rank=int(os.getenv("LOCAL_RANK")), batch_progress_publisher=progress_publisher_mock, evaluation_result_publisher=progress_publisher_mock, - gradient_acc_step=1, + gradient_acc_steps=1, ) diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py index 399226d8..3676cb4f 100644 --- a/tests/dataloader/test_dataloader.py +++ b/tests/dataloader/test_dataloader.py @@ -1,11 +1,15 @@ +from typing import Dict + import torch +from pydantic import BaseModel from torch.utils.data import BatchSampler, SequentialSampler -from modalities.config.config import AppConfig +from modalities.config.component_factory import ComponentFactory +from modalities.config.config import PydanticLLMDataLoaderIFType from modalities.dataloader.dataloader import LLMDataLoader -from modalities.dataloader.dataloader_factory import DataloaderFactory from modalities.dataloader.samplers import ResumableBatchSampler -from modalities.resolver_register import ResolverRegister +from modalities.registry.components import COMPONENTS +from modalities.registry.registry import Registry def test_resumable_dataloader() -> LLMDataLoader: @@ -21,18 +25,25 @@ def test_resumable_dataloader() -> LLMDataLoader: assert (flat_samples == original_samples).all() -def test_dataloader_from_config(dummy_config: AppConfig): - resolvers = ResolverRegister(config=dummy_config) +def test_dataloader_from_config(dummy_config: Dict): start_index = 2 - dataloader_1: LLMDataLoader = DataloaderFactory.get_dataloader( - resolvers=resolvers, config=dummy_config.data.train_dataloader, skip_num_batches=start_index - ) - dataset = dataloader_1.dataset + config_dict, _ = dummy_config + config_dict["train_dataloader"]["config"]["skip_num_batches"] = start_index - distributed_sampler = dataloader_1.batch_sampler.underlying_batch_sampler.sampler - batch_sampler = BatchSampler( - sampler=distributed_sampler, batch_size=dataloader_1.sampler_batch_size, drop_last=False + class DataloaderTestModel(BaseModel): + train_dataloader: PydanticLLMDataLoaderIFType + + registry = Registry(COMPONENTS) + component_factory = ComponentFactory(registry=registry) + components: DataloaderTestModel = component_factory.build_components( + config_dict=config_dict, components_model_type=DataloaderTestModel ) + + dataloader_1: LLMDataLoader = components.train_dataloader + dataset = dataloader_1.dataset + resumable_batch_sampler: ResumableBatchSampler = dataloader_1.batch_sampler + distributed_sampler = resumable_batch_sampler.underlying_batch_sampler.sampler + batch_sampler = BatchSampler(sampler=distributed_sampler, batch_size=dataloader_1.batch_size, drop_last=False) dataloader_2 = LLMDataLoader( dataloader_tag="train", dataset=dataset, batch_sampler=batch_sampler, collate_fn=dataloader_1.collate_fn ) @@ -40,7 +51,7 @@ def test_dataloader_from_config(dummy_config: AppConfig): samples_1 = [batch for _, batch in zip(range(10), dataloader_1)] samples_2 = [batch for _, batch in zip(range(10), dataloader_2)] - assert dataloader_1.sampler_batch_size * len(dataloader_2) == len(dataset) + assert dataloader_1.batch_size * len(dataloader_2) == len(dataset) assert len(dataloader_1) + start_index == len(dataloader_2) diff --git a/tests/dataloader/test_large_file_lines_reader.py b/tests/dataloader/test_large_file_lines_reader.py index e91287ed..a2dc546e 100644 --- a/tests/dataloader/test_large_file_lines_reader.py +++ b/tests/dataloader/test_large_file_lines_reader.py @@ -1,6 +1,7 @@ import json import pickle import tempfile +import warnings from pathlib import Path import pytest @@ -37,9 +38,11 @@ def generate_data_index_file(data_path: Path, **kwargs): dummy_dst_path.unlink(missing_ok=True) indexer.create_index(dummy_dst_path) - with pytest.raises(ValueError): - generate_data_index_file(plain_text_data_path) - generate_data_index_file(plain_text_data_path, drop_faulty_entries=True) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + with pytest.raises(ValueError): + generate_data_index_file(plain_text_data_path) + generate_data_index_file(plain_text_data_path, drop_faulty_entries=True) generate_data_index_file(jsonl_data_path) index = pickle.loads(dummy_dst_path.read_bytes()) diff --git a/tests/dataloader/test_packed_dataset.py b/tests/dataloader/test_packed_dataset.py index 64df0e9c..996565da 100644 --- a/tests/dataloader/test_packed_dataset.py +++ b/tests/dataloader/test_packed_dataset.py @@ -1,6 +1,9 @@ +from pathlib import Path + +import numpy as np import pytest -from modalities.dataloader.create_packed_data import PackedDataGenerator +from modalities.dataloader.create_packed_data import EmbeddedStreamData, PackedDataGenerator, join_embedded_stream_data from modalities.dataloader.dataset import PackedMemMapDatasetContinuous, PackedMemMapDatasetMegatron @@ -35,11 +38,10 @@ def test_packed_continuous_dataset_missing_file(dummy_packed_data_path): PackedMemMapDatasetContinuous(dummy_packed_data_path, block_size=10, sample_key="input_ids") -@pytest.mark.parametrize("max_num_of_tokens, expected_index_size", [(None, 12), (10, 1)]) -def test_create_packed_dataset(indexed_dummy_data_path, gpt2_tokenizer, max_num_of_tokens, expected_index_size): +def test_create_packed_dataset(indexed_dummy_data_path, gpt2_tokenizer): block_size = 5 packed_generator = PackedDataGenerator( - src_path=indexed_dummy_data_path.raw_data_path, tokenizer=gpt2_tokenizer, max_number_of_tokens=max_num_of_tokens + src_path=indexed_dummy_data_path.raw_data_path, tokenizer=gpt2_tokenizer, number_of_processes=2 ) default_packed_dataset_path = packed_generator._default_destination_path() assert not default_packed_dataset_path.is_file() @@ -51,10 +53,35 @@ def test_create_packed_dataset(indexed_dummy_data_path, gpt2_tokenizer, max_num_ start_of_jsonl_content = "0 Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor" tokenized_start_of_jsonl_content = gpt2_tokenizer(start_of_jsonl_content)["input_ids"] packed_dataset_iterator = iter(packed_dataset) - assert tokenized_start_of_jsonl_content[:block_size] == next(packed_dataset_iterator)["input_ids"] - assert tokenized_start_of_jsonl_content[block_size : 2 * block_size] == next(packed_dataset_iterator)["input_ids"] - assert len(packed_dataset.index_base) == expected_index_size + np.testing.assert_equal(tokenized_start_of_jsonl_content[:block_size], next(packed_dataset_iterator)["input_ids"]) + np.testing.assert_equal( + tokenized_start_of_jsonl_content[block_size : 2 * block_size], next(packed_dataset_iterator)["input_ids"] + ) + assert len(packed_dataset._embedded_stream_data.index_base) == 12 # check validity of index section in packed dataset - for idx, (offset, entry_length) in enumerate(packed_dataset.index_base[:-1]): - assert offset + entry_length == packed_dataset.index_base[idx + 1][0] + for idx, (offset, entry_length) in enumerate(packed_dataset._embedded_stream_data.index_base[:-1]): + assert offset + entry_length == packed_dataset._embedded_stream_data.index_base[idx + 1][0] + + +def test_join_packed_datasets(dummy_packed_data_path, tmpdir): + packed_data_clones = [Path(tmpdir, f"clone{i}.pbin") for i in range(3)] + for clone in packed_data_clones: + clone.write_bytes(dummy_packed_data_path.read_bytes()) + + joined_target_file = Path(tmpdir, "joined.pbin") + + stream_data = list(map(EmbeddedStreamData, packed_data_clones)) + join_embedded_stream_data(stream_data, joined_target_file) + + loaded_joint_data = EmbeddedStreamData(joined_target_file) + assert loaded_joint_data + assert loaded_joint_data.data_len == sum(d.data_len for d in stream_data) + + loaded_dataset = PackedMemMapDatasetContinuous(joined_target_file, block_size=2, sample_key="whatever") + original_datasets = [ + PackedMemMapDatasetContinuous(p, block_size=2, sample_key="whatever") for p in packed_data_clones + ] + assert [v for batch in loaded_dataset for v in batch["whatever"]] == [ + v for ds in original_datasets for batch in ds for v in batch["whatever"] + ] diff --git a/tests/models/test_attention.py b/tests/models/test_attention.py new file mode 100644 index 00000000..deb5f474 --- /dev/null +++ b/tests/models/test_attention.py @@ -0,0 +1,48 @@ +import pytest +import torch + +from modalities.models.gpt2.gpt2_model import AttentionType, CausalSelfAttention + + +@pytest.mark.parametrize( + "n_head_q, n_head_kv, n_embd, attention_type, successful", + [ + # Flash Attention + (4, 4, 32, AttentionType.PYTORCH_FLASH_ATTENTION, True), + (8, 2, 32, AttentionType.PYTORCH_FLASH_ATTENTION, True), + (9, 8, 32, AttentionType.PYTORCH_FLASH_ATTENTION, False), + (8, 3, 32, AttentionType.PYTORCH_FLASH_ATTENTION, False), + # Default Attention + (4, 4, 32, AttentionType.DEFAULT_ATTENTION, True), + (8, 2, 32, AttentionType.DEFAULT_ATTENTION, True), + (9, 8, 32, AttentionType.DEFAULT_ATTENTION, False), + (8, 3, 32, AttentionType.DEFAULT_ATTENTION, False), + ], +) +def test_grouped_query_attention_forward(n_head_q, n_head_kv, n_embd, attention_type, successful): + batch_size = 2 + block_size = 10 + embedding_shape = (batch_size, block_size, n_embd) + embedded_input_seq = torch.rand(size=embedding_shape, dtype=torch.float32) + + def attention_forward_pass(attention_type, block_size, embedded_input_seq, n_embd, n_head_kv, n_head_q): + attention_layer = CausalSelfAttention( + n_head_q=n_head_q, + n_head_kv=n_head_kv, + n_embd=n_embd, + attention_type=attention_type, + bias=False, + dropout=False, + block_size=block_size, + ) + output_tensor: torch.Tensor = attention_layer(embedded_input_seq) + return output_tensor + + if not successful: + with pytest.raises(Exception): + attention_forward_pass(attention_type, block_size, embedded_input_seq, n_embd, n_head_kv, n_head_q) + else: + output_tensor = attention_forward_pass( + attention_type, block_size, embedded_input_seq, n_embd, n_head_kv, n_head_q + ) + assert output_tensor.size() == embedding_shape diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py deleted file mode 100644 index 2cffe687..00000000 --- a/tests/test_evaluation.py +++ /dev/null @@ -1,52 +0,0 @@ -import torch -from transformers import AutoConfig, AutoModelForCausalLM - -from modalities.config.config import PretrainedGPTConfig -from modalities.models.gpt2.gpt2_model import ( - ActivationType, - AttentionConfig, - AttentionType, - GPT2Config, - WeightInitailizationConfig, -) -from modalities.models.gpt2.pretrained_gpt_model import PretrainedGPTModel - - -def test_pretrained_gpt_model(tmp_path): - # setup config and model - attention_config = AttentionConfig(attention_type=AttentionType("default_attention"), scaling_factor=3) - config = GPT2Config( - block_size=12, - vocab_size=128, - n_layer=2, - n_head=2, - n_embd=128, - ffn_hidden=128, - dropout=0.01, - bias=True, - attention=attention_config, - activation=ActivationType.GELU, - epsilon=1e-5, - sample_key="input_ids", - prediction_key="logits", - weight_init=WeightInitailizationConfig(mean=0, std=0.02), - ) - pretrained_config = PretrainedGPTConfig(config=config) - - model = PretrainedGPTModel(config=pretrained_config) - model.save_pretrained(tmp_path) - model = model.eval() - - # register config and model - AutoConfig.register("modalities_gpt2", PretrainedGPTConfig) - AutoModelForCausalLM.register(PretrainedGPTConfig, PretrainedGPTModel) - - # load saved model - loaded_model = AutoModelForCausalLM.from_pretrained(tmp_path) - loaded_model = loaded_model.eval() - - # check that model before and after loading return the same output - test_tensor = torch.randint(10, size=(5, 10)) - output_before_loading = model.forward(test_tensor) - output_after_loading = loaded_model.forward(test_tensor) - assert (output_after_loading == output_before_loading).all() diff --git a/tests/test_gym.py b/tests/test_gym.py index f03d9b92..d650f736 100644 --- a/tests/test_gym.py +++ b/tests/test_gym.py @@ -1,10 +1,9 @@ -from unittest.mock import call, patch +from unittest.mock import call import torch from modalities.batch import DatasetBatch from modalities.gym import Gym -from modalities.running_env.fsdp.reducer import Reducer def test_run_cpu_only( @@ -37,15 +36,13 @@ def test_run_cpu_only( llm_data_loader_mock.__len__ = lambda _: num_batches gym = Gym(trainer=trainer, evaluator=evaluator_mock, loss_fun=loss_mock, num_ranks=num_ranks) - with patch.object(Reducer, "reduce", return_value=None) as reduce_mock: - gym.run( - model=nn_model_mock, - optimizer=optimizer_mock, - callback_interval_in_batches=int(num_batches), - train_data_loader=llm_data_loader_mock, - evaluation_data_loaders=[], - checkpointing=checkpointing_mock, - ) - nn_model_mock.forward.assert_has_calls([call(b.samples) for b in batches]) - optimizer_mock.step.assert_called() - reduce_mock.assert_called() + gym.run( + model=nn_model_mock, + optimizer=optimizer_mock, + callback_interval_in_batches=int(num_batches), + train_data_loader=llm_data_loader_mock, + evaluation_data_loaders=[], + checkpointing=checkpointing_mock, + ) + nn_model_mock.forward.assert_has_calls([call(b.samples) for b in batches]) + optimizer_mock.step.assert_called() diff --git a/tests/test_loss_functions.py b/tests/test_loss_functions.py new file mode 100644 index 00000000..8825f15c --- /dev/null +++ b/tests/test_loss_functions.py @@ -0,0 +1,38 @@ +import pytest +import torch + +from modalities.batch import InferenceResultBatch +from modalities.loss_functions import NCELoss, nce_loss + + +@pytest.fixture +def dummy_result_batch() -> InferenceResultBatch: + predictions = {"embedding": torch.rand(1024, 512)} + targets = {"target": torch.zeros(1024, 512)} + batch_dim = 1024 + result_batch = InferenceResultBatch(targets, predictions, batch_dim) + return result_batch + + +# calculating asymmetric NCELoss between a batch of embeddings and itself --> zero +@pytest.mark.parametrize("key", ["embedding"]) +def test_asymm_NCELoss_is_zero(dummy_result_batch, key): + loss_func = NCELoss(prediction_key1=key, prediction_key2=key) + assert loss_func(dummy_result_batch) <= 10e-6 + + +# calculating nce_loss for two randomly generated batch of embeddings (manually calculated) +@pytest.mark.parametrize( + "embedding1,embedding2", + [ + ( + torch.Tensor([[0.38, 0.18], [0.36, 0.66], [0.72, 0.09]]), + torch.Tensor([[0.48, 0.01], [0.54, 0.28], [0.08, 0.34]]), + ) + ], +) +def test_nce_loss_correctness(embedding1, embedding2): + unidirectional_loss = nce_loss(embedding1, embedding2, device="cpu", is_asymmetric=True, temperature=1.0) + bidirectional_loss = nce_loss(embedding1, embedding2, device="cpu", is_asymmetric=False, temperature=1.0) + assert unidirectional_loss == pytest.approx(1.1300, 0.0001) + assert bidirectional_loss == pytest.approx(2.2577, 0.0001) diff --git a/tests/test_main.py b/tests/test_main.py index 92ab8f9f..b81d9d7c 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -4,20 +4,12 @@ from modalities.__main__ import Main -def no_gpu_available() -> bool: - return not torch.cuda.is_available() - - -@pytest.mark.skipif( - no_gpu_available(), reason="This e2e test verifies a GPU-Setup and uses components, which do not support CPU-only." -) +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This e2e test requires 1 GPU.") def test_e2e_training_run_wout_ckpt(monkeypatch, indexed_dummy_data_path, dummy_config): # patch in env variables monkeypatch.setenv("MASTER_ADDR", "localhost") monkeypatch.setenv("MASTER_PORT", "9948") - - dummy_config.data.train_dataloader.config.dataset.config.raw_data_path = indexed_dummy_data_path.raw_data_path - for val_dataloader_config in dummy_config.data.eval_dataloaders: - val_dataloader_config.config.dataset.config.raw_data_path = indexed_dummy_data_path.raw_data_path - main = Main(dummy_config) + config_dict, config_path = dummy_config + config_dict["train_dataset"]["config"]["raw_data_path"] = indexed_dummy_data_path.raw_data_path + main = Main(config_dict, config_path) main.run()