Skip to content

Commit

Permalink
Adjust bq_to_vcf to work with sample info table.
Browse files Browse the repository at this point in the history
  • Loading branch information
tneymanov committed Feb 24, 2020
1 parent 1684106 commit 19bf9bd
Show file tree
Hide file tree
Showing 9 changed files with 349 additions and 63 deletions.
1 change: 0 additions & 1 deletion cloudbuild_CI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ steps:
- '--image_tag ${COMMIT_SHA}'
- '--run_unit_tests'
- '--run_preprocessor_tests'
- '--run_bq_to_vcf_tests'
- '--run_all_tests'
- '--test_name_prefix cloud-ci-'
id: 'test-gcp-variant-transforms-docker'
Expand Down
90 changes: 61 additions & 29 deletions gcp_variant_transforms/bq_to_vcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,19 @@
from gcp_variant_transforms.transforms import bigquery_to_variant
from gcp_variant_transforms.transforms import combine_sample_ids
from gcp_variant_transforms.transforms import densify_variants
from gcp_variant_transforms.transforms import sample_mapping_table


_BASE_QUERY_TEMPLATE = 'SELECT {COLUMNS} FROM `{INPUT_TABLE}`'

_BASE_QUERY_TEMPLATE = 'SELECT {COLUMNS} FROM `{INPUT_TABLE}__{CHROM}`'
_BQ_TO_VCF_SHARDS_JOB_NAME = 'bq-to-vcf-shards'
_COMMAND_LINE_OPTIONS = [variant_transform_options.BigQueryToVcfOptions]
_GENOMIC_REGION_TEMPLATE = ('({REFERENCE_NAME_ID}="{REFERENCE_NAME_VALUE}" AND '
'{START_POSITION_ID}>={START_POSITION_VALUE} AND '
_FULL_INPUT_TABLE = '{TABLE}__{SUFFIX}'
_GENOMIC_REGION_TEMPLATE = ('({START_POSITION_ID}>={START_POSITION_VALUE} AND '
'{END_POSITION_ID}<={END_POSITION_VALUE})')
_SAMPLE_INFO_QUERY_TEMPLATE = (
'SELECT sample_id, sample_name, file_path '
'FROM `{PROJECT_ID}.{DATASET_ID}.{TABLE_ID}_sample_info`')
_VCF_FIXED_COLUMNS = ['#CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER',
'INFO', 'FORMAT']
_VCF_VERSION_LINE = '##fileformat=VCFv4.3\n'
Expand Down Expand Up @@ -117,7 +122,8 @@ def run(argv=None):
'{}_meta_info.vcf'.format(unique_temp_id))
_write_vcf_meta_info(known_args.input_table,
known_args.representative_header_file,
known_args.allow_incompatible_schema)
known_args.allow_incompatible_schema,
known_args.genomic_region)

_bigquery_to_vcf_shards(known_args,
options,
Expand All @@ -136,12 +142,13 @@ def run(argv=None):

def _write_vcf_meta_info(input_table,
representative_header_file,
allow_incompatible_schema):
# type: (str, str, bool) -> None
allow_incompatible_schema,
genomic_region):
# type: (str, str, bool, str) -> None
"""Writes the meta information generated from BigQuery schema."""
header_fields = (
schema_converter.generate_header_fields_from_schema(
_get_schema(input_table), allow_incompatible_schema))
_get_schema(input_table, genomic_region), allow_incompatible_schema))
write_header_fn = vcf_header_io.WriteVcfHeaderFn(representative_header_file)
write_header_fn.process(header_fields, _VCF_VERSION_LINE)

Expand All @@ -163,31 +170,57 @@ def _bigquery_to_vcf_shards(
Also, it writes the meta info and data header with the sample names to
`vcf_header_file_path`.
"""
schema = _get_schema(known_args.input_table)
schema = _get_schema(known_args.input_table, known_args.genomic_region)
# TODO(allieychen): Modify the SQL query with the specified sample_ids.
query = _get_bigquery_query(known_args, schema)
logging.info('Processing BigQuery query %s:', query)
project_id, dataset_id, table_id = bigquery_util.parse_table_reference(
known_args.input_table)
bq_source = bigquery.BigQuerySource(query=query,
validate=True,
use_standard_sql=True)
annotation_names = _extract_annotation_names(schema)

sample_query = _SAMPLE_INFO_QUERY_TEMPLATE.format(PROJECT_ID=project_id,
DATASET_ID=dataset_id,
TABLE_ID=table_id)
bq_sample_source = bigquery.BigQuerySource(query=sample_query,
validate=True,
use_standard_sql=True)
with beam.Pipeline(options=beam_pipeline_options) as p:
variants = (p
| 'ReadFromBigQuery ' >> beam.io.Read(bq_source)
| bigquery_to_variant.BigQueryToVariant(annotation_names))
sample_table_rows = (
p
| 'ReadFromSampleTable' >> beam.io.Read(bq_sample_source))
if known_args.sample_names:
sample_ids = (p
| transforms.Create(known_args.sample_names,
reshuffle=False)
| beam.combiners.ToList())
hash_table = (
sample_table_rows
| 'SampleNameToIdDict' >> sample_mapping_table.SampleNameToIdDict())
sample_names = (p
| transforms.Create(known_args.sample_names,
reshuffle=False))
sample_ids = (sample_names
| 'GetSampleIds' >>
sample_mapping_table.GetSampleIds(
beam.pvalue.AsSingleton(hash_table))
| 'CombineSampleIds' >> beam.combiners.ToList())
sample_names = sample_names | beam.combiners.ToList()
else:
hash_table = (
sample_table_rows
| 'SampleIdToNameDict' >> sample_mapping_table.SampleIdToNameDict())
sample_ids = (variants
| 'CombineSampleIds' >>
combine_sample_ids.SampleIdsCombiner(
known_args.preserve_sample_order))
# TODO(tneymanov): Add logic to extract sample names from sample IDs by
# joining with sample id-name mapping table, once that code is implemented.
sample_names = sample_ids
sample_names = (sample_ids
| 'GetSampleNames' >>
sample_mapping_table.GetSampleNames(
beam.pvalue.AsSingleton(hash_table))
| 'CombineSampleNames' >> beam.combiners.ToList())
sample_ids = sample_ids | beam.combiners.ToList()
_ = (sample_names
| 'GenerateVcfDataHeader' >>
beam.ParDo(_write_vcf_header_with_sample_names,
Expand All @@ -204,10 +237,11 @@ def _bigquery_to_vcf_shards(
| vcfio.WriteVcfDataLines())


def _get_schema(input_table):
# type: (str) -> bigquery_v2.TableSchema
def _get_schema(input_table, genomic_region):
# type: (str, str) -> bigquery_v2.TableSchema
ref, _, _ = genomic_region_parser.parse_genomic_region(genomic_region)
project_id, dataset_id, table_id = bigquery_util.parse_table_reference(
input_table)
_FULL_INPUT_TABLE.format(TABLE=input_table, SUFFIX=ref))
credentials = (client.GoogleCredentials.get_application_default().
create_scoped(['https://www.googleapis.com/auth/bigquery']))
bigquery_client = bigquery_v2.BigqueryV2(credentials=credentials)
Expand All @@ -220,21 +254,19 @@ def _get_bigquery_query(known_args, schema):
# type: (argparse.Namespace, bigquery_v2.TableSchema) -> str
"""Returns a BigQuery query for the interested regions."""
columns = _get_query_columns(schema)
ref, start, end = genomic_region_parser.parse_genomic_region(
known_args.genomic_region)
base_query = _BASE_QUERY_TEMPLATE.format(
COLUMNS=', '.join(columns),
INPUT_TABLE='.'.join(
bigquery_util.parse_table_reference(known_args.input_table)))
bigquery_util.parse_table_reference(known_args.input_table)),
CHROM=ref)
conditions = []
if known_args.genomic_regions:
for region in known_args.genomic_regions:
ref, start, end = genomic_region_parser.parse_genomic_region(region)
conditions.append(_GENOMIC_REGION_TEMPLATE.format(
REFERENCE_NAME_ID=bigquery_util.ColumnKeyConstants.REFERENCE_NAME,
REFERENCE_NAME_VALUE=ref,
START_POSITION_ID=bigquery_util.ColumnKeyConstants.START_POSITION,
START_POSITION_VALUE=start,
END_POSITION_ID=bigquery_util.ColumnKeyConstants.END_POSITION,
END_POSITION_VALUE=end))
conditions.append(_GENOMIC_REGION_TEMPLATE.format(
START_POSITION_ID=bigquery_util.ColumnKeyConstants.START_POSITION,
START_POSITION_VALUE=start,
END_POSITION_ID=bigquery_util.ColumnKeyConstants.END_POSITION,
END_POSITION_VALUE=end))

if not conditions:
return base_query
Expand Down
18 changes: 9 additions & 9 deletions gcp_variant_transforms/bq_to_vcf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,23 @@ def test_write_vcf_data_header(self):
def test_get_bigquery_query_no_region(self):
args = self._create_mock_args(
input_table='my_bucket:my_dataset.my_table',
genomic_regions=None)
genomic_region='chr1')
schema = bigquery.TableSchema()
schema.fields.append(bigquery.TableFieldSchema(
name=bigquery_util.ColumnKeyConstants.REFERENCE_NAME,
type=bigquery_util.TableFieldConstants.TYPE_STRING,
mode=bigquery_util.TableFieldConstants.MODE_NULLABLE,
description='Reference name.'))
self.assertEqual(bq_to_vcf._get_bigquery_query(args, schema),
'SELECT reference_name FROM '
'`my_bucket.my_dataset.my_table`')
self.assertEqual(
bq_to_vcf._get_bigquery_query(args, schema),
'SELECT reference_name FROM '
'`my_bucket.my_dataset.my_table__chr1` '
'WHERE (start_position>=0 AND end_position<=9223372036854775807)')

def test_get_bigquery_query_with_regions(self):
args_1 = self._create_mock_args(
input_table='my_bucket:my_dataset.my_table',
genomic_regions=['c1:1,000-2,000', 'c2'])
genomic_region='c1:1,000-2,000')
schema = bigquery.TableSchema()
schema.fields.append(bigquery.TableFieldSchema(
name=bigquery_util.ColumnKeyConstants.REFERENCE_NAME,
Expand All @@ -93,10 +95,8 @@ def test_get_bigquery_query_with_regions(self):
'of the string of reference bases.')))
expected_query = (
'SELECT reference_name, start_position FROM '
'`my_bucket.my_dataset.my_table` WHERE '
'(reference_name="c1" AND start_position>=1000 AND end_position<=2000) '
'OR (reference_name="c2" AND start_position>=0 AND '
'end_position<=9223372036854775807)'
'`my_bucket.my_dataset.my_table__c1` WHERE '
'(start_position>=1000 AND end_position<=2000)'
)
self.assertEqual(bq_to_vcf._get_bigquery_query(args_1, schema),
expected_query)
Expand Down
30 changes: 18 additions & 12 deletions gcp_variant_transforms/options/variant_transform_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,22 @@ def add_arguments(self, parser):
required=True,
help=('BigQuery table that will be loaded to VCF. It must be in the '
'format of (PROJECT:DATASET.TABLE).'))
parser.add_argument(
'--genomic_region',
required=True, default=None,
help=('A genomic region (separated by a space) to load from BigQuery. '
'The format of the genomic region should be '
'REFERENCE_NAME:START_POSITION-END_POSITION or REFERENCE_NAME if '
'the full chromosome is requested. Only variants matching at '
'this region will be loaded. The chromosome identifier should be '
'identical to the one provided in config file when the tables '
'were being created. For example, '
'`--genomic_region chr2:1000-2000` will load all variants '
'`chr2` with `start_position` in `[1000,2000)` from BigQuery. '
'If the table with suffix `my_chrom3` was imported, '
'`--genomic_region my_chrom3` would return all the variants in '
'that shard. This flag must be specified to indicate the table '
'shard that needs to be exported to VCF file.'))
parser.add_argument(
'--number_of_bases_per_shard',
type=int, default=1000000,
Expand All @@ -558,18 +574,6 @@ def add_arguments(self, parser):
'repeated INFO field will have `Number=.`). It is recommended to '
'provide this file to specify the most accurate and complete '
'meta-information in the VCF file.'))
parser.add_argument(
'--genomic_regions',
default=None, nargs='+',
help=('A list of genomic regions (separated by a space) to load from '
'BigQuery. The format of each genomic region should be '
'REFERENCE_NAME:START_POSITION-END_POSITION or REFERENCE_NAME if '
'the full chromosome is requested. Only variants matching at '
'least one of these regions will be loaded. For example, '
'`--genomic_regions chr1 chr2:1000-2000` will load all variants '
'in `chr1` and all variants in `chr2` with `start_position` in '
'`[1000,2000)` from BigQuery. If this flag is not specified, all '
'variants will be loaded.'))
parser.add_argument(
'--sample_names',
default=None, nargs='+',
Expand All @@ -596,6 +600,8 @@ def add_arguments(self, parser):


def _validate_inputs(parsed_args):
#ref, start, end = genomic_region_parser.parse_genomic_region(
# parsed_args.genomic_region)
if ((parsed_args.input_pattern and parsed_args.input_file) or
(not parsed_args.input_pattern and not parsed_args.input_file)):
raise ValueError('Exactly one of input_pattern and input_file has to be '
Expand Down
9 changes: 9 additions & 0 deletions gcp_variant_transforms/testing/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ def _has_sample_ids(variants):
return _has_sample_ids


def dict_values_equal(expected_dict):
"""Verifies that dictionary is the same as expected."""
def _items_equal(actual_dict):
if expected_dict != actual_dict[0]:
raise BeamAssertException(
'Failed assert: %d == %d' % (expected_dict, actual_dict))
return _items_equal


def header_vars_equal(expected):
def _vars_equal(actual):
expected_vars = [vars(header) for header in expected]
Expand Down
6 changes: 2 additions & 4 deletions gcp_variant_transforms/transforms/combine_sample_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,10 @@ def expand(self, pcoll):
| 'RemoveDuplicates' >> beam.RemoveDuplicates()
| 'Combine' >> beam.combiners.ToList()
| 'ExtractUniqueSampleIds'
>> beam.ParDo(self._extract_unique_sample_ids)
| beam.combiners.ToList())
>> beam.ParDo(self._extract_unique_sample_ids))
else:
return (pcoll
| 'GetSampleIds' >> beam.FlatMap(self._get_sample_ids)
| 'RemoveDuplicates' >> beam.RemoveDuplicates()
| 'Combine' >> beam.combiners.ToList()
| 'SortSampleIds' >> beam.ParDo(sorted)
| beam.combiners.ToList())
| 'SortSampleIds' >> beam.ParDo(sorted))
17 changes: 9 additions & 8 deletions gcp_variant_transforms/transforms/combine_sample_ids_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import unittest

from apache_beam import combiners
from apache_beam import transforms
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
Expand Down Expand Up @@ -48,8 +49,8 @@ def test_sample_ids_combiner_pipeline_preserve_sample_order_error(self):
pipeline
| transforms.Create(variants)
| 'CombineSampleIds' >>
combine_sample_ids.SampleIdsCombiner(
preserve_sample_order=True))
combine_sample_ids.SampleIdsCombiner(preserve_sample_order=True)
| combiners.ToList())
with self.assertRaises(ValueError):
pipeline.run()

Expand All @@ -76,8 +77,8 @@ def test_sample_ids_combiner_pipeline_preserve_sample_order(self):
pipeline
| transforms.Create(variants)
| 'CombineSampleIds' >>
combine_sample_ids.SampleIdsCombiner(
preserve_sample_order=True))
combine_sample_ids.SampleIdsCombiner(preserve_sample_order=True)
| combiners.ToList())
assert_that(combined_sample_ids, equal_to([sample_ids]))
pipeline.run()

Expand All @@ -99,8 +100,8 @@ def test_sample_ids_combiner_pipeline(self):
combined_sample_ids = (
pipeline
| transforms.Create(variants)
| 'CombineSampleIds' >>
combine_sample_ids.SampleIdsCombiner())
| 'CombineSampleIds' >> combine_sample_ids.SampleIdsCombiner()
| combiners.ToList())
assert_that(combined_sample_ids, equal_to([sample_ids]))
pipeline.run()

Expand All @@ -112,7 +113,7 @@ def test_sample_ids_combiner_pipeline_duplicate_sample_ids(self):
_ = (
pipeline
| transforms.Create(variants)
| 'CombineSampleIds' >>
combine_sample_ids.SampleIdsCombiner())
| 'CombineSampleIds' >> combine_sample_ids.SampleIdsCombiner()
| combiners.ToList())
with self.assertRaises(ValueError):
pipeline.run()
Loading

0 comments on commit 19bf9bd

Please sign in to comment.