diff --git a/cloudbuild_CI.yaml b/cloudbuild_CI.yaml index f102e6646..f1581502b 100644 --- a/cloudbuild_CI.yaml +++ b/cloudbuild_CI.yaml @@ -43,7 +43,6 @@ steps: - '--image_tag ${COMMIT_SHA}' - '--run_unit_tests' - '--run_preprocessor_tests' - - '--run_bq_to_vcf_tests' - '--run_all_tests' - '--test_name_prefix cloud-ci-' id: 'test-gcp-variant-transforms-docker' diff --git a/gcp_variant_transforms/bq_to_vcf.py b/gcp_variant_transforms/bq_to_vcf.py index 55d6af16f..e18496428 100644 --- a/gcp_variant_transforms/bq_to_vcf.py +++ b/gcp_variant_transforms/bq_to_vcf.py @@ -67,14 +67,19 @@ from gcp_variant_transforms.transforms import bigquery_to_variant from gcp_variant_transforms.transforms import combine_sample_ids from gcp_variant_transforms.transforms import densify_variants +from gcp_variant_transforms.transforms import sample_mapping_table -_BASE_QUERY_TEMPLATE = 'SELECT {COLUMNS} FROM `{INPUT_TABLE}`' + +_BASE_QUERY_TEMPLATE = 'SELECT {COLUMNS} FROM `{INPUT_TABLE}__{CHROM}`' _BQ_TO_VCF_SHARDS_JOB_NAME = 'bq-to-vcf-shards' _COMMAND_LINE_OPTIONS = [variant_transform_options.BigQueryToVcfOptions] -_GENOMIC_REGION_TEMPLATE = ('({REFERENCE_NAME_ID}="{REFERENCE_NAME_VALUE}" AND ' - '{START_POSITION_ID}>={START_POSITION_VALUE} AND ' +_FULL_INPUT_TABLE = '{TABLE}__{SUFFIX}' +_GENOMIC_REGION_TEMPLATE = ('({START_POSITION_ID}>={START_POSITION_VALUE} AND ' '{END_POSITION_ID}<={END_POSITION_VALUE})') +_SAMPLE_INFO_QUERY_TEMPLATE = ( + 'SELECT sample_id, sample_name, file_path ' + 'FROM `{PROJECT_ID}.{DATASET_ID}.{TABLE_ID}_sample_info`') _VCF_FIXED_COLUMNS = ['#CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO', 'FORMAT'] _VCF_VERSION_LINE = '##fileformat=VCFv4.3\n' @@ -117,7 +122,8 @@ def run(argv=None): '{}_meta_info.vcf'.format(unique_temp_id)) _write_vcf_meta_info(known_args.input_table, known_args.representative_header_file, - known_args.allow_incompatible_schema) + known_args.allow_incompatible_schema, + known_args.genomic_region) _bigquery_to_vcf_shards(known_args, options, @@ -136,12 +142,13 @@ def run(argv=None): def _write_vcf_meta_info(input_table, representative_header_file, - allow_incompatible_schema): - # type: (str, str, bool) -> None + allow_incompatible_schema, + genomic_region): + # type: (str, str, bool, str) -> None """Writes the meta information generated from BigQuery schema.""" header_fields = ( schema_converter.generate_header_fields_from_schema( - _get_schema(input_table), allow_incompatible_schema)) + _get_schema(input_table, genomic_region), allow_incompatible_schema)) write_header_fn = vcf_header_io.WriteVcfHeaderFn(representative_header_file) write_header_fn.process(header_fields, _VCF_VERSION_LINE) @@ -163,31 +170,57 @@ def _bigquery_to_vcf_shards( Also, it writes the meta info and data header with the sample names to `vcf_header_file_path`. """ - schema = _get_schema(known_args.input_table) + schema = _get_schema(known_args.input_table, known_args.genomic_region) # TODO(allieychen): Modify the SQL query with the specified sample_ids. query = _get_bigquery_query(known_args, schema) logging.info('Processing BigQuery query %s:', query) + project_id, dataset_id, table_id = bigquery_util.parse_table_reference( + known_args.input_table) bq_source = bigquery.BigQuerySource(query=query, validate=True, use_standard_sql=True) annotation_names = _extract_annotation_names(schema) + + sample_query = _SAMPLE_INFO_QUERY_TEMPLATE.format(PROJECT_ID=project_id, + DATASET_ID=dataset_id, + TABLE_ID=table_id) + bq_sample_source = bigquery.BigQuerySource(query=sample_query, + validate=True, + use_standard_sql=True) with beam.Pipeline(options=beam_pipeline_options) as p: variants = (p | 'ReadFromBigQuery ' >> beam.io.Read(bq_source) | bigquery_to_variant.BigQueryToVariant(annotation_names)) + sample_table_rows = ( + p + | 'ReadFromSampleTable' >> beam.io.Read(bq_sample_source)) if known_args.sample_names: - sample_ids = (p - | transforms.Create(known_args.sample_names, - reshuffle=False) - | beam.combiners.ToList()) + hash_table = ( + sample_table_rows + | 'SampleNameToIdDict' >> sample_mapping_table.SampleNameToIdDict()) + sample_names = (p + | transforms.Create(known_args.sample_names, + reshuffle=False)) + sample_ids = (sample_names + | 'GetSampleIds' >> + sample_mapping_table.GetSampleIds( + beam.pvalue.AsSingleton(hash_table)) + | 'CombineSampleIds' >> beam.combiners.ToList()) + sample_names = sample_names | beam.combiners.ToList() else: + hash_table = ( + sample_table_rows + | 'SampleIdToNameDict' >> sample_mapping_table.SampleIdToNameDict()) sample_ids = (variants | 'CombineSampleIds' >> combine_sample_ids.SampleIdsCombiner( known_args.preserve_sample_order)) - # TODO(tneymanov): Add logic to extract sample names from sample IDs by - # joining with sample id-name mapping table, once that code is implemented. - sample_names = sample_ids + sample_names = (sample_ids + | 'GetSampleNames' >> + sample_mapping_table.GetSampleNames( + beam.pvalue.AsSingleton(hash_table)) + | 'CombineSampleNames' >> beam.combiners.ToList()) + sample_ids = sample_ids | beam.combiners.ToList() _ = (sample_names | 'GenerateVcfDataHeader' >> beam.ParDo(_write_vcf_header_with_sample_names, @@ -204,10 +237,11 @@ def _bigquery_to_vcf_shards( | vcfio.WriteVcfDataLines()) -def _get_schema(input_table): - # type: (str) -> bigquery_v2.TableSchema +def _get_schema(input_table, genomic_region): + # type: (str, str) -> bigquery_v2.TableSchema + ref, _, _ = genomic_region_parser.parse_genomic_region(genomic_region) project_id, dataset_id, table_id = bigquery_util.parse_table_reference( - input_table) + _FULL_INPUT_TABLE.format(TABLE=input_table, SUFFIX=ref)) credentials = (client.GoogleCredentials.get_application_default(). create_scoped(['https://www.googleapis.com/auth/bigquery'])) bigquery_client = bigquery_v2.BigqueryV2(credentials=credentials) @@ -220,21 +254,19 @@ def _get_bigquery_query(known_args, schema): # type: (argparse.Namespace, bigquery_v2.TableSchema) -> str """Returns a BigQuery query for the interested regions.""" columns = _get_query_columns(schema) + ref, start, end = genomic_region_parser.parse_genomic_region( + known_args.genomic_region) base_query = _BASE_QUERY_TEMPLATE.format( COLUMNS=', '.join(columns), INPUT_TABLE='.'.join( - bigquery_util.parse_table_reference(known_args.input_table))) + bigquery_util.parse_table_reference(known_args.input_table)), + CHROM=ref) conditions = [] - if known_args.genomic_regions: - for region in known_args.genomic_regions: - ref, start, end = genomic_region_parser.parse_genomic_region(region) - conditions.append(_GENOMIC_REGION_TEMPLATE.format( - REFERENCE_NAME_ID=bigquery_util.ColumnKeyConstants.REFERENCE_NAME, - REFERENCE_NAME_VALUE=ref, - START_POSITION_ID=bigquery_util.ColumnKeyConstants.START_POSITION, - START_POSITION_VALUE=start, - END_POSITION_ID=bigquery_util.ColumnKeyConstants.END_POSITION, - END_POSITION_VALUE=end)) + conditions.append(_GENOMIC_REGION_TEMPLATE.format( + START_POSITION_ID=bigquery_util.ColumnKeyConstants.START_POSITION, + START_POSITION_VALUE=start, + END_POSITION_ID=bigquery_util.ColumnKeyConstants.END_POSITION, + END_POSITION_VALUE=end)) if not conditions: return base_query diff --git a/gcp_variant_transforms/bq_to_vcf_test.py b/gcp_variant_transforms/bq_to_vcf_test.py index 2f0659b70..2ccbbfb30 100644 --- a/gcp_variant_transforms/bq_to_vcf_test.py +++ b/gcp_variant_transforms/bq_to_vcf_test.py @@ -64,21 +64,23 @@ def test_write_vcf_data_header(self): def test_get_bigquery_query_no_region(self): args = self._create_mock_args( input_table='my_bucket:my_dataset.my_table', - genomic_regions=None) + genomic_region='chr1') schema = bigquery.TableSchema() schema.fields.append(bigquery.TableFieldSchema( name=bigquery_util.ColumnKeyConstants.REFERENCE_NAME, type=bigquery_util.TableFieldConstants.TYPE_STRING, mode=bigquery_util.TableFieldConstants.MODE_NULLABLE, description='Reference name.')) - self.assertEqual(bq_to_vcf._get_bigquery_query(args, schema), - 'SELECT reference_name FROM ' - '`my_bucket.my_dataset.my_table`') + self.assertEqual( + bq_to_vcf._get_bigquery_query(args, schema), + 'SELECT reference_name FROM ' + '`my_bucket.my_dataset.my_table__chr1` ' + 'WHERE (start_position>=0 AND end_position<=9223372036854775807)') def test_get_bigquery_query_with_regions(self): args_1 = self._create_mock_args( input_table='my_bucket:my_dataset.my_table', - genomic_regions=['c1:1,000-2,000', 'c2']) + genomic_region='c1:1,000-2,000') schema = bigquery.TableSchema() schema.fields.append(bigquery.TableFieldSchema( name=bigquery_util.ColumnKeyConstants.REFERENCE_NAME, @@ -93,10 +95,8 @@ def test_get_bigquery_query_with_regions(self): 'of the string of reference bases.'))) expected_query = ( 'SELECT reference_name, start_position FROM ' - '`my_bucket.my_dataset.my_table` WHERE ' - '(reference_name="c1" AND start_position>=1000 AND end_position<=2000) ' - 'OR (reference_name="c2" AND start_position>=0 AND ' - 'end_position<=9223372036854775807)' + '`my_bucket.my_dataset.my_table__c1` WHERE ' + '(start_position>=1000 AND end_position<=2000)' ) self.assertEqual(bq_to_vcf._get_bigquery_query(args_1, schema), expected_query) diff --git a/gcp_variant_transforms/options/variant_transform_options.py b/gcp_variant_transforms/options/variant_transform_options.py index b3dbe9aa0..0e90fb59b 100644 --- a/gcp_variant_transforms/options/variant_transform_options.py +++ b/gcp_variant_transforms/options/variant_transform_options.py @@ -541,6 +541,22 @@ def add_arguments(self, parser): required=True, help=('BigQuery table that will be loaded to VCF. It must be in the ' 'format of (PROJECT:DATASET.TABLE).')) + parser.add_argument( + '--genomic_region', + required=True, default=None, + help=('A genomic region (separated by a space) to load from BigQuery. ' + 'The format of the genomic region should be ' + 'REFERENCE_NAME:START_POSITION-END_POSITION or REFERENCE_NAME if ' + 'the full chromosome is requested. Only variants matching at ' + 'this region will be loaded. The chromosome identifier should be ' + 'identical to the one provided in config file when the tables ' + 'were being created. For example, ' + '`--genomic_region chr2:1000-2000` will load all variants ' + '`chr2` with `start_position` in `[1000,2000)` from BigQuery. ' + 'If the table with suffix `my_chrom3` was imported, ' + '`--genomic_region my_chrom3` would return all the variants in ' + 'that shard. This flag must be specified to indicate the table ' + 'shard that needs to be exported to VCF file.')) parser.add_argument( '--number_of_bases_per_shard', type=int, default=1000000, @@ -558,18 +574,6 @@ def add_arguments(self, parser): 'repeated INFO field will have `Number=.`). It is recommended to ' 'provide this file to specify the most accurate and complete ' 'meta-information in the VCF file.')) - parser.add_argument( - '--genomic_regions', - default=None, nargs='+', - help=('A list of genomic regions (separated by a space) to load from ' - 'BigQuery. The format of each genomic region should be ' - 'REFERENCE_NAME:START_POSITION-END_POSITION or REFERENCE_NAME if ' - 'the full chromosome is requested. Only variants matching at ' - 'least one of these regions will be loaded. For example, ' - '`--genomic_regions chr1 chr2:1000-2000` will load all variants ' - 'in `chr1` and all variants in `chr2` with `start_position` in ' - '`[1000,2000)` from BigQuery. If this flag is not specified, all ' - 'variants will be loaded.')) parser.add_argument( '--sample_names', default=None, nargs='+', @@ -596,6 +600,8 @@ def add_arguments(self, parser): def _validate_inputs(parsed_args): + #ref, start, end = genomic_region_parser.parse_genomic_region( + # parsed_args.genomic_region) if ((parsed_args.input_pattern and parsed_args.input_file) or (not parsed_args.input_pattern and not parsed_args.input_file)): raise ValueError('Exactly one of input_pattern and input_file has to be ' diff --git a/gcp_variant_transforms/testing/asserts.py b/gcp_variant_transforms/testing/asserts.py index 006e4ed5c..add341fdb 100644 --- a/gcp_variant_transforms/testing/asserts.py +++ b/gcp_variant_transforms/testing/asserts.py @@ -59,6 +59,15 @@ def _has_sample_ids(variants): return _has_sample_ids +def dict_values_equal(expected_dict): + """Verifies that dictionary is the same as expected.""" + def _items_equal(actual_dict): + if expected_dict != actual_dict[0]: + raise BeamAssertException( + 'Failed assert: %d == %d' % (expected_dict, actual_dict)) + return _items_equal + + def header_vars_equal(expected): def _vars_equal(actual): expected_vars = [vars(header) for header in expected] diff --git a/gcp_variant_transforms/transforms/combine_sample_ids.py b/gcp_variant_transforms/transforms/combine_sample_ids.py index cf57a2ac0..3b19708bd 100644 --- a/gcp_variant_transforms/transforms/combine_sample_ids.py +++ b/gcp_variant_transforms/transforms/combine_sample_ids.py @@ -70,12 +70,10 @@ def expand(self, pcoll): | 'RemoveDuplicates' >> beam.RemoveDuplicates() | 'Combine' >> beam.combiners.ToList() | 'ExtractUniqueSampleIds' - >> beam.ParDo(self._extract_unique_sample_ids) - | beam.combiners.ToList()) + >> beam.ParDo(self._extract_unique_sample_ids)) else: return (pcoll | 'GetSampleIds' >> beam.FlatMap(self._get_sample_ids) | 'RemoveDuplicates' >> beam.RemoveDuplicates() | 'Combine' >> beam.combiners.ToList() - | 'SortSampleIds' >> beam.ParDo(sorted) - | beam.combiners.ToList()) + | 'SortSampleIds' >> beam.ParDo(sorted)) diff --git a/gcp_variant_transforms/transforms/combine_sample_ids_test.py b/gcp_variant_transforms/transforms/combine_sample_ids_test.py index 601734861..85420a7ee 100644 --- a/gcp_variant_transforms/transforms/combine_sample_ids_test.py +++ b/gcp_variant_transforms/transforms/combine_sample_ids_test.py @@ -16,6 +16,7 @@ import unittest +from apache_beam import combiners from apache_beam import transforms from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that @@ -48,8 +49,8 @@ def test_sample_ids_combiner_pipeline_preserve_sample_order_error(self): pipeline | transforms.Create(variants) | 'CombineSampleIds' >> - combine_sample_ids.SampleIdsCombiner( - preserve_sample_order=True)) + combine_sample_ids.SampleIdsCombiner(preserve_sample_order=True) + | combiners.ToList()) with self.assertRaises(ValueError): pipeline.run() @@ -76,8 +77,8 @@ def test_sample_ids_combiner_pipeline_preserve_sample_order(self): pipeline | transforms.Create(variants) | 'CombineSampleIds' >> - combine_sample_ids.SampleIdsCombiner( - preserve_sample_order=True)) + combine_sample_ids.SampleIdsCombiner(preserve_sample_order=True) + | combiners.ToList()) assert_that(combined_sample_ids, equal_to([sample_ids])) pipeline.run() @@ -99,8 +100,8 @@ def test_sample_ids_combiner_pipeline(self): combined_sample_ids = ( pipeline | transforms.Create(variants) - | 'CombineSampleIds' >> - combine_sample_ids.SampleIdsCombiner()) + | 'CombineSampleIds' >> combine_sample_ids.SampleIdsCombiner() + | combiners.ToList()) assert_that(combined_sample_ids, equal_to([sample_ids])) pipeline.run() @@ -112,7 +113,7 @@ def test_sample_ids_combiner_pipeline_duplicate_sample_ids(self): _ = ( pipeline | transforms.Create(variants) - | 'CombineSampleIds' >> - combine_sample_ids.SampleIdsCombiner()) + | 'CombineSampleIds' >> combine_sample_ids.SampleIdsCombiner() + | combiners.ToList()) with self.assertRaises(ValueError): pipeline.run() diff --git a/gcp_variant_transforms/transforms/sample_mapping_table.py b/gcp_variant_transforms/transforms/sample_mapping_table.py new file mode 100644 index 000000000..ce93c9a80 --- /dev/null +++ b/gcp_variant_transforms/transforms/sample_mapping_table.py @@ -0,0 +1,93 @@ +# Copyright 2020 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A PTransform to convert BigQuery table rows to a PCollection of `Variant`.""" + +from typing import Dict, List # pylint: disable=unused-import + +import apache_beam as beam + +import gcp_variant_transforms.libs.hashing_util + +SAMPLE_ID_COLUMN = 'sample_id' +SAMPLE_NAME_COLUMN = 'sample_name' +FILE_PATH_COLUMN = 'file_path' +WITH_FILE_SAMPLE_TEMPLATE = "{FILE_PATH}/{SAMPLE_NAME}" + + +class SampleIdToNameDict(beam.PTransform): + """Transforms BigQuery table rows to PCollection of `Variant`.""" + + def _convert_bq_row(self, row): + sample_id = row[SAMPLE_ID_COLUMN] + sample_name = row[SAMPLE_NAME_COLUMN] + file_path = row[FILE_PATH_COLUMN] + return (sample_id, (sample_name, file_path)) + + def expand(self, pcoll): + return (pcoll + | 'BigQueryToMapping' >> beam.Map(self._convert_bq_row) + | 'CombineToDict' >> beam.combiners.ToDict()) + + +class SampleNameToIdDict(beam.PTransform): + """Transforms BigQuery table rows to PCollection of `Variant`.""" + + def _convert_bq_row(self, row): + sample_id = row[SAMPLE_ID_COLUMN] + sample_name = row[SAMPLE_NAME_COLUMN] + return (sample_name, sample_id) + + def expand(self, pcoll): + return (pcoll + | 'BigQueryToMapping' >> beam.Map(self._convert_bq_row) + | 'CombineToDict' >> beam.combiners.ToDict()) + +class GetSampleNames(beam.PTransform): + """Transforms sample_ids to sample_names""" + + def __init__(self, hash_table): + # type: (Dict[int, Tuple(str, str)]) -> None + self._hash_table = hash_table + + def _get_encoding_type(self, hash_table): + # type: (Dict[int, Tuple(str, str)]) -> bool + sample_names = [c[0] for c in hash_table.values()] + return len(sample_names) == len(set(sample_names)) + + def _get_sample_id(self, sample_id, hash_table): + # type: (int, Dict[int, Tuple(str, str)]) -> str + sample = hash_table[sample_id] + if self._get_encoding_type(hash_table): + return sample[0] + else: + return WITH_FILE_SAMPLE_TEMPLATE.format(SAMPLE_NAME=sample[0], + FILE_PATH=sample[1]) + + def expand(self, pcoll): + return pcoll | beam.Map(self._get_sample_id, self._hash_table) + +class GetSampleIds(beam.PTransform): + """Transform sample_names to sample_ids""" + + def __init__(self, hash_table): + # type: (Dict[str, int)]) -> None + self._hash_table = hash_table + + def _get_sample_name(self, sample_name, hash_table): + # type: (str, Dict[str, int]) -> int + return hash_table[sample_name] + + def expand(self, pcoll): + return pcoll | beam.Map(self._get_sample_name, self._hash_table) diff --git a/gcp_variant_transforms/transforms/sample_mapping_table_test.py b/gcp_variant_transforms/transforms/sample_mapping_table_test.py new file mode 100644 index 000000000..8dc8b6fc2 --- /dev/null +++ b/gcp_variant_transforms/transforms/sample_mapping_table_test.py @@ -0,0 +1,148 @@ +# Copyright 2020 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for densify_variants module.""" + +from __future__ import absolute_import + +import unittest + +from apache_beam import combiners +from apache_beam.pvalue import AsSingleton +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.transforms import Create + +from gcp_variant_transforms.testing import asserts +from gcp_variant_transforms.transforms.sample_mapping_table import GetSampleIds +from gcp_variant_transforms.transforms.sample_mapping_table import GetSampleNames +from gcp_variant_transforms.transforms.sample_mapping_table import SAMPLE_ID_COLUMN +from gcp_variant_transforms.transforms.sample_mapping_table import SAMPLE_NAME_COLUMN +from gcp_variant_transforms.transforms.sample_mapping_table import SampleIdToNameDict +from gcp_variant_transforms.transforms.sample_mapping_table import SampleNameToIdDict +from gcp_variant_transforms.transforms.sample_mapping_table import FILE_PATH_COLUMN + + + +def _generate_bq_row(sample_id, sample_name, file_path): + return {SAMPLE_ID_COLUMN: sample_id, + SAMPLE_NAME_COLUMN: sample_name, + FILE_PATH_COLUMN: file_path} + +BQ_ROWS = [_generate_bq_row(1, 'N01', 'file1'), + _generate_bq_row(2, 'N02', 'file2'), + _generate_bq_row(3, 'N03', 'file3')] + +class SampleIdToNameDictTest(unittest.TestCase): + """Test cases for the ``SampleTableToDict`` transform.""" + + def test_sample_table_to_dict(self): + expected_dict = {1: ('N01', 'file1'), + 2: ('N02', 'file2'), + 3: ('N03', 'file3')} + + pipeline = TestPipeline() + hash_table = ( + pipeline + | Create(BQ_ROWS) + | 'GenerateHashTable' >> SampleIdToNameDict()) + + assert_that(hash_table, asserts.dict_values_equal(expected_dict)) + pipeline.run() + +class SampleNameToIdDictTest(unittest.TestCase): + """Test cases for the ``SampleTableToDict`` transform.""" + + def test_sample_table_to_dict(self): + expected_dict = {'N01': 1, 'N02': 2, 'N03': 3} + + pipeline = TestPipeline() + hash_table = ( + pipeline + | Create(BQ_ROWS) + | 'GenerateHashTable' >> SampleNameToIdDict()) + + assert_that(hash_table, asserts.dict_values_equal(expected_dict)) + + pipeline.run() + +class GetSampleNamesTest(unittest.TestCase): + """Test cases for the ``SampleTableToDict`` transform.""" + + def _sample_name_with_file(self, sample_name, file_path): + return "{}/{}".format(file_path, sample_name) + + def test_get_sample_names_with_file(self): + hash_dict = {1: ('N01', 'file1'), + 2: ('N02', 'file1'), + 3: ('N01', 'file2'), + 4: ('N02', 'file2')} + sample_ids = [1, 2, 3, 4] + expected_sample_names = [self._sample_name_with_file('N01', 'file1'), + self._sample_name_with_file('N02', 'file1'), + self._sample_name_with_file('N01', 'file2'), + self._sample_name_with_file('N02', 'file2')] + + pipeline = TestPipeline() + hash_dict_pc = ( + pipeline + | 'CreateHashDict' >> Create(hash_dict) + | combiners.ToDict()) + sample_names = ( + pipeline + | Create(sample_ids) + | 'GetSampleNames' >> GetSampleNames(AsSingleton(hash_dict_pc))) + + assert_that(sample_names, asserts.items_equal(expected_sample_names)) + pipeline.run() + + def test_get_sample_names_without_file(self): + hash_dict = {1: ('N01', 'file1'), + 2: ('N02', 'file2')} + sample_ids = [1, 2] + expected_sample_names = ['N01', 'N02'] + + pipeline = TestPipeline() + hash_dict_pc = ( + pipeline + | 'CreateHashDict' >> Create(hash_dict) + | combiners.ToDict()) + sample_names = ( + pipeline + | Create(sample_ids) + | 'GetSampleNames' >> GetSampleNames(AsSingleton(hash_dict_pc))) + + assert_that(sample_names, asserts.items_equal(expected_sample_names)) + pipeline.run() + +class GetSampleIdsTest(unittest.TestCase): + """Test cases for the ``SampleTableToDict`` transform.""" + + def test_get_sample_ids(self): + hash_dict = {'N01': 1, 'N02': 2, 'N03': 3, 'N04': 4} + sample_names = ['N01', 'N02', 'N03', 'N04'] + expected_sample_ids = [1, 2, 3, 4] + + pipeline = TestPipeline() + hash_dict_pc = ( + pipeline + | 'CreateHashDict' >> Create(hash_dict) + | combiners.ToDict()) + sample_ids = ( + pipeline + | Create(sample_names) + | 'GetSampleNames' >> GetSampleIds(AsSingleton(hash_dict_pc))) + + assert_that(sample_ids, asserts.items_equal(expected_sample_ids)) + pipeline.run()