diff --git a/src/aind_session/extensions/ecephys.py b/src/aind_session/extensions/ecephys.py index c6e792d..ace7859 100644 --- a/src/aind_session/extensions/ecephys.py +++ b/src/aind_session/extensions/ecephys.py @@ -1,6 +1,9 @@ from __future__ import annotations +import contextlib import datetime +import itertools +import json import logging from typing import ClassVar, Literal @@ -462,6 +465,128 @@ def current_sorting_computations( ttl_hash=aind_session.utils.get_ttl_hash(1 * 60), ) + @staticmethod + def get_sorter_name(data_asset_id: str | codeocean.data_asset.DataAsset) -> str: + """ + Get the version of the Kilosort pipeline used to create the sorted data asset. + + Tries to find `sorter_name` in the following json files, in order, for any + probe: + - `processing.json` (in root of asset) + - `si_folder.json` (in `spikesorted` dir) + - `sorting.json` (in `postprocessed` dir) + - `params.json` (in root of asset, for older assets) + + Raises `ValueError` if none of the json files exist, or if none contain the + `sorter_name` key, either of which indicates that the asset data is + incomplete due to the sorting pipeline failing for all probes. + + Examples + -------- + `sorter_name` in: + + processing.json['processing_pipeline']['data_processes'][index]['parameters']['sorter_name']: + >>> aind_session.ecephys.get_sorter_name('921a186a-d8ff-4efc-8e1a-891fde8cd394') + 'kilosort2_5' + + processing.json['data_processes'][index]['parameters']['sorter_name']: + >>> aind_session.ecephys.get_sorter_name('205fc2d0-5f00-468f-a82d-47c94afcd40c') + 'kilosort2_5' + + + processing.json['data_processes'][index]['parameters'] has no 'sorter_name' key: + >>> aind_session.ecephys.get_sorter_name('bd0ad804-4a33-4613-9d6c-6281e442bade') + 'kilosort2_5' + + Incomplete data in sorted asset (sorting failed for all probes): + >>> aind_session.ecephys.get_sorter_name('779eb3ca-9652-4476-a283-0022e2d3f1e4') + Traceback (most recent call last): + ... + ValueError: Sorting data are incomplete for + data_asset_id='779eb3ca-9652-4476-a283-0022e2d3f1e4' (pipeline likely failed) - cannot get sorter name + """ + source_dir = aind_session.utils.codeocean_utils.get_data_asset_source_dir( + aind_session.utils.codeocean_utils.get_normalized_uuid(data_asset_id) + ) + + def _get_sorter_name_from_processing_json(source_dir: upath.UPath) -> str: + processing_path = source_dir / "processing.json" + if not processing_path.exists(): + raise FileNotFoundError(f"No 'processing.json' found in {source_dir}") + processing_text = processing_path.read_text() + if "sorter_name:" not in processing_text: + raise KeyError(f"No 'sorter_name' value found in processing.json for {data_asset_id=}") + processing: dict = json.loads(processing_text) + if "processing_pipeline" in processing: + data_processes = processing["processing_pipeline"]["data_processes"] + else: + assert "data_processes" in processing, f"Fix method of getting sorter name: 'data_processes' not in processing.json for {data_asset_id=}" + data_processes = processing["data_processes"] + for p in data_processes: + if isinstance(p, list): + sorting: dict = next( + (d for d in p if d.get("name") == "Spike sorting"), + {}, + ) + break + else: + if p.get("name") == "Spike sorting": + sorting = p + break + else: + raise AssertionError( + f"Fix method of getting sorter name: 'sorter_name' is in processing.json, but not in expected location for {data_asset_id=}" + ) + assert ( + "parameters" in sorting + ), f"Fix method of getting sorter name: 'parameters' not in 'Spike sorting' data process in processing.json for {data_asset_id=}" + if "sorter_name" not in sorting["parameters"]: + raise KeyError(f"No 'sorter_name' key found in sorting parameters in processing.json") + sorter_name: str = sorting["parameters"]["sorter_name"] + logger.debug(f"Found sorter name in processing.json: {sorter_name}") + return sorter_name + + def _get_sorter_name_from_sorted_folders(source_dir: upath.UPath) -> str: + json_paths = [] + for json_path in ( + itertools.chain( + (source_dir / "spikesorted").rglob("si_folder.json"), + (source_dir / "postprocessed").rglob("sorting.json"), + ) + ): + json_paths.append(json_path) + info = json_path.read_text() + if "sorter_name" in info: + sorter_name = json.loads(info)["annotations"]["__sorting_info__"]["params"]["sorter_name"] + logger.debug(f"Found sorter name in {json_path.name}: {sorter_name}") + else: + if not json_paths: + raise FileNotFoundError(f"No 'processing.json', 'si_folder.json', or 'sorting.json' files found - asset {data_asset_id} likely contains incomplete data") + else: + raise KeyError(f"Fix method of getting sorter name: 'sorter_name' not a value in {set(p.name for p in json_paths)} for {data_asset_id=}") + + def _get_sorter_name_from_params_json(source_dir: upath.UPath) -> str: + params_path = source_dir / "params.json" + if not params_path.exists(): + raise FileNotFoundError(f"No 'params.json' found in {source_dir}") + params_text = params_path.read_text() + if "sorter_name" not in params_text: + raise KeyError(f"No 'sorter_name' value found in {params_path.name}") + params: dict = json.loads(params_text) + assert params, f"Fix method of getting sorter name: {params=} for {data_asset_id=}" + assert "spikesorting" in params, f"Fix method of getting sorter name: 'spikesorting' not in {params_path.name} for {data_asset_id=}" + assert "sorter_name" in params["spikesorting"], f"Fix method of getting sorter name: 'sorter_name' not in 'spikesorting' in {params_path.name} for {data_asset_id=}" + sorter_name = params["spikesorting"]["sorter_name"] + logger.debug(f"Found sorter name in params.json: {sorter_name}") + return sorter_name + + with contextlib.suppress(FileNotFoundError, KeyError): + return _get_sorter_name_from_processing_json(source_dir) + with contextlib.suppress(FileNotFoundError, KeyError): + return _get_sorter_name_from_sorted_folders(source_dir) + with contextlib.suppress(FileNotFoundError, KeyError): + return _get_sorter_name_from_params_json(source_dir) + raise ValueError(f"Sorting data are incomplete for {data_asset_id=!r} (pipeline likely failed) - cannot get sorter name") if __name__ == "__main__": from aind_session import testmod