Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Savitha/scdl geneformer integration #363

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions scripts/singlecell/geneformer/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from bionemo.testing.data.load import load


data_path: Path = load("single_cell/testdata-20240506") / "cellxgene_2023-12-15_small" / "processed_data"
data_path: Path = load("single_cell/testdata-memmap-format") / "cellxgene_2023-12-15_small_mmap"


def test_bionemo2_rootdir():
Expand Down Expand Up @@ -104,7 +104,8 @@ def test_pretrain_cli(tmpdir):
--seq-length 128 \
--limit-val-batches 2 \
--micro-batch-size 2 \
--accumulate-grad-batches 2
--accumulate-grad-batches 2 \
--bypass-tokenizer-vocab True \
""".strip()
env = dict(**os.environ) # a local copy of the environment
env["MASTER_PORT"] = str(open_port)
Expand Down
10 changes: 10 additions & 0 deletions scripts/singlecell/geneformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def main(
save_top_k: int = 2,
save_every_n_steps: int = 100,
config_class: Type[BioBertConfig] = GeneformerConfig,
bypass_tokenizer_vocab: bool = False,
# TODO add datamodule class, and ability to change data step to get full support for pretraining workflows
) -> None:
"""Train a Geneformer model on single cell data.
Expand Down Expand Up @@ -205,6 +206,7 @@ def main(
persistent_workers=num_dataset_workers > 0,
pin_memory=False,
num_workers=num_dataset_workers,
bypass_tokenizer_vocab=bypass_tokenizer_vocab,
)
geneformer_config = config_class(
num_layers=6,
Expand Down Expand Up @@ -474,6 +476,13 @@ def main(
default=None,
help="Path to the checkpoint directory to restore from. Will override `--resume-if-exists` when set.",
)
parser.add_argument(
"--bypass-tokenizer-vocab",
type=Path,
required=False,
default=False,
help="Bypass whether the SingleCellDataLoaderhrows an error when a gene ensemble id is not in the tokenizer vocab. Defaults to False (so the error is thrown by default).",
)

# TODO consider whether nemo.run or some other method can simplify this config class lookup.
config_class_options: Dict[str, Type[BioBertConfig]] = {
Expand Down Expand Up @@ -540,4 +549,5 @@ def config_class_type(desc: str) -> Type[BioBertConfig]:
metric_to_monitor_for_checkpoints=args.metric_to_monitor_for_checkpoints,
save_top_k=args.save_top_k,
save_every_n_steps=args.val_check_interval,
bypass_tokenizer_vocab=args.bypass_tokenizer_vocab,
)
1 change: 1 addition & 0 deletions sub-packages/bionemo-geneformer/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies = [
# bionemo sub-packages
'bionemo-core',
'bionemo-llm',
'bionemo-scdl',
# external
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__( # noqa: D107
num_workers: int = 10, # TODO can this be automatically set?
persistent_workers: bool = True,
pin_memory: bool = True,
bypass_tokenizer_vocab: bool = False,
) -> None:
super().__init__()
self.data_path_train = train_dataset_path
Expand Down Expand Up @@ -107,6 +108,7 @@ def __init__( # noqa: D107
mask_token_prob=self.mask_token_prob,
random_token_prob=self.random_token_prob,
seed=random_utils.get_seed_from_rng(rng),
bypass_tokenizer_vocab=bypass_tokenizer_vocab,
)
self._val_dataset_ori = SingleCellDataset(
self.data_path_val,
Expand All @@ -117,6 +119,7 @@ def __init__( # noqa: D107
mask_token_prob=self.mask_token_prob,
random_token_prob=self.random_token_prob,
seed=random_utils.get_seed_from_rng(rng),
bypass_tokenizer_vocab=bypass_tokenizer_vocab,
)
self._test_dataset_ori = SingleCellDataset(
self.data_path_test,
Expand All @@ -127,6 +130,7 @@ def __init__( # noqa: D107
mask_token_prob=self.mask_token_prob,
random_token_prob=self.random_token_prob,
seed=random_utils.get_seed_from_rng(rng),
bypass_tokenizer_vocab=bypass_tokenizer_vocab,
)

# This is needed here, or you need to specify it in the megatron adapter thing TODO name?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,17 @@
# limitations under the License.


import json
from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Tuple
from typing import Any, Optional, Sequence

import numpy as np
import torch
from nemo.utils import logging
from torch.utils.data import Dataset

from bionemo.core.utils import random_utils
from bionemo.geneformer.data.singlecell.utils import sample_or_truncate
from bionemo.geneformer.tokenizer.gene_tokenizer import GeneTokenizer
from bionemo.llm.data import masking, types
from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset


__all__: Sequence[str] = (
Expand All @@ -40,19 +38,19 @@ class SingleCellDataset(Dataset):
updates will contain more comprehensive workflows for generating a Sparse Memmap from scRNA-seq.

Args:
data_path (str): Path where the single cell files are stored. It should contain the following files:
- `metadata.json`: Path containing feature subset associated with each dataset.
- `features.csv`: Feature subset associated with each sample.
data_path (str): Path where the single cell files are stored in SingleCell Memmap format. It should contain the following files:
- `metadata.json`: Path containing the number of rows int he dataset.
- Gene expression matrix stored in CSR format as `numpy.memmap`:
- `gene_expression_data.npy`: Gene expression values.
- `gene_expression_ind.npy`: Gene indices associated with gene values.
- `gene_expression_ptr.npy`: Column indices for each sample.
- `data.npy`: Non-zero gene expression values.
- `col_ptr.npy`: Indices of the corresponding genes for each entry in data.npy.
- `row_ptr.npy`: Column index pointers for each cell sample.
tokenizer: The tokenizer to use for tokenizing the input data.
median_dict (dict, optional): A dictionary containing median values for each gene. Defaults to None.
max_len (int, optional): The maximum length of the input sequence. Defaults to 1024.
bypass_tokenizer_vocab (bool, optional): Allows you to bypass enforcing that all gene ensemble IDs in the dataset are in the tokenizer vocab. Defaults to False.

Attributes:
data_path (str): Path where the single cell files are stored.
data_path (str): Path where the single cell files are stored in single cell memmap format.
max_len (int): The maximum length of the input sequence.
metadata (dict): Metadata loaded from `metadata.json`.
gene_medians (dict): A dictionary containing median values for each gene. If None, a median of '1' is assumed for all genes.
Expand Down Expand Up @@ -86,118 +84,38 @@ def __init__( # noqa: D107
mask_token_prob: float = 0.8,
random_token_prob: float = 0.1,
prepend_cls_token: bool = True,
assert_increasing_columns: bool = True,
bypass_tokenizer_vocab: bool = False,
seed: int = np.random.SeedSequence().entropy, # type: ignore
):
super().__init__()

self.data_path = data_path
self.max_len = max_len
self.random_token_prob = random_token_prob
self.mask_token_prob = mask_token_prob
self.mask_prob = mask_prob
self.prepend_cls_token = prepend_cls_token
self._seed = seed
# check if column indices are increasing for looking up genes. This is a way of spotting if the sc_memmap.py
# script produced properly strctured sparse files.
self.assert_increasing_columns = assert_increasing_columns
path = Path(data_path)

# - metadata
metadata = json.load(open(path / "metadata.json", "r"))

# - median dict
self.scdl = SingleCellMemMapDataset(data_path)
self.gene_medians = median_dict

# - train/val idxs sampled contiguously
total_el = sum([v["num_el"] for _, v in metadata.items()])
self.num_samples = sum([v["shape"][0] for _, v in metadata.items()])
# - load data
self.gene_data = np.memmap(path / "gene_expression_data.npy", dtype="float32", mode="r", shape=(total_el,))

self.gene_data_indices = np.memmap(
path / "gene_expression_ind.npy", dtype="int32", mode="r", shape=(total_el,)
)

self.gene_data_ptr = np.memmap(
path / "gene_expression_ptr.npy", dtype="int64", mode="r", shape=(self.num_samples + 1,)
)
self.tokenizer = tokenizer
rnd_key = next(iter(metadata))
feature_ids = np.array(metadata[rnd_key]["feature_ids"])

# Determine if we need to store the full metadata (per file feature_ids) or just a single feature_id
# vector for all files. If we can do the later this is much more memory efficient.
# without this change, if num_workers>0, we seem to hit a memory leak after a relatively small number
# of steps. Online discussion points to native python objects like dictionaries of a lot of data
# being a primary culprit behind large RAM usage in dataloaders that use multiprocessing.
features_all_same = True
for m in metadata.values():
if np.any(np.char.not_equal(np.array(m["feature_ids"]), feature_ids)):
features_all_same = False
break

if not features_all_same:
# We need to store per-file metadata of feature_ids. Make sure you run with a lot of RAM or few dataset workers.
# we need to store per-file metadata in this case because some of the files have different subsets of the
# feature_ids.
logging.warning(
"Feature ids are not the same across datasets. This can cause heavy RAM usage "
"for large datasets, try setting num_workers to 0."
)
self.metadata = metadata
self.feature_ids = None

# map row indices to dataset id
self.dataset_ccum = np.zeros(
len(self.metadata),
)
# Maps dataset ids to dataset names (used in the metadata dict)
self.dataset_map = {}
count = 0
for i, k in enumerate(self.metadata):
self.dataset_ccum[i] = count
self.dataset_map[i] = k
count += self.metadata[k]["shape"][0]
self.dataset_ccum[0] = -1
else:
# We can store a single feature_id vector for all datasets, and do not need to store the full metadata array.
logging.warning(
"Feature ids are the same across datasets. This is good, using the same feature_ids for all datasets."
)
self.feature_ids = feature_ids
self.metadata = None
self.bypass_tokenizer_vocab = bypass_tokenizer_vocab

def __len__(self): # noqa: D105
return self.num_samples

def metadata_lookup(self, idx) -> Dict[str, np.ndarray]:
"""Go from a cell idx to the file-level metadata associated with that cell."""
did = sum(~(self.dataset_ccum > idx)) - 1
metadata = self.metadata[self.dataset_map[did]]
return metadata

def lookup_cell_by_idx(self, idx) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: # noqa: D102
ptr = slice(int(self.gene_data_ptr[idx]), int(self.gene_data_ptr[idx + 1]))
# col idxs poin to offsets in the original sparse metadata, this is for looking up metadata eg gene names
col_idxs = np.asarray(self.gene_data_indices[ptr]).astype(int) # keyed by ptr
if self.assert_increasing_columns and len(col_idxs) > 1:
is_increasing = np.diff(col_idxs) > 0
if not np.all(is_increasing):
raise ValueError(f"Column indices are not increasing for {np.sum(~is_increasing)} pairs of genes")
gene_data = np.asarray(self.gene_data[ptr]).astype(int) # keyed by ptr
# Get feature_ids for this particular cell. Eitehr lookup by index if we need to, or if we already verified that
# metadata is not needed because feature_ids are the same for every file, then we can just use the single feature_ids
# vector instead.
feature_ids: np.ndarray = (
self.feature_ids if self.metadata is None else self.metadata_lookup(idx)["feature_ids"]
)
return gene_data, col_idxs, feature_ids
return len(self.scdl)

def __getitem__(self, idx: int) -> types.BertSample: # noqa: D105
rng = np.random.default_rng([self._seed, idx])

"""Performs a lookup and the required transformation for the model"""
gene_data, col_idxs, feature_ids = self.lookup_cell_by_idx(idx)
values, feature_ids_df = self.scdl.get_row(idx, return_features=True, feature_vars=["feature_id"])
gene_data, col_idxs = values[0], values[1]
gene_data = gene_data.astype(np.int64)
col_idxs = col_idxs.astype(np.int64)
if len(gene_data) == 0:
raise ValueError(
"SingleCellMemap data provided is invalid; the gene expression data parsed for the specified index is empty."
)
feature_ids = feature_ids_df.values.tolist()
feature_ids = [f[0] for f in feature_ids]
return process_item(
gene_data,
col_idxs,
Expand All @@ -210,6 +128,7 @@ def __getitem__(self, idx: int) -> types.BertSample: # noqa: D105
mask_prob=self.mask_prob,
random_token_prob=self.random_token_prob,
prepend_cls_token=self.prepend_cls_token,
bypass_tokenizer_vocab=self.bypass_tokenizer_vocab,
)


Expand All @@ -227,6 +146,7 @@ def process_item( # noqa: D417
target_sum: int = 10000,
normalize: bool = True,
prepend_cls_token: bool = True,
bypass_tokenizer_vocab: bool = False,
) -> types.BertSample:
"""Process a single item in the dataset.

Expand All @@ -235,7 +155,7 @@ def process_item( # noqa: D417

Args:
gene_data (list): List of gene data, these are expression counts.
gene_idxs (list): List of gene indices, these are keys in 'metadata['feature_ids']' and correspdong the CSR entry. These are computed by sc_memmap.
gene_idxs (list): List of gene indices, these are keys in 'metadata['feature_ids']' and corresponding the CSR entry.
feature_ids (list): Feature ids for the full dataset.
tokenizer (Tokenizer): Tokenizer object.
gene_median (optional(dict)): Dictionary of gene medians. Defaults to None. Expects ensembl IDs to be keys.
Expand Down Expand Up @@ -263,7 +183,6 @@ def process_item( # noqa: D417
raise ValueError("gene_median must be provided for this tokenizer")

max_len = max_len - 1 # - minus 1 for [CLS] token

gene_names = [feature_ids[idx] for idx in gene_idxs]
genes, tokens, medians = [], [], []
for tok, gene in zip(gene_names, gene_data):
Expand All @@ -273,6 +192,8 @@ def process_item( # noqa: D417
if normalize:
med = gene_median.get(tok, 1) # If not in the dictionary we default to no normalization (1)
medians.append(med)
elif not bypass_tokenizer_vocab:
raise ValueError("Provided gene id " + str(gene) + " not in tokenizer vocab.")

genes = np.asarray(genes)
token_ids = np.asarray(tokens)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from pathlib import Path

import pytest

from bionemo.testing.data.load import load


@pytest.fixture
def test_directory() -> Path:
"""Gets the path to the original synthetic single cell directory with test data (no feature ids).

Returns:
A Path object that is the directory with specified test data.
"""
return load("scdl/sample") / "scdl_data"


@pytest.fixture
def test_directory_feat_ids() -> Path:
"""Gets the path to the directory with the synthetic single cell data (with the feature ids appended).

Returns:
A Path object that is the directory with specified test data.
"""
return load("scdl_feature_ids/sample_scdl_feature_ids") / "scdl_data_with_feature_ids"


@pytest.fixture
def cellx_small_directory() -> Path:
"""Gets the path to the directory with with cellx small dataset in Single Cell Memmap format.

Returns:
A Path object that is the directory with the specified test data.
"""
return load("single_cell/testdata-memmap-format") / "cellxgene_2023-12-15_small_mmap"
Loading