diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 117a9919..5012eff2 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -1,26 +1,93 @@ -name: Build wheels +name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI -on: - release: - types: [ published ] +on: push jobs: - pypi-publish: - name: Build and upload release to PyPI + build: + name: Build distribution 📦 runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.x" + - name: Install pypa/build + run: >- + python3 -m + pip install + build + --user + - name: Build a binary wheel and a source tarball + run: python3 -m build + - name: Store the distribution packages + uses: actions/upload-artifact@v3 + with: + name: python-package-distributions + path: dist/ + + publish-to-pypi: + name: >- + Publish Python 🐍 distribution 📦 to PyPI + if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes + needs: + - build + runs-on: ubuntu-latest + environment: + name: pypi + url: https://pypi.org/p/femr # Replace with your PyPI project name + permissions: + id-token: write # IMPORTANT: mandatory for trusted publishing + + steps: + - name: Download all the dists + uses: actions/download-artifact@v3 + with: + name: python-package-distributions + path: dist/ + - name: Publish distribution 📦 to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + + github-release: + name: >- + Sign the Python 🐍 distribution 📦 with Sigstore + and upload them to GitHub Release + needs: + - publish-to-pypi + runs-on: ubuntu-latest + + permissions: + contents: write # IMPORTANT: mandatory for making GitHub Releases + id-token: write # IMPORTANT: mandatory for sigstore + steps: - - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: '3.x' - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install build - - name: Build package - run: python -m build - - name: Publish package - uses: pypa/gh-action-pypi-publish@release/v1 - with: - password: ${{ secrets.pypi_token }} + - name: Download all the dists + uses: actions/download-artifact@v3 + with: + name: python-package-distributions + path: dist/ + - name: Sign the dists with Sigstore + uses: sigstore/gh-action-sigstore-python@v1.2.3 + with: + inputs: >- + ./dist/*.tar.gz + ./dist/*.whl + - name: Create GitHub Release + env: + GITHUB_TOKEN: ${{ github.token }} + run: >- + gh release create + '${{ github.ref_name }}' + --repo '${{ github.repository }}' + --notes "" + - name: Upload artifact signatures to GitHub Release + env: + GITHUB_TOKEN: ${{ github.token }} + # Upload to GitHub Release using the `gh` CLI. + # `dist/` contains the built packages, and the + # sigstore-produced signatures and certificates. + run: >- + gh release upload + '${{ github.ref_name }}' dist/** + --repo '${{ github.repository }}' diff --git a/.gitignore b/.gitignore index 5bc74a06..8bbc7147 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ ignore/* *.ipynb_checkpoints* tutorials/trash tutorials/tmp_trainer +_version.py diff --git a/pyproject.toml b/pyproject.toml index 937b86ef..6b72d719 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools", "setuptools-scm"] +requires = ["setuptools >= 69.0", "setuptools-scm>=8.0"] build-backend = "setuptools.build_meta" [project] @@ -15,7 +15,7 @@ dependencies = [ "icecream == 2.1.3", "nptyping == 2.4.1", "msgpack >= 1.0.5", - "meds == 0.1.1", + "meds == 0.1.3", "torch >= 2.1.2", "transformers >= 4.25", "datasets >= 2.15", @@ -23,7 +23,10 @@ dependencies = [ "dill >= 0.3.7", ] requires-python=">3.9" -version = "0.2.2" +dynamic = ["version"] + +[tool.setuptools_scm] +version_file = "src/femr/_version.py" [project.scripts] diff --git a/src/femr/__init__.py b/src/femr/__init__.py index e69de29b..6ce5475b 100644 --- a/src/femr/__init__.py +++ b/src/femr/__init__.py @@ -0,0 +1 @@ +from femr._version import __version__ # noqa \ No newline at end of file diff --git a/src/femr/models/config.py b/src/femr/models/config.py index fd524170..40ab236c 100644 --- a/src/femr/models/config.py +++ b/src/femr/models/config.py @@ -20,6 +20,20 @@ def __init__( hidden_act: str = "gelu", **kwargs, ) -> None: + """Defined a configuration for a FEMR Transformer. + + Arguments: + vocab_size: The number of tokens in the vocabulary + is_hierarchical: Whether to use a hierarchical vocabulary. See FEMRTokenizer for more information. + hidden_size: The internal representation size + intermediate_size: The size of the FFN in the transformer layers + n_heads: The number of attention heads + n_layers: The number of transformer encoder layers + attention_width: FEMR by default uses a local attention transformer. This defines the width of those attention windows. + use_normed_ages: Whether or not to provide normalized ages as a feature to the model + use_bias: Whether or not to use bias terms in the transformer layers. Current research indicates that this should be set to False. + hidden_act: The type of activation function to use in the transformer. + """ super().__init__(**kwargs) self.vocab_size = vocab_size @@ -39,18 +53,31 @@ def __init__( class FEMRTaskConfig(transformers.PretrainedConfig): def __init__(self, task_type: str = "", task_kwargs: Mapping[str, Any] = {}, **kwargs): + """A generic FEMR task definition. This holds state used for initalizing a tasks.py class. + + Task.get_task_config returns the task type and kwargs used to initialize this. + + Arguments: + task_type: The name of the task. + task_kwargs: Arbitrary arguments used to store state for that task. + """ super().__init__(**kwargs) self.task_type = task_type self.task_kwargs = task_kwargs class FEMRModelConfig(transformers.PretrainedConfig): + """A model config is defined as the combination of a transformer config and a task config.""" def __init__( self, transformer_config: Optional[Dict[str, Any]] = None, task_config: Optional[Dict[str, Any]] = None, **kwargs, ): + """A combination of a transformer config and a task config. + + It is possible to initialize this with only a transformer config, in which case the model will be configured for inference only. + """ super().__init__(**kwargs) if transformer_config is None: transformer_config = {} @@ -67,6 +94,9 @@ def __init__( def from_transformer_task_configs( cls, transformer_config: FEMRTransformerConfig, task_config: FEMRTaskConfig ) -> FEMRModelConfig: + """ + Combine a transformer configuration and task configuration into a model configuration. + """ if task_config is not None: task_config_dict = task_config.to_dict() else: diff --git a/src/femr/models/processor.py b/src/femr/models/processor.py index 9248116f..0ec26bb1 100644 --- a/src/femr/models/processor.py +++ b/src/femr/models/processor.py @@ -3,18 +3,37 @@ import collections import datetime import functools -from typing import Any, Dict, List, Mapping, Tuple +from typing import Any, Dict, List, Mapping, Tuple, Optional import datasets import numpy as np import torch.utils.data +import meds + import femr.hf_utils import femr.models.tokenizer import femr.pat_utils -def map_length_stats(batch, indices, *, processor, max_length): +def map_preliminary_batch_stats(batch, indices, *, processor: FEMRBatchProcessor, max_length: int): + """ + This function creates preliminary batch statistics, to be used for final batching. + + The overall problem this is trying to solve is that Patient records can be very long, so when we featurize patients we actually featurize subsequences of patients. + + AKA every patient in a batch is actually (patient_id, start_index, length), which specifies a subsequence of the patient. + + The trickiness becomes when a particular patient has multiple labels, some of which will require multiple subsequences. + + The goal of this function is to compute the list [(patient_id, start_index, length)] such that every label is covered + by at least one batch. Note that some labels will be covered multiple times. + + Note that there is a special setting for tasks that don't need exact labeling (needs_exact in tasks.py returns False). + For these patients we only generate one tuple for them, and drop some labels. + + Later code will then take [(patient_id, start_index, length)], and create actual batches. + """ lengths = [] for patient_index, patient_id, events in zip(indices, batch["patient_id"], batch["events"]): @@ -23,8 +42,12 @@ def map_length_stats(batch, indices, *, processor, max_length): "events": events, } data = processor.convert_patient(patient) + + # There are no labels for this patient if data["transformer"]["label_indices"].shape[0] == 0: continue + + # We need exact batching, so we need an algorithm that precisely covers every label in the batch if data["needs_exact"]: current_start = 0 current_end = 0 @@ -48,18 +71,22 @@ def map_length_stats(batch, indices, *, processor, max_length): return [] -def agg_length_stats(lengths1, lengths2): +def agg_preliminary_batch_stats(lengths1, lengths2): + """Aggregate preliminary length statistics from the map_preliminary_batch_stats""" lengths1.extend(lengths2) return lengths1 class BatchCreator: - def __init__(self, tokenizer, task=None): + """The BatchCreator is designed to generate batches from patient data.""" + def __init__(self, tokenizer: femr.models.FEMRTokenizer, task: Optional[femr.tasks.Task] = None): + """Initialize a BatchCreator, with a tokenizer, and optionally a task.""" self.tokenizer = tokenizer self.task = task def start_batch(self): + """Start a batch.""" self.patient_ids = [] self.offsets = [] self.patient_lengths = [] @@ -82,99 +109,193 @@ def start_batch(self): if self.task is not None: self.task.start_batch() - def add_patient(self, patient, offset, max_patient_length=None): - self.offsets.append(offset) + def add_patient(self, patient: meds.Patient, offset: int = 0, max_length: Optional[int] = None): + """Add a patient to the current batch. + + Note that the two optional parameters are used to add a subset of a patient to a batch. + It is generally recommended to never manually use offset or max_length as you should rely on FEMRBatchProcessor.convert_dataset. - def process_patient_events(): - current_date = None - last_time = None + Arguments: + patient: The patient to add. + offset: The offset into the patient to featurize. + max_length: The maximum length of the batch sequence. There is no max when left at None. - if self.task is not None: - self.task.start_patient(patient, self.tokenizer.ontology) + """ + current_date = None + last_time = None - patient_length_index = 0 - birth = femr.pat_utils.get_patient_birthdate(patient) - self.tokenizer.start_patient() + # The overall algorithm here is a bit complex + # First we featurize the entire patient + # Then we slice the patient indices according to offset and max_length - for event in patient["events"]: - if event["time"].date() != current_date: - current_date = event["time"].date() - codes_seen_today = set() + # These are the indices of the labels into the patient vectors + per_patient_label_indices = [] - for measurement in event["measurements"]: - features, weights = self.tokenizer.get_feature_codes(event["time"], measurement) - if len(features) == 0: - continue - if all(feature in codes_seen_today for feature in features): - continue + # The ages at each index for the patient + per_patient_ages = [] - codes_seen_today |= set(features) + # The normalized age at index for the patient + per_patient_normalized_ages = [] - if patient_length_index < offset: - patient_length_index += 1 - continue + # The timestamps at each index for the patient + per_patient_timestamps = [] - if (self.task is not None) and (last_time is not None): - num_added = self.task.add_event(last_time, event["time"], features) - for _ in range(num_added): - self.label_indices.append(len(self.ages) - 1) + # For a regular tokenizer, we just have tokens + per_patient_tokens = [] - if max_patient_length is not None and (patient_length_index - offset >= max_patient_length): - return None + # For a hierarchical tokenizer, we have a more complex setup + # These are designed to match the inputs required for an EmbeddingBag. + # See PyTorch's EmbeddingBag documentation to understand what these mean. + per_patient_hierarchical_tokens = [] + per_patient_hierarchical_weights = [] + per_patient_token_indices = [0] - if not self.tokenizer.is_hierarchical: - assert len(features) == 1 - self.tokens.append(features[0]) - else: - self.hierarchical_tokens.extend(features) - self.hierarchical_weights.extend(weights) - self.token_indices.append(len(self.hierarchical_tokens)) + if self.task is not None: + self.task.start_patient(patient, self.tokenizer.ontology) - self.patient_ids.append(patient["patient_id"]) - self.valid_tokens.append(True) - self.ages.append((event["time"] - birth) / datetime.timedelta(days=1)) - self.normalized_ages.append(self.tokenizer.normalize_age(event["time"] - birth)) - self.timestamps.append(event["time"].timestamp()) + birth = femr.pat_utils.get_patient_birthdate(patient) + self.tokenizer.start_patient() - patient_length_index += 1 + for event in patient["events"]: + # We want to avoid duplicate codes in the same day, so we maintain codes_seen_today + if event["time"].date() != current_date: + current_date = event["time"].date() + codes_seen_today = set() - last_time = event["time"] + for measurement in event["measurements"]: + # Get features and weights for the current event + features, weights = self.tokenizer.get_feature_codes(event["time"], measurement) - return last_time + # Ignore events with no features + if len(features) == 0: + continue - start_index = len(self.ages) - final_time = process_patient_events() + # Ignore events where all features have already occurred + if all(feature in codes_seen_today for feature in features): + continue + + codes_seen_today |= set(features) + + if (self.task is not None) and (last_time is not None): + # Now we have to consider whether or not to have labels for this time step + # The add_event function returns how many labels to assign for this time + num_added = self.task.add_event(last_time, event["time"], features) + for _ in range(num_added): + per_patient_label_indices.append(len(per_patient_ages) - 1) + + if not self.tokenizer.is_hierarchical: + assert len(features) == 1 + per_patient_tokens.append(features[0]) + else: + per_patient_hierarchical_tokens.extend(features) + per_patient_hierarchical_weights.extend(weights) + per_patient_token_indices.append(len(per_patient_hierarchical_tokens)) + + per_patient_ages.append((event["time"] - birth) / datetime.timedelta(days=1)) + per_patient_normalized_ages.append(self.tokenizer.normalize_age(event["time"] - birth)) + per_patient_timestamps.append(event["time"].timestamp()) + + last_time = event["time"] - if self.task is not None and final_time is not None: - num_added = self.task.add_event(final_time, None, []) + + if self.task is not None and last_time is not None: + num_added = self.task.add_event(last_time, None, None) for _ in range(num_added): - self.label_indices.append(len(self.ages) - 1) + per_patient_label_indices.append(len(per_patient_ages) - 1) + + # Now we want to actually add the patient data to the batch. + # This will involve some clever slicing. + + # First, let's get the length we are adding + length_found = len(per_patient_ages) + if max_length is not None: + length_to_add = min(length_found - offset, max_length) + else: + length_to_add = length_found - offset + + start_index = len(self.ages) + + # Let's add the constants first + self.valid_tokens.extend([True] * length_to_add) + self.patient_ids.extend([patient['patient_id']] * length_to_add) + self.offsets.append(offset) + self.patient_lengths.append(length_to_add) + + # Ages, normalized ages and timestamps are also easy to add + self.ages.extend(per_patient_ages[offset: offset + length_to_add]) + self.normalized_ages.extend(per_patient_normalized_ages[offset: offset + length_to_add]) + self.timestamps.extend(per_patient_timestamps[offset: offset + length_to_add]) + + if not self.tokenizer.is_hierarchical: + # Easy for simple tokenizer + self.tokens.extend(per_patient_tokens[offset: offset + length_to_add]) + else: + # Hierarchical tokenizer is more complex since we have to shift the indices as well + # Remember, these arrays are all designed for PyTorch EmbeddingBag + + # We need to get the start and end at a particular offset + internal_start = per_patient_token_indices[offset] + internal_end = per_patient_token_indices[offset + length_to_add] + + # We need to offset the token indices to account for the existing tokens + self.token_indices.extend([len(self.hierarchical_tokens) - internal_start + value for value in per_patient_token_indices[offset + 1:offset + length_to_add + 1]]) + + self.hierarchical_tokens.extend(per_patient_hierarchical_tokens[internal_start: internal_end]) + self.hierarchical_weights.extend(per_patient_hierarchical_weights[internal_start: internal_end]) + + + # The label indices are also a bit tricky as they have to be offset accordingly. + # We also need to collect good labels that should be sent to the final numpy arrays. + labels_to_add = [] + for i, label_index in enumerate(per_patient_label_indices): + corrected_label = label_index - offset + + if 0 <= corrected_label < length_to_add: + labels_to_add.append(i) + self.label_indices.append(start_index + corrected_label) + + if self.task is not None: + self.task.add_patient_labels(labels_to_add) - self.patient_lengths.append(len(self.ages) - start_index) def get_batch_data(self): + """Convert the batch to numpy arrays. The data structure is defined inline in this function.""" if self.tokenizer.vocab_size <= 2**15: token_dtype = np.int16 else: token_dtype = np.int32 transformer = { + # Whether or not the token is valid at this index "valid_tokens": np.array(self.valid_tokens), + + # The age of the patient in days at this index "ages": np.array(self.ages, dtype=np.float32), + + # The normalized ages at this index "normalized_ages": np.array(self.normalized_ages, dtype=np.float16), + + # The timestamp (in seconds) at this index "timestamps": np.array(self.timestamps, dtype=np.int64), + + # The length of the patient "patient_lengths": np.array(self.patient_lengths, dtype=np.int32), + + # The indices of the labels "label_indices": np.array(self.label_indices, dtype=np.int32), } if not self.tokenizer.is_hierarchical: + # For a single tokenizer, these are simple the token indices transformer["tokens"] = np.array(self.tokens, dtype=token_dtype) else: + # See PyTorch's EmbeddingBag for what these numpy arrays mean. transformer["hierarchical_tokens"] = np.array(self.hierarchical_tokens, dtype=token_dtype) transformer["hierarchical_weights"] = np.array(self.hierarchical_weights, dtype=np.float16) transformer["token_indices"] = np.array(self.token_indices, dtype=np.int32) + # Some general metadata final = { "num_patients": len(self.patient_lengths), "num_indices": len(self.label_indices), @@ -183,6 +304,7 @@ def get_batch_data(self): "transformer": transformer, } + # Add the task data if self.task is not None and transformer["label_indices"].shape[0] > 0: final["task"] = self.task.get_batch_data() final["needs_exact"] = self.task.needs_exact() @@ -230,12 +352,30 @@ def _add_dimension(data: Any) -> Any: class FEMRBatchProcessor: + """The FEMR Batch processor creates batches for processing by a transformer.""" def __init__(self, tokenizer, task=None): self.creator = BatchCreator(tokenizer, task) - def convert_patient(self, patient, offset=0, max_patient_length=None, tensor_type=None, **formatter_kwargs): + def convert_patient(self, patient: meds.Patient, offset: int =0, max_length: Optional[int]=None, tensor_type=None, **formatter_kwargs): + """Convert a single patient to a batch. + + Note that this can also convert parts of a patient to a batch using the offset and max_length parameters. + This is useful for processing long patients. + + NOTE: This function is primarily for debugging purposes. It is recommended to use convert_dataset for maximum correctness and efficiency. + + Arguments: + patient: The patient to convert + offset: The integer offset into the patient to convert + max_length: The maximum length to convert + tensor_type: The dataset to return + formatter_kwargs: Arguments for a datasets formatter when converting datatypes + + Returns: + A batch, ready to be fed into a FEMR transformer model + """ self.creator.start_batch() - self.creator.add_patient(patient, offset=offset, max_patient_length=max_patient_length) + self.creator.add_patient(patient, offset=offset, max_length=max_length) batch_data = self.creator.get_batch_data() if tensor_type is not None: formatter = datasets.formatting.get_formatter(tensor_type, **formatter_kwargs) @@ -243,23 +383,35 @@ def convert_patient(self, patient, offset=0, max_patient_length=None, tensor_typ return batch_data def collate(self, batches: List[Mapping[str, Any]]) -> Mapping[str, Any]: + """A collate function that prepares batches for being fed into a dataloader.""" assert len(batches) == 1, "Can only have one batch when collating" return {"batch": _add_dimension(self.creator.cleanup_batch(batches[0]))} - def convert_dataset(self, dataset, tokens_per_batch: int, min_samples_per_batch: int = 4, num_proc: int = 1): + def convert_dataset(self, dataset, tokens_per_batch: int, min_patients_per_batch: int = 4, num_proc: int = 1): + """Convert an entire dataset to batches. + + Arguments: + dataset: A huggingface dataset containing MEDS patients + tokens_per_batch: The number of tokens allowed per batch + min_patients_per_batch: The minimum number of patients per batch + num_proc: The number of processers to use when converting + + Returns: + A huggingface dataset object containing batches + """ if isinstance(dataset, datasets.DatasetDict): return datasets.DatasetDict( { - k: self.convert_dataset(v, tokens_per_batch, min_samples_per_batch, num_proc) + k: self.convert_dataset(v, tokens_per_batch, min_patients_per_batch, num_proc) for k, v in dataset.items() } ) - max_length = tokens_per_batch // min_samples_per_batch + max_length = tokens_per_batch // min_patients_per_batch lengths = femr.hf_utils.aggregate_over_dataset( dataset, - functools.partial(map_length_stats, processor=self, max_length=max_length), - agg_length_stats, + functools.partial(map_preliminary_batch_stats, processor=self, max_length=max_length), + agg_preliminary_batch_stats, num_proc=num_proc, batch_size=1_000, with_indices=True, diff --git a/src/femr/models/tasks.py b/src/femr/models/tasks.py index fba7035e..f7fe8bd2 100644 --- a/src/femr/models/tasks.py +++ b/src/femr/models/tasks.py @@ -75,49 +75,50 @@ def filter_dataset(self, dataset: datasets.Dataset, index: femr.index.PatientInd def start_patient(self, patient: meds.Patient, _ontology: Optional[femr.ontology.Ontology]) -> None: self.current_labels = self.label_map[patient["patient_id"]] self.current_label_index = 0 - self.patient_id = patient["patient_id"] def needs_exact(self) -> bool: return True def start_batch(self) -> None: - self.patient_ids: List[int] = [] - self.prediction_timestamps: List[float] = [] + """LabeledPatientTask currently has no per label state.""" + pass + + def add_patient_labels(self, _patient_label_offsets: List[int]) -> None: + """As there is no per label state, this is ignored""" + pass def add_event( - self, current_date: datetime.datetime, next_date: datetime.datetime, next_features: Sequence[int] + self, current_date: datetime.datetime, next_date: datetime.datetime, next_features: Optional[Sequence[int]] = None ) -> int: - num_added = 0 + has_label = False + while True: if self.current_label_index == len(self.current_labels): - return num_added + break current_label = self.current_labels[self.current_label_index] is_valid = current_date <= current_label["prediction_time"] next_valid = next_date is not None and next_date <= current_label["prediction_time"] - if not is_valid: - # We don't have any valid representation for this label, so ignore - self.current_label_index += 1 - continue - if next_valid: - # Next one is valid, so break early to give it a chance next time - return num_added - else: - self.patient_ids.append(self.patient_id) - self.prediction_timestamps.append(current_label["prediction_time"].timestamp()) - num_added += 1 - self.current_label_index += 1 + # Next one is valid, so break eary to give it a chance next time + break - assert False, "Should never reach end" + if is_valid: + has_label = True + self.current_label_index += 1 + else: + # The next label isn't valid, so we have to break here + break + + if has_label: + return 1 + else: + return 0 def get_batch_data(self) -> Mapping[str, np.ndarray]: - return { - "patient_ids": np.array(self.patient_ids, dtype=np.int64), - "prediction_timestamps": np.array(self.prediction_timestamps, dtype=np.int64), - } + return {} class CLMBRTask(Task): @@ -129,8 +130,8 @@ def get_task_config(self) -> femr.models.config.FEMRTaskConfig: task_type="clmbr", task_kwargs=dict(clmbr_vocab_size=self.clmbr_vocab_size) ) - def start_patient(self, patient: meds.Patient, _ontology: Optional[femr.ontology.Ontology]) -> None: - pass + def start_patient(self, _patient: meds.Patient, _ontology: Optional[femr.ontology.Ontology]) -> None: + self.per_patient_batch_labels: List[int] = [] def needs_exact(self) -> bool: return False @@ -138,10 +139,13 @@ def needs_exact(self) -> bool: def start_batch(self) -> None: self.batch_labels: List[int] = [] + def add_patient_labels(self, patient_label_offsets: List[int]) -> None: + self.batch_labels.extend([self.per_patient_batch_labels[i] for i in patient_label_offsets]) + def add_event( - self, current_date: datetime.datetime, next_date: datetime.datetime, next_features: Sequence[int] + self, current_date: datetime.datetime, next_date: datetime.datetime, next_features: Optional[Sequence[int]] = None ) -> int: - if len(next_features) == 0: + if next_features is None: return 0 if len(next_features) != 1: @@ -152,7 +156,7 @@ def add_event( if next_feature >= self.clmbr_vocab_size: return 0 - self.batch_labels.append(next_feature) + self.per_patient_batch_labels.append(next_feature) return 1 @@ -322,6 +326,13 @@ def start_patient(self, patient: meds.Patient, ontology: Optional[femr.ontology. assert ontology self.calculator = SurvivalCalculator(ontology, patient, self.pretraining_task_codes) + self.per_patient_censor_time = [] + self.per_patient_time_sparse = { + "data": [], + "indices": [], + "indptr": [0], + } + def needs_exact(self) -> bool: return False @@ -334,6 +345,18 @@ def start_batch(self) -> None: "indptr": [0], } + def add_patient_labels(self, patient_label_offsets: List[int]) -> None: + """Add per-patient labels to the global task labels.""" + self.censor_time.extend([self.per_patient_censor_time[i] for i in patient_label_offsets]) + + for index in patient_label_offsets: + start = self.per_patient_time_sparse['indptr'][index] + end = self.per_patient_time_sparse['indptr'][index + 1] + + self.time_sparse['data'].extend(self.per_patient_time_sparse['data'][start:end]) + self.time_sparse['indices'].extend(self.per_patient_time_sparse['indices'][start:end]) + self.time_sparse['indptr'].append(len(self.time_sparse['indices'])) + def add_event( self, current_date: datetime.datetime, next_date: datetime.datetime, next_features: Sequence[int] ) -> int: @@ -346,16 +369,16 @@ def add_event( return 0 censor_seconds = censor_time.total_seconds() - self.censor_time.append(censor_seconds) + self.per_patient_censor_time.append(censor_seconds) for event_name, time in tte.items(): j = self.task_to_index_map[event_name] seconds = time.total_seconds() - self.time_sparse["data"].append(seconds) - self.time_sparse["indices"].append(j) + self.per_patient_time_sparse["data"].append(seconds) + self.per_patient_time_sparse["indices"].append(j) - self.time_sparse["indptr"].append(len(self.time_sparse["data"])) + self.per_patient_time_sparse["indptr"].append(len(self.per_patient_time_sparse["data"])) return 1 diff --git a/src/femr/models/tokenizer.py b/src/femr/models/tokenizer.py index 08c81974..a8a02185 100644 --- a/src/femr/models/tokenizer.py +++ b/src/femr/models/tokenizer.py @@ -6,7 +6,7 @@ import functools import math import os -from typing import Any, Dict, Mapping, Optional, Set, Union +from typing import Any, Dict, Mapping, Optional, Set, Union, Tuple, List import meds import msgpack @@ -385,9 +385,15 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: ) def start_patient(self): + """Compute per-patient statistics that are required to generate features.""" + + # This is currently a null-op, but is required for cost featurization pass - def get_feature_codes(self, time: datetime.datetime, measurement: meds.Measurement): + def get_feature_codes(self, _time: datetime.datetime, measurement: meds.Measurement) -> Tuple[List[int], Optional[List[float]]]: + """Get codes for the provided measurement and time""" + + # Note that time is currently not used in this code, but it is required for cost featurization if self.is_hierarchical: assert self.ontology is not None codes = [ @@ -431,5 +437,5 @@ def get_feature_codes(self, time: datetime.datetime, measurement: meds.Measureme else: return [], None - def normalize_age(self, age): + def normalize_age(self, age: datetime.timedelta) -> float: return (age.total_seconds() - self.dictionary["age_stats"]["mean"]) / (self.dictionary["age_stats"]["std"]) diff --git a/src/femr/models/transformer.py b/src/femr/models/transformer.py index 8ddb3a69..7e918d5e 100644 --- a/src/femr/models/transformer.py +++ b/src/femr/models/transformer.py @@ -178,12 +178,8 @@ class LabeledPatientTaskHead(nn.Module): def __init__(self, hidden_size: int): super().__init__() - def forward(self, features: torch.Tensor, batch: Mapping[str, torch.Tensor]): - return ( - batch["patient_ids"], - batch["prediction_timestamps"].cpu().numpy().astype("datetime64[s]").astype(datetime.datetime), - features, - ) + def forward(self, features: torch.Tensor, batch: Mapping[str, torch.Tensor], return_logits=False): + return 0, {} class CLMBRTaskHead(nn.Module): @@ -192,12 +188,15 @@ def __init__(self, hidden_size: int, clmbr_vocab_size: int): self.final_layer = nn.Linear(hidden_size, clmbr_vocab_size) - def forward(self, features: torch.Tensor, batch: Mapping[str, torch.Tensor]): + def forward(self, features: torch.Tensor, batch: Mapping[str, torch.Tensor], return_logits=False): logits = self.final_layer(features) labels = batch["labels"] loss = F.cross_entropy(logits, labels) - return loss, logits + if not return_logits: + logits = None + + return loss, {'logits': logits} class MOTORTaskHead(nn.Module): @@ -242,7 +241,7 @@ def forward(self, features: torch.Tensor, batch: Mapping[str, torch.Tensor], ret if not return_logits: time_dependent_logits = None - return loss, time_dependent_logits + return loss, {'time_dependent_logits': time_dependent_logits} def remove_first_dimension(data: Any) -> Any: @@ -285,23 +284,35 @@ def create_task_head(self) -> nn.Module: elif task_type == "motor": return MOTORTaskHead(hidden_size, **task_kwargs) - def forward(self, batch: Mapping[str, Any], return_loss=True): + def forward(self, batch: Mapping[str, Any], return_loss=True, return_logits=False, return_reprs=False): # Need a return_loss parameter for transformers.Trainer to work properly assert return_loss batch = remove_first_dimension(batch) features = self.transformer(batch["transformer"]) - if "task" in batch: + if "task" in batch and self.config.task_config is not None: features = features.reshape(-1, features.shape[-1]) features = features[batch["transformer"]["label_indices"], :] - return self.task_model(features, batch["task"]) + loss, result = self.task_model(features, batch["task"], return_logits=return_logits) + if return_reprs: + result["representations"] = features + if return_logits or return_reprs: + result["timestamps"] = batch["transformer"]["timestamps"][batch["transformer"]["label_indices"]] + result["patient_ids"] = batch["patient_ids"][batch["transformer"]["label_indices"]] + return loss, result else: - return ( - batch["patient_ids"], - batch["transformer"]["timestamps"].cpu().numpy().astype("datetime64[s]").astype(datetime.datetime), - features, - ) + loss = 0 + features = features.reshape(-1, features.shape[-1]) + features = features[batch["transformer"]["label_indices"], :] + result = { + "timestamps": batch["transformer"]["timestamps"][batch["transformer"]["label_indices"]], + "patient_ids": batch["patient_ids"][batch["transformer"]["label_indices"]], + "representations": features, + } + + return loss, result + def compute_features( @@ -327,7 +338,7 @@ def compute_features( model = model.to(device) batches = processor.convert_dataset( - filtered_data, tokens_per_batch=tokens_per_batch, min_samples_per_batch=1, num_proc=num_proc + filtered_data, tokens_per_batch=tokens_per_batch, min_patients_per_batch=1, num_proc=num_proc ) batches.set_format("pt", device=device) @@ -339,13 +350,13 @@ def compute_features( for batch in batches: batch = processor.collate([batch])["batch"] with torch.no_grad(): - patient_ids, feature_times, representations = model(batch) - all_patient_ids.append(patient_ids.cpu().numpy()) - all_feature_times.append(feature_times) - all_representations.append(representations.cpu().numpy()) + _, result = model(batch, return_reprs=True) + all_patient_ids.append(result["patient_ids"].cpu().numpy()) + all_feature_times.append(result["timestamps"].cpu().numpy()) + all_representations.append(result["representations"].cpu().numpy()) return { "patient_ids": np.concatenate(all_patient_ids), - "feature_times": np.concatenate(all_feature_times), + "feature_times": np.concatenate(all_feature_times).astype('datetime64[s]'), "features": np.concatenate(all_representations), } diff --git a/tests/models/test_batch_creator.py b/tests/models/test_batch_creator.py new file mode 100644 index 00000000..2b951e1d --- /dev/null +++ b/tests/models/test_batch_creator.py @@ -0,0 +1,183 @@ +import femr.models.processor +import femr.models.tasks +import datetime + +from femr_test_tools import create_patients_dataset + +class DummyTokenizer: + def __init__(self, is_hierarchical: bool = False): + self.is_hierarchical = is_hierarchical + self.ontology = None + self.vocab_size = 100 + + def start_patient(self): + pass + + def get_feature_codes(self, time, measurement): + if measurement['code'] == 'SNOMED/184099003': + return [1], None + else: + return [int(measurement['code'])], None + + def normalize_age(self, age): + return 0.5 + +def assert_two_batches_equal_third(batch1, batch2, batch3): + """This asserts that batch1 + batch2 = batchs3""" + assert batch3['patient_ids'].tolist() == batch1['patient_ids'].tolist() + batch2['patient_ids'].tolist() + + assert batch3['transformer']['ages'].tolist() == batch1['transformer']['ages'].tolist() + batch2['transformer']['ages'].tolist() + assert batch3['transformer']['timestamps'].tolist() == batch1['transformer']['timestamps'].tolist() + batch2['transformer']['timestamps'].tolist() + + + # Checking the label indices is a bit more involved as we have to map to age/patient id and then check that + target_label_ages = [] + target_label_patient_ids = [] + + for label_index in batch1['transformer']['label_indices'].tolist(): + target_label_ages.append(batch1['transformer']['ages'][label_index]) + target_label_patient_ids.append(batch1['patient_ids'][label_index]) + + for label_index in batch2['transformer']['label_indices'].tolist(): + target_label_ages.append(batch2['transformer']['ages'][label_index]) + target_label_patient_ids.append(batch2['patient_ids'][label_index]) + + actual_label_ages = [] + actual_label_patient_ids = [] + + for label_index in batch3['transformer']['label_indices'].tolist(): + actual_label_ages.append(batch3['transformer']['ages'][label_index]) + actual_label_patient_ids.append(batch3['patient_ids'][label_index]) + + assert target_label_ages == actual_label_ages + assert target_label_patient_ids == actual_label_patient_ids + + if 'tokens' in batch3['transformer']: + assert batch3['transformer']['tokens'].tolist() == batch1['transformer']['tokens'].tolist() + batch2['transformer']['tokens'].tolist() + +def test_two_patients_concat_no_task(): + tokenizer = DummyTokenizer() + + fake_patients = create_patients_dataset(10) + + fake_patient1 = fake_patients[1] + fake_patient2 = fake_patients[5] + + creator = femr.models.processor.BatchCreator(tokenizer) + + creator.start_batch() + creator.add_patient(fake_patient1) + + data_for_patient1 = creator.get_batch_data() + + creator.start_batch() + creator.add_patient(fake_patient2) + + data_for_patient2 = creator.get_batch_data() + + + creator.start_batch() + creator.add_patient(fake_patient1) + creator.add_patient(fake_patient2) + + data_for_patients = creator.get_batch_data() + + assert_two_batches_equal_third(data_for_patient1, data_for_patient2, data_for_patients) + +def test_split_patients_concat_no_task(): + tokenizer = DummyTokenizer() + + fake_patients = create_patients_dataset(10) + + fake_patient = fake_patients[1] + + creator = femr.models.processor.BatchCreator(tokenizer) + + creator.start_batch() + creator.add_patient(fake_patient) + + data_for_patient = creator.get_batch_data() + + length = len(data_for_patient['transformer']['timestamps']) + + creator.start_batch() + creator.add_patient(fake_patient, offset=0, max_length = length // 2) + + data_for_part1 = creator.get_batch_data() + + creator.start_batch() + creator.add_patient(fake_patient, offset=length // 2, max_length = None) + + data_for_part2 = creator.get_batch_data() + + assert_two_batches_equal_third(data_for_part1, data_for_part2, data_for_patient) + +def test_two_patients_concat_task(): + tokenizer = DummyTokenizer() + + fake_patients = create_patients_dataset(10) + + task = femr.models.tasks.LabeledPatientTask([ + {'patient_id': 1, 'prediction_time': datetime.datetime(2011, 7, 6)}, + {'patient_id': 1, 'prediction_time': datetime.datetime(2017, 1, 1)}, + + {'patient_id': 5, 'prediction_time': datetime.datetime(2011, 11, 6)}, + {'patient_id': 5, 'prediction_time': datetime.datetime(2017, 2, 1)} + ]) + + fake_patient1 = fake_patients[1] + fake_patient2 = fake_patients[5] + + creator = femr.models.processor.BatchCreator(tokenizer, task=task) + + creator.start_batch() + creator.add_patient(fake_patient1) + + data_for_patient1 = creator.get_batch_data() + + creator.start_batch() + creator.add_patient(fake_patient2) + + data_for_patient2 = creator.get_batch_data() + + + creator.start_batch() + creator.add_patient(fake_patient1) + creator.add_patient(fake_patient2) + + data_for_patients = creator.get_batch_data() + + assert_two_batches_equal_third(data_for_patient1, data_for_patient2, data_for_patients) + +def test_split_patients_concat_task(): + tokenizer = DummyTokenizer() + + fake_patients = create_patients_dataset(10) + + fake_patient = fake_patients[1] + + task = femr.models.tasks.LabeledPatientTask([ + {'patient_id': 1, 'prediction_time': datetime.datetime(2011, 7, 6)}, + {'patient_id': 1, 'prediction_time': datetime.datetime(2017, 1, 1)}, + ]) + + creator = femr.models.processor.BatchCreator(tokenizer, task=task) + + creator.start_batch() + creator.add_patient(fake_patient) + + data_for_patient = creator.get_batch_data() + + length = len(data_for_patient['transformer']['timestamps']) + + creator.start_batch() + creator.add_patient(fake_patient, offset=0, max_length = length // 2) + + data_for_part1 = creator.get_batch_data() + + creator.start_batch() + creator.add_patient(fake_patient, offset=length // 2, max_length = None) + + data_for_part2 = creator.get_batch_data() + + assert_two_batches_equal_third(data_for_part1, data_for_part2, data_for_patient) \ No newline at end of file diff --git a/tutorials/1_Ontology.ipynb b/tutorials/1_Ontology.ipynb index 3a2d8097..3a260234 100644 --- a/tutorials/1_Ontology.ipynb +++ b/tutorials/1_Ontology.ipynb @@ -47,6 +47,14 @@ "execution_count": 1, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, { "name": "stdout", "output_type": "stream", @@ -72,18 +80,12 @@ "metadata": {}, "outputs": [ { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "b7a8199d5f61425797611df1d2f3bfcd", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Map: 0%| | 0/200 [00:00\n", " \n", " \n", - " [1200/1200 00:14, Epoch 100/100]\n", + " [1200/1200 00:13, Epoch 100/100]\n", " \n", " \n", " \n", @@ -251,303 +248,303 @@ " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", + " \n", + " \n", " \n", " \n", "
204.0064003.9614784.2187004.201308
403.6795003.6930923.9251003.901888
603.3704003.4425383.6400003.605798
803.0687003.2096403.3686003.318301
1002.8052002.9990433.1153003.044958
1202.5856002.8155622.8612002.788877
1402.3595002.6599152.6533002.559976
1602.2143002.5357222.4306002.361309
1802.0324002.4368632.2844002.189811
2001.9680002.3594102.1236002.049350
2201.7890002.3006652.0645001.939663
2401.8202002.2558931.9018001.849089
2601.7446002.2200371.8665001.779356
2801.6625002.1918351.7963001.724157
3001.6622002.1691141.7487001.683180
3201.6535002.1494381.7087001.649142
3401.5808002.1323901.6825001.612055
3601.5554002.1171961.6524001.589550
3801.5675002.1049751.6203001.565886
4001.5205002.0943061.6290001.551617
4201.5189002.0832271.5607001.532212
4401.5117002.0744781.5621001.517776
4601.4977002.0657981.5528001.504316
4801.4619002.0579341.5279001.489016
5001.4672002.0501421.5090001.479996
5201.4649002.0436291.5034001.466780
5401.4333002.0375241.5014001.459986
5601.4426002.0318291.4890001.452937
5801.3830002.0262411.4716001.447207
6001.4525002.0205421.4517001.436194
6201.4154002.0156861.4617001.430223
6401.4016002.0121661.4402001.427563
6601.3890002.0079001.4269001.418490
6801.3990002.0048481.4356001.412927
7001.3516002.0006321.4055001.407846
7201.3970001.9968421.4202001.405152
7401.3949001.9939961.4015001.401981
7601.3348001.9916061.4104001.393342
7801.3685001.9882881.3947001.393573
8001.3825001.9867391.3899001.390646
8201.3329001.9842291.4067001.386406
8401.3433001.9822761.3643001.384271
8601.3185001.9803631.3669001.380643
8801.3718001.9782541.4038001.380061
9001.3353001.9771331.3534001.376688
9201.3454001.9754131.3632001.374987
9401.3143001.9737121.3517001.372282
9601.3394001.9727151.3797001.370155
9801.3031001.9715541.3618001.368772
10001.3568001.9708831.3794001.368523
10201.3190001.9697001.3309001.366558
10401.3363001.9690401.3522001.366861
10601.2879001.9684501.3459001.365087
10801.3396001.9677501.3577001.364698
11001.3088001.9672231.3515001.363665
11201.3491001.9668761.3611001.362971
11401.2960001.9666111.3322001.362701
11601.3372001.9663291.3494001.362505
11801.2773001.9662161.3368001.362360
12001.3344001.9661831.3529001.362225

" @@ -587,7 +584,7 @@ "\n", " output_dir='tmp_trainer',\n", " remove_unused_columns=False,\n", - " num_train_epochs=100,\n", + " num_train_epochs=20,\n", "\n", " eval_steps=20,\n", " evaluation_strategy=\"steps\",\n", @@ -628,7 +625,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.10.14" } }, "nbformat": 4, diff --git a/tutorials/5_CLMBR Featurization And Modeling.ipynb b/tutorials/5_CLMBR Featurization And Modeling.ipynb index c5bd2789..70ad33ca 100644 --- a/tutorials/5_CLMBR Featurization And Modeling.ipynb +++ b/tutorials/5_CLMBR Featurization And Modeling.ipynb @@ -38,17 +38,15 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/esteinberg/miniconda3/envs/femrv2/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", - "/home/esteinberg/miniconda3/envs/femrv2/lib/python3.11/site-packages/torch/cuda/__init__.py:611: UserWarning: Can't initialize NVML\n", - " warnings.warn(\"Can't initialize NVML\")\n", - "Map (num_proc=4): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 1555.73 examples/s]\n", - "Some weights of the model checkpoint at input/clmbr_model were not used when initializing FEMRModel: ['transformer.in_norm.scale', 'transformer.layers.1.norm.scale', 'task_model.final_layer.weight', 'task_model.final_layer.bias', 'transformer.out_norm.scale', 'transformer.layers.0.norm.scale']\n", + "Map (num_proc=4): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 1419.57 examples/s]\n", + "Some weights of the model checkpoint at input/clmbr_model were not used when initializing FEMRModel: ['task_model.final_layer.bias', 'task_model.final_layer.weight', 'transformer.in_norm.scale', 'transformer.layers.0.norm.scale', 'transformer.layers.1.norm.scale', 'transformer.out_norm.scale']\n", "- This IS expected if you are initializing FEMRModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing FEMRModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", - "Some weights of FEMRModel were not initialized from the model checkpoint at input/clmbr_model and are newly initialized: ['transformer.layers.0.norm.weight', 'transformer.layers.1.norm.weight', 'transformer.in_norm.weight', 'transformer.out_norm.weight']\n", + "Some weights of FEMRModel were not initialized from the model checkpoint at input/clmbr_model and are newly initialized: ['transformer.in_norm.weight', 'transformer.layers.0.norm.weight', 'transformer.layers.1.norm.weight', 'transformer.out_norm.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", - "Map (num_proc=4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 453.45 examples/s]\n" + "Map (num_proc=4): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 952.94 examples/s]\n" ] }, { @@ -62,7 +60,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Generating train split: 5 examples [00:00, 5.98 examples/s]" + "Generating train split: 5 examples [00:00, 19.21 examples/s]" ] }, { @@ -191,7 +189,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "bad5ad4f", "metadata": {}, "outputs": [ @@ -260,7 +258,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.10.14" } }, "nbformat": 4, diff --git a/tutorials/6_Train MOTOR.ipynb b/tutorials/6_Train MOTOR.ipynb index 4ab77e8d..5268f1a2 100644 --- a/tutorials/6_Train MOTOR.ipynb +++ b/tutorials/6_Train MOTOR.ipynb @@ -48,9 +48,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/esteinberg/miniconda3/envs/femrv2/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", - "Map (num_proc=4): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 1512.26 examples/s]" + "Map (num_proc=4): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 1497.28 examples/s]" ] }, { @@ -66,7 +66,7 @@ "output_type": "stream", "text": [ "\n", - "Map (num_proc=4): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 170/170 [00:00<00:00, 1514.12 examples/s]" + "Map (num_proc=4): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 170/170 [00:00<00:00, 1367.32 examples/s]\n" ] }, { @@ -84,13 +84,6 @@ " })\n", "})\n" ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] } ], "source": [ @@ -133,7 +126,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Map (num_proc=4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 170/170 [00:00<00:00, 460.40 examples/s]\n" + "Map (num_proc=4): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 170/170 [00:00<00:00, 330.27 examples/s]\n" ] } ], @@ -164,7 +157,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Map (num_proc=4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 170/170 [00:00<00:00, 414.49 examples/s]\n" + "Map (num_proc=4): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 170/170 [00:00<00:00, 304.29 examples/s]\n" ] } ], @@ -199,7 +192,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Map (num_proc=4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 144/144 [00:00<00:00, 348.22 examples/s]\n" + "Map (num_proc=4): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 144/144 [00:00<00:00, 270.70 examples/s]\n" ] }, { @@ -213,8 +206,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "Generating train split: 7 examples [00:00, 11.81 examples/s]\n", - "Map (num_proc=4): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:00<00:00, 68.03 examples/s]\n" + "Generating train split: 7 examples [00:00, 12.19 examples/s]\n", + "Map (num_proc=4): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:00<00:00, 46.04 examples/s]\n" ] }, { @@ -229,7 +222,7 @@ "output_type": "stream", "text": [ "Setting num_proc from 4 back to 1 for the train split to disable multiprocessing as it only contains one shard.\n", - "Generating train split: 1 examples [00:00, 76.13 examples/s]" + "Generating train split: 1 examples [00:00, 73.97 examples/s]" ] }, { @@ -280,14 +273,11 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/esteinberg/miniconda3/envs/femrv2/lib/python3.11/site-packages/torch/cuda/__init__.py:611: UserWarning: Can't initialize NVML\n", + "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/torch/cuda/__init__.py:628: UserWarning: Can't initialize NVML\n", " warnings.warn(\"Can't initialize NVML\")\n", - "2024-02-15 01:17:39.969569: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2024-02-15 01:17:39.969629: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2024-02-15 01:17:39.971090: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "2024-02-15 01:17:39.981133: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", - "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", - "2024-02-15 01:17:41.176250: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", + "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \n", + "dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\n", + " warnings.warn(\n", "Could not estimate the number of tokens of the input, floating-point operations will not be computed\n" ] }, @@ -311,178 +301,178 @@ " \n", " \n", " 20\n", - " 0.824000\n", - " 0.506734\n", + " 0.818100\n", + " 0.525866\n", " \n", " \n", " 40\n", - " 0.790300\n", - " 0.506806\n", + " 0.822800\n", + " 0.525876\n", " \n", " \n", " 60\n", - " 0.858900\n", - " 0.506890\n", + " 0.858700\n", + " 0.525893\n", " \n", " \n", " 80\n", - " 0.789600\n", - " 0.506971\n", + " 0.823100\n", + " 0.525914\n", " \n", " \n", " 100\n", - " 0.828200\n", - " 0.507057\n", + " 0.811400\n", + " 0.525938\n", " \n", " \n", " 120\n", - " 0.787200\n", - " 0.507141\n", + " 0.842300\n", + " 0.525971\n", " \n", " \n", " 140\n", - " 0.823900\n", - " 0.507234\n", + " 0.799100\n", + " 0.525999\n", " \n", " \n", " 160\n", - " 0.821700\n", - " 0.507326\n", + " 0.837400\n", + " 0.526033\n", " \n", " \n", " 180\n", - " 0.821900\n", - " 0.507422\n", + " 0.812000\n", + " 0.526067\n", " \n", " \n", " 200\n", - " 0.793500\n", - " 0.507510\n", + " 0.820400\n", + " 0.526105\n", " \n", " \n", " 220\n", - " 0.822300\n", - " 0.507599\n", + " 0.836600\n", + " 0.526146\n", " \n", " \n", " 240\n", - " 0.797900\n", - " 0.507683\n", + " 0.805800\n", + " 0.526186\n", " \n", " \n", " 260\n", - " 0.808600\n", - " 0.507770\n", + " 0.832500\n", + " 0.526230\n", " \n", " \n", " 280\n", - " 0.822000\n", - " 0.507853\n", + " 0.817100\n", + " 0.526271\n", " \n", " \n", " 300\n", - " 0.821800\n", - " 0.507933\n", + " 0.836300\n", + " 0.526312\n", " \n", " \n", " 320\n", - " 0.795300\n", - " 0.508017\n", + " 0.792800\n", + " 0.526349\n", " \n", " \n", " 340\n", - " 0.790100\n", - " 0.508093\n", + " 0.839500\n", + " 0.526394\n", " \n", " \n", " 360\n", - " 0.821200\n", - " 0.508164\n", + " 0.816300\n", + " 0.526433\n", " \n", " \n", " 380\n", - " 0.840000\n", - " 0.508235\n", + " 0.837300\n", + " 0.526467\n", " \n", " \n", " 400\n", - " 0.789600\n", - " 0.508295\n", + " 0.832900\n", + " 0.526508\n", " \n", " \n", " 420\n", - " 0.819400\n", - " 0.508359\n", + " 0.796200\n", + " 0.526538\n", " \n", " \n", " 440\n", - " 0.795600\n", - " 0.508415\n", + " 0.834400\n", + " 0.526575\n", " \n", " \n", " 460\n", - " 0.843600\n", - " 0.508471\n", + " 0.796000\n", + " 0.526606\n", " \n", " \n", " 480\n", - " 0.789400\n", - " 0.508523\n", + " 0.812800\n", + " 0.526634\n", " \n", " \n", " 500\n", - " 0.793500\n", - " 0.508568\n", + " 0.872000\n", + " 0.526664\n", " \n", " \n", " 520\n", - " 0.810000\n", - " 0.508606\n", + " 0.779700\n", + " 0.526686\n", " \n", " \n", " 540\n", - " 0.819500\n", - " 0.508646\n", + " 0.834500\n", + " 0.526710\n", " \n", " \n", " 560\n", - " 0.818500\n", - " 0.508676\n", + " 0.815000\n", + " 0.526728\n", " \n", " \n", " 580\n", - " 0.819700\n", - " 0.508704\n", + " 0.834300\n", + " 0.526747\n", " \n", " \n", " 600\n", - " 0.793100\n", - " 0.508728\n", + " 0.776000\n", + " 0.526761\n", " \n", " \n", " 620\n", - " 0.832700\n", - " 0.508747\n", + " 0.831600\n", + " 0.526774\n", " \n", " \n", " 640\n", - " 0.794800\n", - " 0.508761\n", + " 0.868200\n", + " 0.526784\n", " \n", " \n", " 660\n", - " 0.811900\n", - " 0.508772\n", + " 0.805700\n", + " 0.526790\n", " \n", " \n", " 680\n", - " 0.825900\n", - " 0.508779\n", + " 0.791600\n", + " 0.526794\n", " \n", " \n", " 700\n", - " 0.788300\n", - " 0.508781\n", + " 0.833400\n", + " 0.526796\n", " \n", " \n", "

" @@ -493,13 +483,6 @@ }, "metadata": {}, "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Checkpoint destination directory tmp_trainer/checkpoint-500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n" - ] } ], "source": [ @@ -572,7 +555,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.10.14" } }, "nbformat": 4, diff --git a/tutorials/7_MOTOR Featurization And Modeling.ipynb b/tutorials/7_MOTOR Featurization And Modeling.ipynb index 0fc3bd1a..e88cf164 100644 --- a/tutorials/7_MOTOR Featurization And Modeling.ipynb +++ b/tutorials/7_MOTOR Featurization And Modeling.ipynb @@ -38,17 +38,15 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/esteinberg/miniconda3/envs/femrv2/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "/home/esteinberg/miniconda3/envs/debug_document_femr/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", - "/home/esteinberg/miniconda3/envs/femrv2/lib/python3.11/site-packages/torch/cuda/__init__.py:611: UserWarning: Can't initialize NVML\n", - " warnings.warn(\"Can't initialize NVML\")\n", - "Map (num_proc=4): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 1578.13 examples/s]\n", - "Some weights of the model checkpoint at input/motor_model were not used when initializing FEMRModel: ['task_model.final_layer.bias', 'task_model.task_embedding.weight', 'transformer.out_norm.scale', 'transformer.layers.1.norm.scale', 'transformer.in_norm.scale', 'task_model.final_layer.weight', 'task_model.task_embedding_bias.weight', 'transformer.layers.0.norm.scale']\n", + "Map (num_proc=4): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 1384.33 examples/s]\n", + "Some weights of the model checkpoint at input/motor_model were not used when initializing FEMRModel: ['task_model.final_layer.bias', 'task_model.final_layer.weight', 'task_model.task_embedding.weight', 'task_model.task_embedding_bias.weight', 'transformer.in_norm.scale', 'transformer.layers.0.norm.scale', 'transformer.layers.1.norm.scale', 'transformer.out_norm.scale']\n", "- This IS expected if you are initializing FEMRModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing FEMRModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", - "Some weights of FEMRModel were not initialized from the model checkpoint at input/motor_model and are newly initialized: ['transformer.in_norm.weight', 'transformer.out_norm.weight', 'transformer.layers.0.norm.weight', 'transformer.layers.1.norm.weight']\n", + "Some weights of FEMRModel were not initialized from the model checkpoint at input/motor_model and are newly initialized: ['transformer.in_norm.weight', 'transformer.layers.0.norm.weight', 'transformer.layers.1.norm.weight', 'transformer.out_norm.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", - "Map (num_proc=4): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 472.02 examples/s]\n" + "Map (num_proc=4): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 327.93 examples/s]\n" ] }, { @@ -62,7 +60,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Generating train split: 4 examples [00:00, 9.14 examples/s]" + "Generating train split: 4 examples [00:00, 5.05 examples/s]" ] }, { @@ -265,7 +263,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.10.14" } }, "nbformat": 4,