Skip to content

Commit

Permalink
Improve seed handling (#651)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Jul 9, 2024
1 parent 9e4ca4e commit 35ce1ec
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 13 deletions.
8 changes: 6 additions & 2 deletions src/fairseq2/recipes/lm/instruction_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ def load_instruction_finetuner(
CheckpointModelMetadataProvider(config.resume_checkpoint_dir)
)

seed = config.seed

# Load the tokenizer.
model_card = retrieve_asset_card(config.model)

Expand Down Expand Up @@ -312,9 +314,11 @@ def load_instruction_finetuner(
batch_shuffle_window=config.batch_shuffle_window,
num_accumulate=config.gradient_accumulation,
num_prefetch=config.num_prefetch,
seed=config.seed,
seed=seed,
)

seed += 1

optimizer = AdamW(
model.parameters(),
lr=config.lr,
Expand Down Expand Up @@ -352,7 +356,7 @@ def load_instruction_finetuner(
publish_metrics_every_n_steps=config.publish_metrics_every_n_steps,
profile=config.profile,
anomaly_detection=config.anomaly_detection,
seed=config.seed,
seed=seed,
wall_watch=wall_watch,
)

Expand Down
14 changes: 9 additions & 5 deletions src/fairseq2/recipes/transformer/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class TextTranslateConfig:
"""The data type of the model."""

# Generation
mode: Literal["beam_search", "sampling"] = "beam_search"
generator_mode: Literal["beam_search", "sampling"] = "beam_search"
"""The mode of sequence generation."""

beam_search: BeamSearchConfig = field(default_factory=lambda: BeamSearchConfig())
Expand Down Expand Up @@ -204,6 +204,8 @@ def load_text_translator(

gang = setup_root_gang(log)

seed = config.seed

# Load the tokenizer.
model_card = retrieve_asset_card(config.model)

Expand Down Expand Up @@ -257,7 +259,7 @@ def load_text_translator(

# Initialize the sequence generator.
generator = _create_sequence_generator(
model, config.mode, config.beam_search, config.sampling
model, config.generator_mode, config.beam_search, config.sampling
)

# Initialize the generator unit.
Expand Down Expand Up @@ -305,16 +307,18 @@ def load_text_translator(
config.max_seq_len,
batching=StaticBatching(config.batch_size),
num_prefetch=config.num_prefetch,
seed=config.seed,
seed=seed,
)

seed += 1

# Initialize the generator.
return Generator[SequenceBatch](
unit=unit,
data_reader=data_reader,
root_gang=gang,
metrics_dir=output_dir.joinpath("metrics"),
seed=config.seed,
seed=seed,
wall_watch=wall_watch,
)

Expand Down Expand Up @@ -404,7 +408,7 @@ def _create_sequence_generator(
return _create_sampling_generator(model, sampling_config)

raise ValueError(
f"`config.mode` must be 'sampling' or 'beam_search', but is '{mode}' instead."
f"`config.generator_mode` must be 'sampling' or 'beam_search', but is '{mode}' instead."
)


Expand Down
8 changes: 6 additions & 2 deletions src/fairseq2/recipes/wav2vec2/asr/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def load_wav2vec2_asr_evaluator(

gang = setup_root_gang(log)

seed = config.seed

# Load the tokenizer.
model_card = retrieve_asset_card(config.model)

Expand Down Expand Up @@ -190,17 +192,19 @@ def load_wav2vec2_asr_evaluator(
max_audio_len=config.max_audio_len,
normalize_audio=config.normalize_audio,
num_prefetch=config.num_prefetch,
seed=config.seed,
seed=seed,
)

seed += 1

# Initialize the evaluator.
return Evaluator[Seq2SeqBatch](
units=[unit],
data_readers=[data_reader],
root_gang=gang,
tb_dir=output_dir.joinpath("tb"),
metrics_dir=output_dir.joinpath("metrics"),
seed=config.seed,
seed=seed,
wall_watch=wall_watch,
)

Expand Down
16 changes: 12 additions & 4 deletions src/fairseq2/recipes/wav2vec2/asr/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ def load_wav2vec2_asr_trainer(
CheckpointModelMetadataProvider(config.resume_checkpoint_dir)
)

seed = config.seed

# Load the tokenizer.
tokenizer_card = retrieve_asset_card(config.tokenizer)

Expand Down Expand Up @@ -288,7 +290,9 @@ def load_wav2vec2_asr_trainer(
log.info("Pretrained model loaded on rank 0.")

if gang.rank == 0:
to_device(model, gang.device, seed=config.seed)
to_device(model, gang.device, seed=seed)

seed += 1

gang.barrier()

Expand Down Expand Up @@ -337,9 +341,11 @@ def load_wav2vec2_asr_trainer(
batch_shuffle_window=config.batch_shuffle_window,
num_accumulate=config.gradient_accumulation,
num_prefetch=config.num_prefetch,
seed=config.seed,
seed=seed,
)

seed += 1

optimizer = AdamW(dp_model.parameters(), lr=config.lr, betas=config.betas)

lr_scheduler = TriStageLR(
Expand All @@ -363,9 +369,11 @@ def load_wav2vec2_asr_trainer(
max_audio_len=config.max_audio_len,
normalize_audio=config.normalize_audio,
num_prefetch=config.num_prefetch,
seed=config.seed,
seed=seed,
)

seed += 1

# Initialize the trainer.
return Trainer[Seq2SeqBatch](
unit=unit,
Expand All @@ -390,7 +398,7 @@ def load_wav2vec2_asr_trainer(
publish_metrics_every_n_steps=config.publish_metrics_every_n_steps,
profile=config.profile,
anomaly_detection=config.anomaly_detection,
seed=config.seed,
seed=seed,
wall_watch=wall_watch,
)

Expand Down

0 comments on commit 35ce1ec

Please sign in to comment.