Skip to content

Commit

Permalink
update README
Browse files Browse the repository at this point in the history
  • Loading branch information
alexteua committed Feb 6, 2023
1 parent 6963f51 commit b07771f
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,12 @@ dp/configs/autoreg_config.yaml
```
for the forward and autoregressive transformer model, respectively.

Pepare data in a tuple-format and use the preprocess and train API:
Distributed training is supported. You can specify which GPUs to utilize by setting CUDA_VISIBLE_DEVICES env variable:
```bash
CUDA_VISIBLE_DEVICES=0,1 python run_training.py
```

Inside the training script prepare data in a tuple-format and use the preprocess and train API:

```python
from dp.preprocess import preprocess
Expand All @@ -74,9 +79,19 @@ train_data = [('en_us', 'young', 'jʌŋ'),
val_data = [('en_us', 'young', 'jʌŋ'),
('de', 'benützten', 'bənʏt͡stn̩')] * 100

preprocess(config_file='config.yaml', train_data=train_data,
config_file = 'dp/configs/forward_config.yaml'

preprocess(config_file=config_file,
train_data=train_data,
val_data=val_data,
deduplicate_train_data=False)
train(config_file='config.yaml')

num_gpus = torch.cuda.device_count()

if num_gpus > 1:
mp.spawn(train, nprocs=num_gpus, args=(num_gpus, config_file))
else:
train(rank=0, num_gpus=num_gpus, config_file=config_file)
```
Model checkpoints will be stored in the checkpoints path that is provided by the config.yaml.

Expand Down

0 comments on commit b07771f

Please sign in to comment.