diff --git a/apiv2/db_import/common/config.py b/apiv2/db_import/common/config.py index d838f78e1..04424da2e 100644 --- a/apiv2/db_import/common/config.py +++ b/apiv2/db_import/common/config.py @@ -54,21 +54,27 @@ def load_deposition_map(self) -> None: self.deposition_map[item.id] = item @lru_cache # noqa - def get_alignment_by_path(self, path: str) -> int | None: + def get_alignment_by_path(self, path: str, run_id: int) -> int | None: session = self.get_db_session() - for item in session.scalars( - sa.select(models.Alignment).where(models.Alignment.s3_alignment_metadata == path), - ).all(): + where_clauses = [ + models.Alignment.run_id == run_id, + models.Alignment.s3_alignment_metadata == path, + ] + for item in session.scalars(sa.select(models.Alignment).where(sa.and_(*where_clauses))).all(): return item.id @lru_cache # noqa - def get_tiltseries_by_path(self, path: str) -> int | None: + def get_tiltseries_by_path(self, path: str, run_id: int) -> int | None: session = self.get_db_session() # '_' is a wildcard character in sql LIKE queries, so we need to escape them! escaped_path = os.path.dirname(path).replace("_", "\\_") path = os.path.join(self.s3_prefix, escaped_path, "%") try: - item = session.scalars(sa.select(models.Tiltseries).where(models.Tiltseries.s3_mrc_file.like(path))).one() + where_clauses = [ + models.Tiltseries.run_id == run_id, + models.Tiltseries.s3_mrc_file.like(path), + ] + item = session.scalars(sa.select(models.Tiltseries).where(sa.and_(*where_clauses))).one() return item.id except NoResultFound: # We have a few runs that (erroneously) are missing tiltseries. diff --git a/apiv2/db_import/importers/alignment.py b/apiv2/db_import/importers/alignment.py index 404195c74..986c324bd 100644 --- a/apiv2/db_import/importers/alignment.py +++ b/apiv2/db_import/importers/alignment.py @@ -46,12 +46,16 @@ class AlignmentItem(ItemDBImporter): model_class = models.Alignment def load_computed_fields(self): + run_id = self.input_data["run"].id if self.model_args["alignment_method"] == "undefined": self.model_args["alignment_method"] = None self.model_args["s3_alignment_metadata"] = self.get_s3_url(self.input_data["file"]) self.model_args["https_alignment_metadata"] = self.get_https_url(self.input_data["file"]) if self.input_data.get("tiltseries_path"): - self.model_args["tiltseries_id"] = self.config.get_tiltseries_by_path(self.input_data["tiltseries_path"]) + self.model_args["tiltseries_id"] = self.config.get_tiltseries_by_path( + self.input_data["tiltseries_path"], + run_id, + ) self.model_args["run_id"] = self.input_data["run"].id diff --git a/apiv2/db_import/importers/annotation.py b/apiv2/db_import/importers/annotation.py index 7a44fd552..0d45bdee6 100644 --- a/apiv2/db_import/importers/annotation.py +++ b/apiv2/db_import/importers/annotation.py @@ -75,7 +75,9 @@ def load_computed_fields(self): self.model_args["annotation_shape_id"] = self.input_data["annotation_shape"].id self.model_args["tomogram_voxel_spacing_id"] = self.input_data["tomogram_voxel_spacing"].id if alignment_path: - self.model_args["alignment_id"] = self.config.get_alignment_by_path(alignment_path) + self.model_args["alignment_id"] = self.config.get_alignment_by_path( + alignment_path, self.input_data["run"].id, + ) self.model_args["source"] = self.calculate_source() self.model_args["s3_path"] = self.get_s3_url(self.input_data["path"]) self.model_args["https_path"] = self.get_https_url(self.input_data["path"]) diff --git a/apiv2/db_import/importers/tomogram.py b/apiv2/db_import/importers/tomogram.py index 5639a9630..395a14ffb 100644 --- a/apiv2/db_import/importers/tomogram.py +++ b/apiv2/db_import/importers/tomogram.py @@ -53,10 +53,11 @@ def generate_neuroglancer_data(self, path) -> str | None: def load_computed_fields(self): https_prefix = self.config.https_prefix + run_id = self.input_data["run"].id extra_data = { "ctf_corrected": bool(self.input_data.get("ctf_corrected")), "tomogram_voxel_spacing_id": self.input_data["tomogram_voxel_spacing"].id, - "run_id": self.input_data["run"].id, + "run_id": run_id, "fiducial_alignment_status": normalize_fiducial_alignment( self.input_data.get("fiducial_alignment_status", False), ), @@ -68,7 +69,7 @@ def load_computed_fields(self): "https_mrc_file": self.get_https_url(self.input_data["mrc_file"]), # TODO: Add alignment_id once we have an alignment importer. "alignment_id": self.config.get_alignment_by_path( - self.get_s3_url(self.input_data["alignment_metadata_path"]), + self.get_s3_url(self.input_data["alignment_metadata_path"]), run_id, ), "key_photo_url": None, "key_photo_thumbnail_url": None,