diff --git a/src/femr/models/transformer.py b/src/femr/models/transformer.py index b5260e85..4881976f 100644 --- a/src/femr/models/transformer.py +++ b/src/femr/models/transformer.py @@ -303,7 +303,7 @@ def forward(self, batch: Mapping[str, Any], return_loss=True, return_logits=Fals else: loss = 0 features = features.reshape(-1, features.shape[-1]) - if 'task' in batch: + if "task" in batch: features = features[batch["transformer"]["label_indices"], :] result = { "timestamps": batch["transformer"]["timestamps"][batch["transformer"]["label_indices"]],