diff --git a/src/femr/models/processor.py b/src/femr/models/processor.py index 52342f07..1d3dc88e 100644 --- a/src/femr/models/processor.py +++ b/src/femr/models/processor.py @@ -201,7 +201,7 @@ def add_patient(self, patient: meds.Patient, offset: int = 0, max_length: Option per_patient_ages.append((event["time"] - birth) / datetime.timedelta(days=1)) per_patient_normalized_ages.append(self.tokenizer.normalize_age(event["time"] - birth)) - per_patient_timestamps.append(event["time"].timestamp()) + per_patient_timestamps.append(event["time"].replace(tzinfo=datetime.timezone.utc).timestamp()) last_time = event["time"] diff --git a/src/femr/models/transformer.py b/src/femr/models/transformer.py index 7b95eccc..b5260e85 100644 --- a/src/femr/models/transformer.py +++ b/src/femr/models/transformer.py @@ -303,12 +303,19 @@ def forward(self, batch: Mapping[str, Any], return_loss=True, return_logits=Fals else: loss = 0 features = features.reshape(-1, features.shape[-1]) - features = features[batch["transformer"]["label_indices"], :] - result = { - "timestamps": batch["transformer"]["timestamps"][batch["transformer"]["label_indices"]], - "patient_ids": batch["patient_ids"][batch["transformer"]["label_indices"]], - "representations": features, - } + if 'task' in batch: + features = features[batch["transformer"]["label_indices"], :] + result = { + "timestamps": batch["transformer"]["timestamps"][batch["transformer"]["label_indices"]], + "patient_ids": batch["patient_ids"][batch["transformer"]["label_indices"]], + "representations": features, + } + else: + result = { + "timestamps": batch["transformer"]["timestamps"], + "patient_ids": batch["patient_ids"], + "representations": features, + } return loss, result