Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
EthanSteinberg committed Mar 21, 2024
1 parent bfcac77 commit 8bca94f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/femr/models/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def add_patient(self, patient: meds.Patient, offset: int = 0, max_length: Option

per_patient_ages.append((event["time"] - birth) / datetime.timedelta(days=1))
per_patient_normalized_ages.append(self.tokenizer.normalize_age(event["time"] - birth))
per_patient_timestamps.append(event["time"].timestamp())
per_patient_timestamps.append(event["time"].replace(tzinfo=datetime.timezone.utc).timestamp())

last_time = event["time"]

Expand Down
19 changes: 13 additions & 6 deletions src/femr/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,12 +303,19 @@ def forward(self, batch: Mapping[str, Any], return_loss=True, return_logits=Fals
else:
loss = 0
features = features.reshape(-1, features.shape[-1])
features = features[batch["transformer"]["label_indices"], :]
result = {
"timestamps": batch["transformer"]["timestamps"][batch["transformer"]["label_indices"]],
"patient_ids": batch["patient_ids"][batch["transformer"]["label_indices"]],
"representations": features,
}
if 'task' in batch:
features = features[batch["transformer"]["label_indices"], :]
result = {
"timestamps": batch["transformer"]["timestamps"][batch["transformer"]["label_indices"]],
"patient_ids": batch["patient_ids"][batch["transformer"]["label_indices"]],
"representations": features,
}
else:
result = {
"timestamps": batch["transformer"]["timestamps"],
"patient_ids": batch["patient_ids"],
"representations": features,
}

return loss, result

Expand Down

0 comments on commit 8bca94f

Please sign in to comment.