From 3108681349681dc9e5f3141313546f440527fde0 Mon Sep 17 00:00:00 2001 From: Manasa Venkatakrishnan Date: Wed, 11 Dec 2024 17:40:02 -0800 Subject: [PATCH] chore: Supporting dataset level job queuing --- ingestion_tools/scripts/enqueue_runs.py | 96 +++++++++++++++++-------- 1 file changed, 66 insertions(+), 30 deletions(-) diff --git a/ingestion_tools/scripts/enqueue_runs.py b/ingestion_tools/scripts/enqueue_runs.py index 17819610a..36e355503 100644 --- a/ingestion_tools/scripts/enqueue_runs.py +++ b/ingestion_tools/scripts/enqueue_runs.py @@ -360,6 +360,12 @@ def db_import( multiple=False, help="Exclude runs matching this regex. If not specified, all runs are processed", ) +@click.option( + "--queue-datasets/--queue-runs", + type=bool, + default=False, + help="Queue jobs for datasets instead of runs", +) @ingest_common_options @enqueue_common_options @click.pass_context @@ -374,6 +380,7 @@ def queue( force_overwrite: bool, swipe_wdl_key: str, skip_until_run_name: str, + queue_datasets: bool, **kwargs, ): handle_common_options(ctx, kwargs) @@ -396,21 +403,63 @@ def queue( filter_datasets = [re.compile(pattern) for pattern in kwargs.get("filter_dataset_name", [])] exclude_datasets = [re.compile(pattern) for pattern in kwargs.get("exclude_dataset_name", [])] + wdl_args = { + "config_file": config_file, + "input_bucket": input_bucket, + "output_path": output_path, + } + # Always iterate over depostions, datasets and runs. for deposition in DepositionImporter.finder(config): print(f"Processing deposition: {deposition.name}") datasets = DatasetImporter.finder(config, deposition=deposition) - for dataset in datasets: - if list(filter(lambda x: x.match(dataset.name), exclude_datasets)): - print(f"Excluding {dataset.name}..") - continue - if filter_datasets and not list(filter(lambda x: x.match(dataset.name), filter_datasets)): - print(f"Skipping {dataset.name}..") - continue - print(f"Processing dataset: {dataset.name}") - runs = RunImporter.finder(config, dataset=dataset) - futures = [] - with ProcessPoolExecutor(max_workers=ctx.obj["parallelism"]) as workerpool: + futures = [] + + with ProcessPoolExecutor(max_workers=ctx.obj["parallelism"]) as workerpool: + for dataset in datasets: + if list(filter(lambda x: x.match(dataset.name), exclude_datasets)): + print(f"Excluding {dataset.name}..") + continue + if filter_datasets and not list(filter(lambda x: x.match(dataset.name), filter_datasets)): + print(f"Skipping {dataset.name}..") + continue + print(f"Processing dataset: {dataset.name}") + + per_dataset_args = {} + # Don't copy over dataset and run name filters to the queued jobs - they're intended to be + # batched into 1-dataset chunks. + excluded_args = ["filter_dataset_name"] + for k, v in kwargs.items(): + if any(substring in k for substring in excluded_args): + break + per_dataset_args[k] = v + + dataset_id = dataset.name + deposition_id = deposition.name + ds_execution_name = f"{int(time.time())}-dep{deposition_id}-ds{dataset_id}" + if queue_datasets: + new_args = to_args( + import_everything=import_everything, + write_mrc=write_mrc, + write_zarr=write_zarr, + force_overwrite=force_overwrite, + **per_dataset_args, + ) # make a copy + new_args.append(f"--filter-dataset-name '^{dataset.name}$'") + futures.append( + workerpool.submit( + partial( + run_job, + ds_execution_name, + {**{"flags": " ".join(new_args)}, **wdl_args}, + swipe_wdl_key=swipe_wdl_key, + **ctx.obj, + ), + ), + ) + continue + + runs = RunImporter.finder(config, dataset=dataset) for run in runs: if skip_run and not skip_run_until_regex.match(run.name): print(f"Skipping {run.name}..") @@ -425,14 +474,10 @@ def queue( continue print(f"Processing {run.name}...") - per_run_args = {} - # Don't copy over dataset and run name filters to the queued jobs - they're intended to be + per_run_args = dict(per_dataset_args) + # Don't copy over run name filters to the queued jobs - they're intended to be # batched into 1-run chunks. - excluded_args = ["filter_dataset_name", "filter_run_name"] - for k, v in kwargs.items(): - if any(substring in k for substring in excluded_args): - break - per_run_args[k] = v + per_run_args.pop("filter_run_name", None) new_args = to_args( import_everything=import_everything, write_mrc=write_mrc, @@ -443,32 +488,23 @@ def queue( new_args.append(f"--filter-dataset-name '^{dataset.name}$'") new_args.append(f"--filter-run-name '^{run.name}$'") - dataset_id = dataset.name - deposition_id = deposition.name - execution_name = f"{int(time.time())}-dep{deposition_id}-ds{dataset_id}-run{run.name}" + execution_name = f"{int(time.time())}-{ds_execution_name}-run{run.name}" # execution name greater than 80 chars causes boto ValidationException if len(execution_name) > 80: execution_name = execution_name[-80:] - - wdl_args = { - "config_file": config_file, - "input_bucket": input_bucket, - "output_path": output_path, - "flags": " ".join(new_args), - } futures.append( workerpool.submit( partial( run_job, execution_name, - wdl_args, + {**{"flags": " ".join(new_args)}, **wdl_args}, swipe_wdl_key=swipe_wdl_key, **ctx.obj, ), ), ) - wait(futures) + wait(futures) class OrderedSyncFilters(click.Command):