Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
EthanSteinbergPrealize committed Mar 21, 2024
1 parent 250f9ee commit bfcac77
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 112 deletions.
2 changes: 1 addition & 1 deletion src/femr/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from femr._version import __version__ # noqa
from femr._version import __version__ # noqa
22 changes: 12 additions & 10 deletions src/femr/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@ def __init__(
**kwargs,
) -> None:
"""Defined a configuration for a FEMR Transformer.
Arguments:
vocab_size: The number of tokens in the vocabulary
is_hierarchical: Whether to use a hierarchical vocabulary. See FEMRTokenizer for more information.
is_hierarchical: Whether to use a hierarchical vocabulary. See FEMRTokenizer for more information
hidden_size: The internal representation size
intermediate_size: The size of the FFN in the transformer layers
n_heads: The number of attention heads
n_layers: The number of transformer encoder layers
attention_width: FEMR by default uses a local attention transformer. This defines the width of those attention windows.
attention_width: FEMR by default uses a local attention transformer with a width defined here
use_normed_ages: Whether or not to provide normalized ages as a feature to the model
use_bias: Whether or not to use bias terms in the transformer layers. Current research indicates that this should be set to False.
hidden_act: The type of activation function to use in the transformer.
"""
use_bias: Whether or not to use bias terms in the transformer layers
hidden_act: The type of activation function to use in the transformer
"""
super().__init__(**kwargs)

self.vocab_size = vocab_size
Expand All @@ -56,9 +56,9 @@ def __init__(self, task_type: str = "", task_kwargs: Mapping[str, Any] = {}, **k
"""A generic FEMR task definition. This holds state used for initalizing a tasks.py class.
Task.get_task_config returns the task type and kwargs used to initialize this.
Arguments:
task_type: The name of the task.
task_type: The name of the task.
task_kwargs: Arbitrary arguments used to store state for that task.
"""
super().__init__(**kwargs)
Expand All @@ -68,15 +68,17 @@ def __init__(self, task_type: str = "", task_kwargs: Mapping[str, Any] = {}, **k

class FEMRModelConfig(transformers.PretrainedConfig):
"""A model config is defined as the combination of a transformer config and a task config."""

def __init__(
self,
transformer_config: Optional[Dict[str, Any]] = None,
task_config: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""A combination of a transformer config and a task config.
It is possible to initialize this with only a transformer config, in which case the model will be configured for inference only.
It is possible to initialize this with only a transformer config, in which
case the model will be configured for inference only.
"""
super().__init__(**kwargs)
if transformer_config is None:
Expand Down
93 changes: 53 additions & 40 deletions src/femr/models/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
import collections
import datetime
import functools
from typing import Any, Dict, List, Mapping, Tuple, Optional
from typing import Any, Dict, List, Mapping, Optional, Tuple

import datasets
import meds
import numpy as np
import torch.utils.data

import meds

import femr.hf_utils
import femr.models.tokenizer
import femr.pat_utils
Expand All @@ -20,16 +19,21 @@ def map_preliminary_batch_stats(batch, indices, *, processor: FEMRBatchProcessor
"""
This function creates preliminary batch statistics, to be used for final batching.
The overall problem this is trying to solve is that Patient records can be very long, so when we featurize patients we actually featurize subsequences of patients.
The overall problem this is trying to solve is that Patient records can be very long, so
when we featurize patients we actually featurize subsequences of patients.
AKA every patient in a batch is actually (patient_id, start_index, length), which specifies a subsequence of the patient.
AKA every patient in a batch is actually (patient_id, start_index, length), which
specifies a subsequence of the patient.
The trickiness becomes when a particular patient has multiple labels, some of which will require multiple subsequences.
The trickiness becomes when a particular patient has multiple labels, some
of which will require multiple subsequences.
The goal of this function is to compute the list [(patient_id, start_index, length)] such that every label is covered
by at least one batch. Note that some labels will be covered multiple times.
The goal of this function is to compute the list [(patient_id, start_index, length)] such
that every label is covered by at least one batch. Note that some labels will be covered multiple times.
Note that there is a special setting for tasks that don't need exact labeling
(needs_exact in tasks.py returns False).
Note that there is a special setting for tasks that don't need exact labeling (needs_exact in tasks.py returns False).
For these patients we only generate one tuple for them, and drop some labels.
Later code will then take [(patient_id, start_index, length)], and create actual batches.
Expand Down Expand Up @@ -80,7 +84,8 @@ def agg_preliminary_batch_stats(lengths1, lengths2):

class BatchCreator:
"""The BatchCreator is designed to generate batches from patient data."""
def __init__(self, tokenizer: femr.models.FEMRTokenizer, task: Optional[femr.tasks.Task] = None):

def __init__(self, tokenizer: femr.models.tokenizer.FEMRTokenizer, task: Optional[femr.models.tasks.Task] = None):
"""Initialize a BatchCreator, with a tokenizer, and optionally a task."""
self.tokenizer = tokenizer
self.task = task
Expand Down Expand Up @@ -111,20 +116,21 @@ def start_batch(self):

def add_patient(self, patient: meds.Patient, offset: int = 0, max_length: Optional[int] = None):
"""Add a patient to the current batch.
Note that the two optional parameters are used to add a subset of a patient to a batch.
It is generally recommended to never manually use offset or max_length as you should rely on FEMRBatchProcessor.convert_dataset.
It is generally recommended to never manually use offset or max_length as
you should rely on FEMRBatchProcessor.convert_dataset.
Arguments:
patient: The patient to add.
offset: The offset into the patient to featurize.
max_length: The maximum length of the batch sequence. There is no max when left at None.
max_length: The maximum length of the batch sequence. There is no max when left at None.
"""
current_date = None
last_time = None


# The overall algorithm here is a bit complex
# First we featurize the entire patient
# Then we slice the patient indices according to offset and max_length
Expand All @@ -147,9 +153,9 @@ def add_patient(self, patient: meds.Patient, offset: int = 0, max_length: Option
# For a hierarchical tokenizer, we have a more complex setup
# These are designed to match the inputs required for an EmbeddingBag.
# See PyTorch's EmbeddingBag documentation to understand what these mean.
per_patient_hierarchical_tokens = []
per_patient_hierarchical_weights = []
per_patient_token_indices = [0]
per_patient_hierarchical_tokens: List[int] = []
per_patient_hierarchical_weights: List[float] = []
per_patient_token_indices: List[int] = [0]

if self.task is not None:
self.task.start_patient(patient, self.tokenizer.ontology)
Expand Down Expand Up @@ -188,6 +194,7 @@ def add_patient(self, patient: meds.Patient, offset: int = 0, max_length: Option
assert len(features) == 1
per_patient_tokens.append(features[0])
else:
assert weights is not None
per_patient_hierarchical_tokens.extend(features)
per_patient_hierarchical_weights.extend(weights)
per_patient_token_indices.append(len(per_patient_hierarchical_tokens))
Expand All @@ -198,7 +205,6 @@ def add_patient(self, patient: meds.Patient, offset: int = 0, max_length: Option

last_time = event["time"]


if self.task is not None and last_time is not None:
num_added = self.task.add_event(last_time, None, None)
for _ in range(num_added):
Expand All @@ -218,32 +224,36 @@ def add_patient(self, patient: meds.Patient, offset: int = 0, max_length: Option

# Let's add the constants first
self.valid_tokens.extend([True] * length_to_add)
self.patient_ids.extend([patient['patient_id']] * length_to_add)
self.patient_ids.extend([patient["patient_id"]] * length_to_add)
self.offsets.append(offset)
self.patient_lengths.append(length_to_add)

# Ages, normalized ages and timestamps are also easy to add
self.ages.extend(per_patient_ages[offset: offset + length_to_add])
self.normalized_ages.extend(per_patient_normalized_ages[offset: offset + length_to_add])
self.timestamps.extend(per_patient_timestamps[offset: offset + length_to_add])
self.ages.extend(per_patient_ages[offset : offset + length_to_add])
self.normalized_ages.extend(per_patient_normalized_ages[offset : offset + length_to_add])
self.timestamps.extend(per_patient_timestamps[offset : offset + length_to_add])

if not self.tokenizer.is_hierarchical:
# Easy for simple tokenizer
self.tokens.extend(per_patient_tokens[offset: offset + length_to_add])
self.tokens.extend(per_patient_tokens[offset : offset + length_to_add])
else:
# Hierarchical tokenizer is more complex since we have to shift the indices as well
# Remember, these arrays are all designed for PyTorch EmbeddingBag

# We need to get the start and end at a particular offset
internal_start = per_patient_token_indices[offset]
internal_end = per_patient_token_indices[offset + length_to_add]

# We need to offset the token indices to account for the existing tokens
self.token_indices.extend([len(self.hierarchical_tokens) - internal_start + value for value in per_patient_token_indices[offset + 1:offset + length_to_add + 1]])

self.hierarchical_tokens.extend(per_patient_hierarchical_tokens[internal_start: internal_end])
self.hierarchical_weights.extend(per_patient_hierarchical_weights[internal_start: internal_end])
# We need to offset the token indices to account for the existing tokens
self.token_indices.extend(
[
len(self.hierarchical_tokens) - internal_start + value
for value in per_patient_token_indices[offset + 1 : offset + length_to_add + 1]
]
)

self.hierarchical_tokens.extend(per_patient_hierarchical_tokens[internal_start:internal_end])
self.hierarchical_weights.extend(per_patient_hierarchical_weights[internal_start:internal_end])

# The label indices are also a bit tricky as they have to be offset accordingly.
# We also need to collect good labels that should be sent to the final numpy arrays.
Expand All @@ -258,7 +268,6 @@ def add_patient(self, patient: meds.Patient, offset: int = 0, max_length: Option
if self.task is not None:
self.task.add_patient_labels(labels_to_add)


def get_batch_data(self):
"""Convert the batch to numpy arrays. The data structure is defined inline in this function."""
if self.tokenizer.vocab_size <= 2**15:
Expand All @@ -269,19 +278,14 @@ def get_batch_data(self):
transformer = {
# Whether or not the token is valid at this index
"valid_tokens": np.array(self.valid_tokens),

# The age of the patient in days at this index
"ages": np.array(self.ages, dtype=np.float32),

# The normalized ages at this index
"normalized_ages": np.array(self.normalized_ages, dtype=np.float16),

# The timestamp (in seconds) at this index
"timestamps": np.array(self.timestamps, dtype=np.int64),

# The length of the patient
"patient_lengths": np.array(self.patient_lengths, dtype=np.int32),

# The indices of the labels
"label_indices": np.array(self.label_indices, dtype=np.int32),
}
Expand Down Expand Up @@ -353,24 +357,33 @@ def _add_dimension(data: Any) -> Any:

class FEMRBatchProcessor:
"""The FEMR Batch processor creates batches for processing by a transformer."""

def __init__(self, tokenizer, task=None):
self.creator = BatchCreator(tokenizer, task)

def convert_patient(self, patient: meds.Patient, offset: int =0, max_length: Optional[int]=None, tensor_type=None, **formatter_kwargs):
def convert_patient(
self,
patient: meds.Patient,
offset: int = 0,
max_length: Optional[int] = None,
tensor_type=None,
**formatter_kwargs,
):
"""Convert a single patient to a batch.
Note that this can also convert parts of a patient to a batch using the offset and max_length parameters.
This is useful for processing long patients.
NOTE: This function is primarily for debugging purposes. It is recommended to use convert_dataset for maximum correctness and efficiency.
NOTE: This function is primarily for debugging purposes. It is
recommended to use convert_dataset for maximum correctness and efficiency.
Arguments:
patient: The patient to convert
offset: The integer offset into the patient to convert
max_length: The maximum length to convert
tensor_type: The dataset to return
formatter_kwargs: Arguments for a datasets formatter when converting datatypes
Returns:
A batch, ready to be fed into a FEMR transformer model
"""
Expand All @@ -389,13 +402,13 @@ def collate(self, batches: List[Mapping[str, Any]]) -> Mapping[str, Any]:

def convert_dataset(self, dataset, tokens_per_batch: int, min_patients_per_batch: int = 4, num_proc: int = 1):
"""Convert an entire dataset to batches.
Arguments:
dataset: A huggingface dataset containing MEDS patients
tokens_per_batch: The number of tokens allowed per batch
min_patients_per_batch: The minimum number of patients per batch
num_proc: The number of processers to use when converting
Returns:
A huggingface dataset object containing batches
"""
Expand Down
40 changes: 28 additions & 12 deletions src/femr/models/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,20 @@ def start_batch(self) -> None:
def start_patient(self, patient: meds.Patient, ontology: Optional[femr.ontology.Ontology]) -> None:
...

@abc.abstractmethod
def add_patient_labels(self, patient_label_offsets: List[int]) -> None:
...

@abc.abstractmethod
def needs_exact(self) -> bool:
...

@abc.abstractmethod
def add_event(
self, current_date: datetime.datetime, next_date: datetime.datetime, next_features: Sequence[int]
self,
current_date: datetime.datetime,
next_date: Optional[datetime.datetime],
next_features: Optional[Sequence[int]],
) -> int:
...

Expand Down Expand Up @@ -88,7 +95,10 @@ def add_patient_labels(self, _patient_label_offsets: List[int]) -> None:
pass

def add_event(
self, current_date: datetime.datetime, next_date: datetime.datetime, next_features: Optional[Sequence[int]] = None
self,
current_date: datetime.datetime,
next_date: Optional[datetime.datetime],
next_features: Optional[Sequence[int]] = None,
) -> int:
has_label = False

Expand All @@ -111,7 +121,7 @@ def add_event(
else:
# The next label isn't valid, so we have to break here
break

if has_label:
return 1
else:
Expand Down Expand Up @@ -143,7 +153,10 @@ def add_patient_labels(self, patient_label_offsets: List[int]) -> None:
self.batch_labels.extend([self.per_patient_batch_labels[i] for i in patient_label_offsets])

def add_event(
self, current_date: datetime.datetime, next_date: datetime.datetime, next_features: Optional[Sequence[int]] = None
self,
current_date: datetime.datetime,
next_date: Optional[datetime.datetime],
next_features: Optional[Sequence[int]] = None,
) -> int:
if next_features is None:
return 0
Expand Down Expand Up @@ -326,8 +339,8 @@ def start_patient(self, patient: meds.Patient, ontology: Optional[femr.ontology.
assert ontology
self.calculator = SurvivalCalculator(ontology, patient, self.pretraining_task_codes)

self.per_patient_censor_time = []
self.per_patient_time_sparse = {
self.per_patient_censor_time: List[float] = []
self.per_patient_time_sparse: Dict[str, List[float]] = {
"data": [],
"indices": [],
"indptr": [0],
Expand All @@ -350,15 +363,18 @@ def add_patient_labels(self, patient_label_offsets: List[int]) -> None:
self.censor_time.extend([self.per_patient_censor_time[i] for i in patient_label_offsets])

for index in patient_label_offsets:
start = self.per_patient_time_sparse['indptr'][index]
end = self.per_patient_time_sparse['indptr'][index + 1]
start = int(self.per_patient_time_sparse["indptr"][index])
end = int(self.per_patient_time_sparse["indptr"][index + 1])

self.time_sparse['data'].extend(self.per_patient_time_sparse['data'][start:end])
self.time_sparse['indices'].extend(self.per_patient_time_sparse['indices'][start:end])
self.time_sparse['indptr'].append(len(self.time_sparse['indices']))
self.time_sparse["data"].extend(self.per_patient_time_sparse["data"][start:end])
self.time_sparse["indices"].extend(self.per_patient_time_sparse["indices"][start:end])
self.time_sparse["indptr"].append(len(self.time_sparse["indices"]))

def add_event(
self, current_date: datetime.datetime, next_date: datetime.datetime, next_features: Sequence[int]
self,
current_date: datetime.datetime,
next_date: Optional[datetime.datetime],
next_features: Optional[Sequence[int]] = None,
) -> int:
if not should_make_survival_prediction(current_date, next_date):
return 0
Expand Down
Loading

0 comments on commit bfcac77

Please sign in to comment.