Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing dataset path to first-party recipes #652

Merged
merged 1 commit into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions src/fairseq2/recipes/lm/instruction_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
from torch import Tensor
from torch.nn import Module

from fairseq2.assets import default_asset_store
from fairseq2.assets import AssetNotFoundError, default_asset_store
from fairseq2.assets.utils import retrieve_asset_card
from fairseq2.checkpoint import CheckpointModelMetadataProvider, FileCheckpointManager
from fairseq2.config_registry import ConfigRegistry
from fairseq2.data.text import load_text_tokenizer
from fairseq2.datasets import LengthBatching
from fairseq2.datasets.instruction import load_instruction_dataset
from fairseq2.datasets.instruction import (
GenericInstructionDataset,
load_instruction_dataset,
)
from fairseq2.gang import Gang
from fairseq2.logging import get_log_writer
from fairseq2.models import load_model
Expand Down Expand Up @@ -57,7 +60,7 @@ class InstructionFinetuneConfig:

# Data
dataset: Union[str, Path] = "openeft" # TODO: change!
"""The name or path to the asset card of the instruction dataset."""
"""The name, path, or path to the asset card of the instruction dataset."""

max_seq_len: int = 8192
"""The maximum sequence length."""
Expand Down Expand Up @@ -236,13 +239,26 @@ def load_instruction_finetuner(
log.info("Tokenizer loaded.")

# Load the dataset.
dataset_card = retrieve_asset_card(config.dataset)
try:
dataset_card = retrieve_asset_card(config.dataset)
except AssetNotFoundError:
dataset_card = None

log.info("Loading {} instruction dataset.", dataset_card.name)
if dataset_card is not None:
log.info("Loading {} instruction dataset.", dataset_card.name)

dataset = load_instruction_dataset(dataset_card)
dataset = load_instruction_dataset(dataset_card)

log.info("Dataset loaded.")
log.info("Dataset loaded.")
else:
try:
path = Path(config.dataset)
except ValueError:
raise AssetNotFoundError(
config.dataset, f"An asset with the name '{config.dataset}' cannot be found." # type: ignore[arg-type]
)

dataset = GenericInstructionDataset.from_path(path)

# Load the model.
init_device = META
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/recipes/transformer/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class TransformerEvalConfig:

# Data
dataset: Union[str, Path] = "foo" # TODO: change!
"""The name or path to the asset card of the parallel text dataset."""
"""The name, path, or path to the asset card of the parallel text dataset."""

split: str = "test"
"""The name of the test data split."""
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/recipes/transformer/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class TextTranslateConfig:

# Data
dataset: Union[str, Path] = "foo" # TODO: change!
"""The name or path to the asset card of the text dataset."""
"""The name, path, or path to the asset card of the text dataset."""

source_lang: str = "eng_Latn"
"""The code of the language to translate from."""
Expand Down
27 changes: 20 additions & 7 deletions src/fairseq2/recipes/wav2vec2/asr/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
import torch
from torch.nn import Module

from fairseq2.assets import default_asset_store
from fairseq2.assets import AssetNotFoundError, default_asset_store
from fairseq2.assets.utils import retrieve_asset_card
from fairseq2.checkpoint import CheckpointModelMetadataProvider
from fairseq2.config_registry import ConfigRegistry
from fairseq2.data.text import TextTokenDecoder, TextTokenizer, load_text_tokenizer
from fairseq2.datasets import LengthBatching
from fairseq2.datasets.asr import load_asr_dataset
from fairseq2.datasets.asr import GenericAsrDataset, load_asr_dataset
from fairseq2.gang import Gang
from fairseq2.logging import get_log_writer
from fairseq2.metrics.wer import WerMetric
Expand Down Expand Up @@ -49,7 +49,7 @@ class Wav2Vec2AsrEvalConfig:

# Data
dataset: Union[str, Path] = "librilight_asr_10h"
"""The name or path to the asset card of the ASR dataset."""
"""The name, path, or path to the asset card of the ASR dataset."""

split: str = "test_other"
"""The name of the eval data split."""
Expand Down Expand Up @@ -119,13 +119,26 @@ def load_wav2vec2_asr_evaluator(
log.info("Tokenizer loaded.")

# Load the dataset.
dataset_card = retrieve_asset_card(config.dataset)
try:
dataset_card = retrieve_asset_card(config.dataset)
except AssetNotFoundError:
dataset_card = None

if dataset_card is not None:
log.info("Loading {} ASR dataset.", dataset_card.name)

log.info("Loading {} ASR dataset.", dataset_card.name)
dataset = load_asr_dataset(dataset_card)

dataset = load_asr_dataset(dataset_card)
log.info("Dataset loaded.")
else:
try:
path = Path(config.dataset)
except ValueError:
raise AssetNotFoundError(
config.dataset, f"An asset with the name '{config.dataset}' cannot be found." # type: ignore[arg-type]
)

log.info("Dataset loaded.")
dataset = GenericAsrDataset.from_path(path)

# Load the model.
log.info("Loading {} model on rank 0.", model_card.name)
Expand Down
27 changes: 20 additions & 7 deletions src/fairseq2/recipes/wav2vec2/asr/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
from torch import Tensor
from torch.nn import Module

from fairseq2.assets import default_asset_store
from fairseq2.assets import AssetNotFoundError, default_asset_store
from fairseq2.assets.utils import retrieve_asset_card
from fairseq2.checkpoint import CheckpointModelMetadataProvider, FileCheckpointManager
from fairseq2.config_registry import ConfigRegistry
from fairseq2.data.text import load_text_tokenizer
from fairseq2.datasets import LengthBatching
from fairseq2.datasets.asr import load_asr_dataset
from fairseq2.datasets.asr import GenericAsrDataset, load_asr_dataset
from fairseq2.gang import Gang
from fairseq2.logging import get_log_writer
from fairseq2.models.seq2seq import Seq2SeqBatch
Expand Down Expand Up @@ -63,7 +63,7 @@ class Wav2Vec2AsrTrainConfig:

# Data
dataset: Union[str, Path] = "librilight_asr_10h"
"""The name or path to the asset card of the ASR dataset."""
"""The name, path, or path to the asset card of the ASR dataset."""

train_split: str = "train"
"""The name of the train data split."""
Expand Down Expand Up @@ -235,13 +235,26 @@ def load_wav2vec2_asr_trainer(
log.info("Tokenizer loaded.")

# Load the dataset.
dataset_card = retrieve_asset_card(config.dataset)
try:
dataset_card = retrieve_asset_card(config.dataset)
except AssetNotFoundError:
dataset_card = None

log.info("Loading {} ASR dataset.", dataset_card.name)
if dataset_card is not None:
log.info("Loading {} ASR dataset.", dataset_card.name)

dataset = load_asr_dataset(config.dataset)
dataset = load_asr_dataset(dataset_card)

log.info("Dataset loaded.")
log.info("Dataset loaded.")
else:
try:
path = Path(config.dataset)
except ValueError:
raise AssetNotFoundError(
config.dataset, f"An asset with the name '{config.dataset}' cannot be found." # type: ignore[arg-type]
)

dataset = GenericAsrDataset.from_path(path)

# Initialize the model configuration.
if config.model_arch is None:
Expand Down
Loading