diff --git a/cloudbuild_CI.yaml b/cloudbuild_CI.yaml index b81ac9708..5ec4559f8 100644 --- a/cloudbuild_CI.yaml +++ b/cloudbuild_CI.yaml @@ -56,4 +56,4 @@ steps: # - '--gs_dir bashir-variant_integration_test_runs' images: - 'gcr.io/${PROJECT_ID}/gcp-variant-transforms:${COMMIT_SHA}' -timeout: 240m +timeout: 270m diff --git a/gcp_variant_transforms/bq_to_vcf.py b/gcp_variant_transforms/bq_to_vcf.py index 55d6af16f..7fa7147c2 100644 --- a/gcp_variant_transforms/bq_to_vcf.py +++ b/gcp_variant_transforms/bq_to_vcf.py @@ -67,17 +67,25 @@ from gcp_variant_transforms.transforms import bigquery_to_variant from gcp_variant_transforms.transforms import combine_sample_ids from gcp_variant_transforms.transforms import densify_variants +from gcp_variant_transforms.transforms import sample_mapping_table + _BASE_QUERY_TEMPLATE = 'SELECT {COLUMNS} FROM `{INPUT_TABLE}`' _BQ_TO_VCF_SHARDS_JOB_NAME = 'bq-to-vcf-shards' _COMMAND_LINE_OPTIONS = [variant_transform_options.BigQueryToVcfOptions] +TABLE_SUFFIX_SEPARATOR = bigquery_util.TABLE_SUFFIX_SEPARATOR +SAMPLE_INFO_TABLE_SUFFIX = bigquery_util.SAMPLE_INFO_TABLE_SUFFIX _GENOMIC_REGION_TEMPLATE = ('({REFERENCE_NAME_ID}="{REFERENCE_NAME_VALUE}" AND ' '{START_POSITION_ID}>={START_POSITION_VALUE} AND ' '{END_POSITION_ID}<={END_POSITION_VALUE})') _VCF_FIXED_COLUMNS = ['#CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO', 'FORMAT'] -_VCF_VERSION_LINE = '##fileformat=VCFv4.3\n' +_VCF_VERSION_LINE = ( + vcf_header_io.FILE_FORMAT_HEADER_TEMPLATE.format(VERSION='4.3') + '\n') +_SAMPLE_INFO_QUERY_TEMPLATE = ( + 'SELECT sample_id, sample_name, file_path FROM ' + '`{PROJECT_ID}.{DATASET_ID}.{TABLE_NAME}`') def run(argv=None): @@ -164,30 +172,63 @@ def _bigquery_to_vcf_shards( `vcf_header_file_path`. """ schema = _get_schema(known_args.input_table) - # TODO(allieychen): Modify the SQL query with the specified sample_ids. - query = _get_bigquery_query(known_args, schema) - logging.info('Processing BigQuery query %s:', query) - bq_source = bigquery.BigQuerySource(query=query, - validate=True, - use_standard_sql=True) + variant_query = _get_variant_query(known_args, schema) + logging.info('Processing BigQuery query %s:', variant_query) + project_id, dataset_id, table_id = bigquery_util.parse_table_reference( + known_args.input_table) + bq_variant_source = bigquery.BigQuerySource(query=variant_query, + validate=True, + use_standard_sql=True) annotation_names = _extract_annotation_names(schema) + + base_table_id = bigquery_util.get_table_base_name(table_id) + sample_query = _SAMPLE_INFO_QUERY_TEMPLATE.format( + PROJECT_ID=project_id, + DATASET_ID=dataset_id, + TABLE_NAME=bigquery_util.compose_table_name(base_table_id, + SAMPLE_INFO_TABLE_SUFFIX)) + bq_sample_source = bigquery.BigQuerySource(query=sample_query, + validate=True, + use_standard_sql=True) with beam.Pipeline(options=beam_pipeline_options) as p: variants = (p - | 'ReadFromBigQuery ' >> beam.io.Read(bq_source) + | 'ReadFromBigQuery ' >> beam.io.Read(bq_variant_source) | bigquery_to_variant.BigQueryToVariant(annotation_names)) + sample_table_rows = ( + p + | 'ReadFromSampleTable' >> beam.io.Read(bq_sample_source)) if known_args.sample_names: - sample_ids = (p - | transforms.Create(known_args.sample_names, - reshuffle=False) - | beam.combiners.ToList()) + temp_sample_names = (p + | transforms.Create(known_args.sample_names, + reshuffle=False)) else: - sample_ids = (variants - | 'CombineSampleIds' >> - combine_sample_ids.SampleIdsCombiner( - known_args.preserve_sample_order)) - # TODO(tneymanov): Add logic to extract sample names from sample IDs by - # joining with sample id-name mapping table, once that code is implemented. - sample_names = sample_ids + # Get sample names from sample IDs in the variants and sort. + id_to_name_hash_table = ( + sample_table_rows + | 'SampleIdToNameDict' >> sample_mapping_table.SampleIdToNameDict()) + temp_sample_ids = (variants + | 'CombineSampleIds' >> + combine_sample_ids.SampleIdsCombiner( + known_args.preserve_sample_order)) + temp_sample_names = ( + temp_sample_ids + | 'GetSampleNames' >> + sample_mapping_table.GetSampleNames( + beam.pvalue.AsSingleton(id_to_name_hash_table)) + | 'CombineToList' >> beam.combiners.ToList() + | 'SortSampleNames' >> beam.ParDo(sorted)) + + name_to_id_hash_table = ( + sample_table_rows + | 'SampleNameToIdDict' >> sample_mapping_table.SampleNameToIdDict()) + sample_ids = ( + temp_sample_names + | 'GetSampleIds' >> + sample_mapping_table.GetSampleIds( + beam.pvalue.AsSingleton(name_to_id_hash_table)) + | 'CombineSortedSampleIds' >> beam.combiners.ToList()) + sample_names = temp_sample_names | beam.combiners.ToList() + _ = (sample_names | 'GenerateVcfDataHeader' >> beam.ParDo(_write_vcf_header_with_sample_names, @@ -196,7 +237,8 @@ def _bigquery_to_vcf_shards( header_file_path)) _ = (variants - | densify_variants.DensifyVariants(beam.pvalue.AsSingleton(sample_ids)) + | densify_variants.DensifyVariants( + beam.pvalue.AsSingleton(sample_ids)) | 'PairVariantWithKey' >> beam.Map(_pair_variant_with_key, known_args.number_of_bases_per_shard) | 'GroupVariantsByKey' >> beam.GroupByKey() @@ -216,7 +258,7 @@ def _get_schema(input_table): return table.schema -def _get_bigquery_query(known_args, schema): +def _get_variant_query(known_args, schema): # type: (argparse.Namespace, bigquery_v2.TableSchema) -> str """Returns a BigQuery query for the interested regions.""" columns = _get_query_columns(schema) diff --git a/gcp_variant_transforms/bq_to_vcf_test.py b/gcp_variant_transforms/bq_to_vcf_test.py index 2f0659b70..4f5640d89 100644 --- a/gcp_variant_transforms/bq_to_vcf_test.py +++ b/gcp_variant_transforms/bq_to_vcf_test.py @@ -61,7 +61,7 @@ def test_write_vcf_data_header(self): content = f.readlines() self.assertEqual(content, expected_content) - def test_get_bigquery_query_no_region(self): + def test_get_variant_query_no_region(self): args = self._create_mock_args( input_table='my_bucket:my_dataset.my_table', genomic_regions=None) @@ -71,11 +71,11 @@ def test_get_bigquery_query_no_region(self): type=bigquery_util.TableFieldConstants.TYPE_STRING, mode=bigquery_util.TableFieldConstants.MODE_NULLABLE, description='Reference name.')) - self.assertEqual(bq_to_vcf._get_bigquery_query(args, schema), + self.assertEqual(bq_to_vcf._get_variant_query(args, schema), 'SELECT reference_name FROM ' '`my_bucket.my_dataset.my_table`') - def test_get_bigquery_query_with_regions(self): + def test_get_variant_query_with_regions(self): args_1 = self._create_mock_args( input_table='my_bucket:my_dataset.my_table', genomic_regions=['c1:1,000-2,000', 'c2']) @@ -98,7 +98,7 @@ def test_get_bigquery_query_with_regions(self): 'OR (reference_name="c2" AND start_position>=0 AND ' 'end_position<=9223372036854775807)' ) - self.assertEqual(bq_to_vcf._get_bigquery_query(args_1, schema), + self.assertEqual(bq_to_vcf._get_variant_query(args_1, schema), expected_query) def test_get_query_columns(self): diff --git a/gcp_variant_transforms/libs/bigquery_util.py b/gcp_variant_transforms/libs/bigquery_util.py index a7acd2be8..3bf1d16af 100644 --- a/gcp_variant_transforms/libs/bigquery_util.py +++ b/gcp_variant_transforms/libs/bigquery_util.py @@ -41,6 +41,8 @@ _MAX_BQ_NUM_PARTITIONS = 4000 _RANGE_END_SIG_DIGITS = 4 _RANGE_INTERVAL_SIG_DIGITS = 1 +_TOTAL_BASE_PAIRS_SIG_DIGITS = 4 +_PARTITION_SIZE_SIG_DIGITS = 1 START_POSITION_COLUMN = 'start_position' _BQ_CREATE_PARTITIONED_TABLE_COMMAND = ( diff --git a/gcp_variant_transforms/options/variant_transform_options.py b/gcp_variant_transforms/options/variant_transform_options.py index 0edc7216d..fc75a5c25 100644 --- a/gcp_variant_transforms/options/variant_transform_options.py +++ b/gcp_variant_transforms/options/variant_transform_options.py @@ -24,6 +24,9 @@ from gcp_variant_transforms.libs import bigquery_util from gcp_variant_transforms.libs import variant_sharding +TABLE_SUFFIX_SEPARATOR = bigquery_util.TABLE_SUFFIX_SEPARATOR +SAMPLE_INFO_TABLE_SUFFIX = bigquery_util.SAMPLE_INFO_TABLE_SUFFIX + class VariantTransformsOptions(object): """Base class for defining groups of options for Variant Transforms. @@ -215,8 +218,7 @@ def validate(self, parsed_args, client=None): dataset_id) all_output_tables = [] all_output_tables.append( - bigquery_util.compose_table_name( - table_id, bigquery_util.SAMPLE_INFO_TABLE_SUFFIX)) + bigquery_util.compose_table_name(table_id, SAMPLE_INFO_TABLE_SUFFIX)) sharding = variant_sharding.VariantSharding( parsed_args.sharding_config_path) num_shards = sharding.get_num_shards() @@ -587,6 +589,31 @@ def add_arguments(self, parser): 'extracted variants to have the same sample ordering (usually ' 'true for tables from single VCF file import).')) + def validate(self, parsed_args, client=None): + if not client: + credentials = GoogleCredentials.get_application_default().create_scoped( + ['https://www.googleapis.com/auth/bigquery']) + client = bigquery.BigqueryV2(credentials=credentials) + + project_id, dataset_id, table_id = bigquery_util.parse_table_reference( + parsed_args.input_table) + if not bigquery_util.table_exist(client, project_id, dataset_id, table_id): + raise ValueError('Table {}:{}.{} does not exist.'.format( + project_id, dataset_id, table_id)) + if table_id.count(TABLE_SUFFIX_SEPARATOR) != 1: + raise ValueError( + 'Input table {} is malformed - exactly one suffix separator "{}" is ' + 'required'.format(parsed_args.input_table, + TABLE_SUFFIX_SEPARATOR)) + base_table_id = table_id[:table_id.find(TABLE_SUFFIX_SEPARATOR)] + sample_table_id = bigquery_util.compose_table_name(base_table_id, + SAMPLE_INFO_TABLE_SUFFIX) + + if not bigquery_util.table_exist(client, project_id, dataset_id, + sample_table_id): + raise ValueError('Sample table {}:{}.{} does not exist.'.format( + project_id, dataset_id, sample_table_id)) + def _validate_inputs(parsed_args): if ((parsed_args.input_pattern and parsed_args.input_file) or diff --git a/gcp_variant_transforms/testing/asserts.py b/gcp_variant_transforms/testing/asserts.py index 006e4ed5c..46c9d1a09 100644 --- a/gcp_variant_transforms/testing/asserts.py +++ b/gcp_variant_transforms/testing/asserts.py @@ -59,6 +59,20 @@ def _has_sample_ids(variants): return _has_sample_ids +def dict_values_equal(expected_dict): + """Verifies that dictionary is the same as expected.""" + def _items_equal(actual_dict): + actual = actual_dict[0] + for k in expected_dict: + if k not in actual or expected_dict[k] != actual[k]: + raise BeamAssertException( + 'Failed assert: %s == %s' % (expected_dict, actual)) + if len(expected_dict) != len(actual): + raise BeamAssertException( + 'Failed assert: %s == %s' % (expected_dict, actual)) + return _items_equal + + def header_vars_equal(expected): def _vars_equal(actual): expected_vars = [vars(header) for header in expected] diff --git a/gcp_variant_transforms/testing/data/vcf/bq_to_vcf/expected_output/no_options.vcf b/gcp_variant_transforms/testing/data/vcf/bq_to_vcf/expected_output/no_options.vcf index 6e249a927..05a2e10ba 100644 --- a/gcp_variant_transforms/testing/data/vcf/bq_to_vcf/expected_output/no_options.vcf +++ b/gcp_variant_transforms/testing/data/vcf/bq_to_vcf/expected_output/no_options.vcf @@ -5,8 +5,8 @@ ##INFO= ##INFO= ##INFO= -##FORMAT= ##FORMAT= +##FORMAT= ##FORMAT= #CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001 NA00002 NA00003 19 1234567 microsat1 GTCT G,GTACT 50.0 PASS AA=G;NS=3;DP=9 GT:DP:GQ 0/1:4:35 0/2:2:17 1/1:3:40 diff --git a/gcp_variant_transforms/testing/data/vcf/bq_to_vcf/expected_output/option_allow_incompatible_schema.vcf b/gcp_variant_transforms/testing/data/vcf/bq_to_vcf/expected_output/option_allow_incompatible_schema.vcf index 255c5bfa5..c39e793b0 100644 --- a/gcp_variant_transforms/testing/data/vcf/bq_to_vcf/expected_output/option_allow_incompatible_schema.vcf +++ b/gcp_variant_transforms/testing/data/vcf/bq_to_vcf/expected_output/option_allow_incompatible_schema.vcf @@ -6,8 +6,8 @@ ##INFO= ##INFO= ##INFO= -##FORMAT= ##FORMAT= +##FORMAT= ##FORMAT= ##FORMAT= #CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001 NA00002 NA00003 diff --git a/gcp_variant_transforms/testing/data/vcf/bq_to_vcf/expected_output/option_customized_export.vcf b/gcp_variant_transforms/testing/data/vcf/bq_to_vcf/expected_output/option_customized_export.vcf index 1ee3068b5..a344cda42 100644 --- a/gcp_variant_transforms/testing/data/vcf/bq_to_vcf/expected_output/option_customized_export.vcf +++ b/gcp_variant_transforms/testing/data/vcf/bq_to_vcf/expected_output/option_customized_export.vcf @@ -5,8 +5,8 @@ ##INFO= ##INFO= ##INFO= -##FORMAT= ##FORMAT= +##FORMAT= ##FORMAT= #CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001 NA00003 19 1234567 microsat1 GTCT G,GTACT 50.0 PASS AA=G;NS=3;DP=9 GT:DP:GQ 0/1:4:35 1/1:3:40 diff --git a/gcp_variant_transforms/testing/data/vcf/bq_to_vcf/expected_output/option_preserve_sample_order.vcf b/gcp_variant_transforms/testing/data/vcf/bq_to_vcf/expected_output/option_preserve_sample_order.vcf index 029479dbc..0ae4a7a46 100644 --- a/gcp_variant_transforms/testing/data/vcf/bq_to_vcf/expected_output/option_preserve_sample_order.vcf +++ b/gcp_variant_transforms/testing/data/vcf/bq_to_vcf/expected_output/option_preserve_sample_order.vcf @@ -5,8 +5,8 @@ ##INFO= ##INFO= ##INFO= -##FORMAT= ##FORMAT= +##FORMAT= ##FORMAT= #CHROM POS ID REF ALT QUAL FILTER INFO FORMAT NA00001 NA00002 NA00003 NA00004 19 1234567 microsat1;microsat2 GTCT G,GTACT 50.0 PASS AA=G;NS=2;DP=9 GT:DP:GQ 0/1:4:35 0/2:2:17 1/1:3:40 .:.:. diff --git a/gcp_variant_transforms/testing/integration/bq_to_vcf_tests/no_options.json b/gcp_variant_transforms/testing/integration/bq_to_vcf_tests/no_options.json index ffaaaf1fa..eea5b326a 100644 --- a/gcp_variant_transforms/testing/integration/bq_to_vcf_tests/no_options.json +++ b/gcp_variant_transforms/testing/integration/bq_to_vcf_tests/no_options.json @@ -1,7 +1,7 @@ [ { "test_name": "bq-to-vcf-no-options", - "input_table": "gcp-variant-transforms-test:bq_to_vcf_integration_tests.4_0_new_schema", + "input_table": "gcp-variant-transforms-test:bq_to_vcf_integration_tests.4_0__suffix", "output_file_name": "bq_to_vcf_no_options.vcf", "runner": "DirectRunner", "expected_output_file": "gcp_variant_transforms/testing/data/vcf/bq_to_vcf/expected_output/no_options.vcf" diff --git a/gcp_variant_transforms/testing/integration/bq_to_vcf_tests/option_allow_incompatible_schema.json b/gcp_variant_transforms/testing/integration/bq_to_vcf_tests/option_allow_incompatible_schema.json index 90d626a7e..c32d05c44 100644 --- a/gcp_variant_transforms/testing/integration/bq_to_vcf_tests/option_allow_incompatible_schema.json +++ b/gcp_variant_transforms/testing/integration/bq_to_vcf_tests/option_allow_incompatible_schema.json @@ -1,7 +1,7 @@ [ { "test_name": "bq-to-vcf-option-allow-incompatible-schema", - "input_table": "gcp-variant-transforms-test:bq_to_vcf_integration_tests.4_2_new_schema", + "input_table": "gcp-variant-transforms-test:bq_to_vcf_integration_tests.4_2__suffix", "output_file_name": "bq_to_vcf_option_allow_incompatible_schema.vcf", "allow_incompatible_schema": true, "runner": "DirectRunner", diff --git a/gcp_variant_transforms/testing/integration/bq_to_vcf_tests/option_customized_export.json b/gcp_variant_transforms/testing/integration/bq_to_vcf_tests/option_customized_export.json index c69663c54..ccf0119bb 100644 --- a/gcp_variant_transforms/testing/integration/bq_to_vcf_tests/option_customized_export.json +++ b/gcp_variant_transforms/testing/integration/bq_to_vcf_tests/option_customized_export.json @@ -1,7 +1,7 @@ [ { "test_name": "bq-to-vcf-option-customized-export", - "input_table": "gcp-variant-transforms-test:bq_to_vcf_integration_tests.4_0_new_schema", + "input_table": "gcp-variant-transforms-test:bq_to_vcf_integration_tests.4_0__suffix", "output_file_name": "bq_to_vcf_option_customized_export.vcf", "genomic_regions": "19:1234566-1234570 20:14369-17330", "sample_names": "NA00001 NA00003", diff --git a/gcp_variant_transforms/testing/integration/bq_to_vcf_tests/option_number_of_bases_per_shard.json b/gcp_variant_transforms/testing/integration/bq_to_vcf_tests/option_number_of_bases_per_shard.json index 7702d6ab2..20786523d 100644 --- a/gcp_variant_transforms/testing/integration/bq_to_vcf_tests/option_number_of_bases_per_shard.json +++ b/gcp_variant_transforms/testing/integration/bq_to_vcf_tests/option_number_of_bases_per_shard.json @@ -1,10 +1,10 @@ [ { "test_name": "bq-to-vcf-option-number-of-bases-per-shard", - "input_table": "gcp-variant-transforms-test:bq_to_vcf_integration_tests.platinum_NA12877_hg38_10K_lines_new_schema", + "input_table": "gcp-variant-transforms-test:bq_to_vcf_integration_tests.platinum_NA12877_hg38_10K_lines__suffix", "output_file_name": "bq_to_vcf_option_number_of_bases_per_shard.vcf", "number_of_bases_per_shard": 100000, "runner": "DataflowRunner", - "expected_output_file": "gs://gcp-variant-transforms-testfiles/bq_to_vcf_expected_output/platinum_NA12877_hg38_10K_lines.vcf" + "expected_output_file": "gs://gcp-variant-transforms-testfiles/bq_to_vcf_expected_output/platinum_NA12877_hg38_10K_lines_v2.vcf" } ] diff --git a/gcp_variant_transforms/testing/integration/bq_to_vcf_tests/option_preserve_sample_order.json b/gcp_variant_transforms/testing/integration/bq_to_vcf_tests/option_preserve_sample_order.json index 7b4cab18e..45842cc97 100644 --- a/gcp_variant_transforms/testing/integration/bq_to_vcf_tests/option_preserve_sample_order.json +++ b/gcp_variant_transforms/testing/integration/bq_to_vcf_tests/option_preserve_sample_order.json @@ -1,7 +1,7 @@ [ { "test_name": "bq-to-vcf-option-preserve-call-names-order", - "input_table": "gcp-variant-transforms-test:bq_to_vcf_integration_tests.merge_option_move_to_calls_new_schema", + "input_table": "gcp-variant-transforms-test:bq_to_vcf_integration_tests.merge_option_move_to_calls__suffix", "output_file_name": "bq_to_vcf_option_preserve_sample_order.vcf", "preserve_sample_order": false, "runner": "DirectRunner", diff --git a/gcp_variant_transforms/testing/integration/bq_to_vcf_tests/option_representative_header_file.json b/gcp_variant_transforms/testing/integration/bq_to_vcf_tests/option_representative_header_file.json index 7a415d098..a660866ef 100644 --- a/gcp_variant_transforms/testing/integration/bq_to_vcf_tests/option_representative_header_file.json +++ b/gcp_variant_transforms/testing/integration/bq_to_vcf_tests/option_representative_header_file.json @@ -1,7 +1,7 @@ [ { "test_name": "bq-to-vcf-option-representative-header-file", - "input_table": "gcp-variant-transforms-test:bq_to_vcf_integration_tests.4_0_new_schema", + "input_table": "gcp-variant-transforms-test:bq_to_vcf_integration_tests.4_0__suffix", "representative_header_file": "gs://gcp-variant-transforms-testfiles/small_tests/valid-4.0.vcf", "preserve_sample_order": true, "output_file_name": "bq_to_vcf_option_representative_header_file.vcf", diff --git a/gcp_variant_transforms/transforms/combine_sample_ids.py b/gcp_variant_transforms/transforms/combine_sample_ids.py index cf57a2ac0..3b19708bd 100644 --- a/gcp_variant_transforms/transforms/combine_sample_ids.py +++ b/gcp_variant_transforms/transforms/combine_sample_ids.py @@ -70,12 +70,10 @@ def expand(self, pcoll): | 'RemoveDuplicates' >> beam.RemoveDuplicates() | 'Combine' >> beam.combiners.ToList() | 'ExtractUniqueSampleIds' - >> beam.ParDo(self._extract_unique_sample_ids) - | beam.combiners.ToList()) + >> beam.ParDo(self._extract_unique_sample_ids)) else: return (pcoll | 'GetSampleIds' >> beam.FlatMap(self._get_sample_ids) | 'RemoveDuplicates' >> beam.RemoveDuplicates() | 'Combine' >> beam.combiners.ToList() - | 'SortSampleIds' >> beam.ParDo(sorted) - | beam.combiners.ToList()) + | 'SortSampleIds' >> beam.ParDo(sorted)) diff --git a/gcp_variant_transforms/transforms/combine_sample_ids_test.py b/gcp_variant_transforms/transforms/combine_sample_ids_test.py index 601734861..85420a7ee 100644 --- a/gcp_variant_transforms/transforms/combine_sample_ids_test.py +++ b/gcp_variant_transforms/transforms/combine_sample_ids_test.py @@ -16,6 +16,7 @@ import unittest +from apache_beam import combiners from apache_beam import transforms from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that @@ -48,8 +49,8 @@ def test_sample_ids_combiner_pipeline_preserve_sample_order_error(self): pipeline | transforms.Create(variants) | 'CombineSampleIds' >> - combine_sample_ids.SampleIdsCombiner( - preserve_sample_order=True)) + combine_sample_ids.SampleIdsCombiner(preserve_sample_order=True) + | combiners.ToList()) with self.assertRaises(ValueError): pipeline.run() @@ -76,8 +77,8 @@ def test_sample_ids_combiner_pipeline_preserve_sample_order(self): pipeline | transforms.Create(variants) | 'CombineSampleIds' >> - combine_sample_ids.SampleIdsCombiner( - preserve_sample_order=True)) + combine_sample_ids.SampleIdsCombiner(preserve_sample_order=True) + | combiners.ToList()) assert_that(combined_sample_ids, equal_to([sample_ids])) pipeline.run() @@ -99,8 +100,8 @@ def test_sample_ids_combiner_pipeline(self): combined_sample_ids = ( pipeline | transforms.Create(variants) - | 'CombineSampleIds' >> - combine_sample_ids.SampleIdsCombiner()) + | 'CombineSampleIds' >> combine_sample_ids.SampleIdsCombiner() + | combiners.ToList()) assert_that(combined_sample_ids, equal_to([sample_ids])) pipeline.run() @@ -112,7 +113,7 @@ def test_sample_ids_combiner_pipeline_duplicate_sample_ids(self): _ = ( pipeline | transforms.Create(variants) - | 'CombineSampleIds' >> - combine_sample_ids.SampleIdsCombiner()) + | 'CombineSampleIds' >> combine_sample_ids.SampleIdsCombiner() + | combiners.ToList()) with self.assertRaises(ValueError): pipeline.run() diff --git a/gcp_variant_transforms/transforms/sample_mapping_table.py b/gcp_variant_transforms/transforms/sample_mapping_table.py new file mode 100644 index 000000000..6dd5e2f85 --- /dev/null +++ b/gcp_variant_transforms/transforms/sample_mapping_table.py @@ -0,0 +1,88 @@ +# Copyright 2020 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A PTransform to convert BigQuery table rows to a PCollection of `Variant`.""" + +from typing import Dict, List # pylint: disable=unused-import + +import apache_beam as beam + +from gcp_variant_transforms.libs import sample_info_table_schema_generator + +SAMPLE_ID_COLUMN = sample_info_table_schema_generator.SAMPLE_ID +SAMPLE_NAME_COLUMN = sample_info_table_schema_generator.SAMPLE_NAME + + +class SampleIdToNameDict(beam.PTransform): + """Generate Id-to-Name hashing table from sample info table.""" + + def _extract_id_name(self, row): + sample_id = row[SAMPLE_ID_COLUMN] + sample_name = row[SAMPLE_NAME_COLUMN] + return (sample_id, sample_name) + + def expand(self, pcoll): + return (pcoll + | 'ExtractIdNameTuples' >> beam.Map(self._extract_id_name) + | 'CombineToDict' >> beam.combiners.ToDict()) + + +class SampleNameToIdDict(beam.PTransform): + """Generate Name-to-ID hashing table from sample info table.""" + + def _extract_id_name(self, row): + sample_id = row[SAMPLE_ID_COLUMN] + sample_name = row[SAMPLE_NAME_COLUMN] + return (sample_name, sample_id) + + def expand(self, pcoll): + return (pcoll + | 'ExtractNameIdTuples' >> beam.Map(self._extract_id_name) + | 'CombineToDict' >> beam.combiners.ToDict()) + +class GetSampleNames(beam.PTransform): + """Looks up sample_names corresponding to the given sample_ids""" + + def __init__(self, id_to_name_dict): + # type: (Dict[int, str]) -> None + self._id_to_name_dict = id_to_name_dict + + def _get_sample_name(self, sample_id, id_to_name_dict): + # type: (int, Dict[int, str]) -> str + if sample_id in id_to_name_dict: + return id_to_name_dict[sample_id] + raise ValueError('Sample ID `{}` was not found.'.format(sample_id)) + + def expand(self, pcoll): + return (pcoll + | 'Generate Name to ID Mapping' + >> beam.Map(self._get_sample_name, self._id_to_name_dict)) + +class GetSampleIds(beam.PTransform): + """Looks up sample_ids corresponding to the given sample_names""" + + def __init__(self, name_to_id_dict): + # type: (Dict[str, int)]) -> None + self._name_to_id_dict = name_to_id_dict + + def _get_sample_id(self, sample_name, name_to_id_dict): + # type: (str, Dict[str, int]) -> int + if sample_name in name_to_id_dict: + return name_to_id_dict[sample_name] + raise ValueError('Sample `{}` was not found.'.format(sample_name)) + + def expand(self, pcoll): + return (pcoll + | 'Generate Name to ID Mapping' + >> beam.Map(self._get_sample_id, self._name_to_id_dict)) diff --git a/gcp_variant_transforms/transforms/sample_mapping_table_test.py b/gcp_variant_transforms/transforms/sample_mapping_table_test.py new file mode 100644 index 000000000..5d2d0f242 --- /dev/null +++ b/gcp_variant_transforms/transforms/sample_mapping_table_test.py @@ -0,0 +1,117 @@ +# Copyright 2020 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for densify_variants module.""" + +from __future__ import absolute_import + +import unittest + +from apache_beam import combiners +from apache_beam.pvalue import AsSingleton +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.transforms import Create + +from gcp_variant_transforms.testing import asserts +from gcp_variant_transforms.transforms.sample_mapping_table import GetSampleIds +from gcp_variant_transforms.transforms.sample_mapping_table import GetSampleNames +from gcp_variant_transforms.transforms.sample_mapping_table import SAMPLE_ID_COLUMN +from gcp_variant_transforms.transforms.sample_mapping_table import SAMPLE_NAME_COLUMN +from gcp_variant_transforms.transforms.sample_mapping_table import SampleIdToNameDict +from gcp_variant_transforms.transforms.sample_mapping_table import SampleNameToIdDict + + + +def _generate_bq_row(sample_id, sample_name): + return {SAMPLE_ID_COLUMN: sample_id, + SAMPLE_NAME_COLUMN: sample_name} + +BQ_ROWS = [_generate_bq_row(1, 'N01'), + _generate_bq_row(2, 'N02'), + _generate_bq_row(3, 'N03')] + +class SampleIdToNameDictTest(unittest.TestCase): + """Test cases for the ``SampleTableToDict`` transform.""" + + def test_sample_table_to_dict(self): + expected_dict = {1: 'N01', + 2: 'N02', + 3: 'N03'} + + pipeline = TestPipeline() + hash_table = ( + pipeline + | Create(BQ_ROWS) + | 'GenerateHashTable' >> SampleIdToNameDict()) + assert_that(hash_table, asserts.dict_values_equal(expected_dict)) + pipeline.run() + +class SampleNameToIdDictTest(unittest.TestCase): + """Test cases for the ``SampleTableToDict`` transform.""" + + def test_sample_table_to_dict(self): + expected_dict = {'N01': 1, 'N02': 2, 'N03': 3} + + pipeline = TestPipeline() + hash_table = ( + pipeline + | Create(BQ_ROWS) + | 'GenerateHashTable' >> SampleNameToIdDict()) + + assert_that(hash_table, asserts.dict_values_equal(expected_dict)) + + pipeline.run() + +class GetSampleNamesTest(unittest.TestCase): + """Test cases for the ``SampleTableToDict`` transform.""" + + def test_get_sample_names(self): + hash_dict = {1: 'N01', 2: 'N02', 3: 'N03', 4: 'N04'} + sample_ids = [1, 2, 3, 4] + expected_sample_names = ['N01', 'N02', 'N03', 'N04'] + + pipeline = TestPipeline() + hash_dict_pc = ( + pipeline + | 'CreateHashDict' >> Create(hash_dict) + | combiners.ToDict()) + sample_names = ( + pipeline + | Create(sample_ids) + | 'GetSampleNames' >> GetSampleNames(AsSingleton(hash_dict_pc))) + + assert_that(sample_names, asserts.items_equal(expected_sample_names)) + pipeline.run() + +class GetSampleIdsTest(unittest.TestCase): + """Test cases for the ``SampleTableToDict`` transform.""" + + def test_get_sample_ids(self): + hash_dict = {'N01': 1, 'N02': 2, 'N03': 3, 'N04': 4} + sample_names = ['N01', 'N02', 'N03', 'N04'] + expected_sample_ids = [1, 2, 3, 4] + + pipeline = TestPipeline() + hash_dict_pc = ( + pipeline + | 'CreateHashDict' >> Create(hash_dict) + | combiners.ToDict()) + sample_ids = ( + pipeline + | Create(sample_names) + | 'GetSampleNames' >> GetSampleIds(AsSingleton(hash_dict_pc))) + + assert_that(sample_ids, asserts.items_equal(expected_sample_ids)) + pipeline.run() diff --git a/gcp_variant_transforms/vcf_to_bq.py b/gcp_variant_transforms/vcf_to_bq.py index 44d2e6edd..222aacd7e 100644 --- a/gcp_variant_transforms/vcf_to_bq.py +++ b/gcp_variant_transforms/vcf_to_bq.py @@ -187,7 +187,8 @@ def _shard_variants(known_args, pipeline_args, pipeline_mode): known_args.all_patterns, p, known_args, pipeline_mode) sample_ids = (variants | 'CombineSampleIds' >> - combine_sample_ids.SampleIdsCombiner()) + combine_sample_ids.SampleIdsCombiner() + | 'CombineToList' >> beam.combiners.ToList()) # TODO(tneymanov): Annotation pipeline currently stores sample IDs instead # of sample names in the the sharded VCF files, which would lead to double # hashing of samples. Needs to be fixed ASAP.