Skip to content

Commit

Permalink
Adjust the commit to always assume 1-1 mapping between sample names a…
Browse files Browse the repository at this point in the history
…nd IDs, and modify integration tests.
  • Loading branch information
tneymanov committed Apr 9, 2020
1 parent aeff351 commit bb8baef
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 173 deletions.
39 changes: 20 additions & 19 deletions gcp_variant_transforms/bq_to_vcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@
_BQ_TO_VCF_SHARDS_JOB_NAME = 'bq-to-vcf-shards'
_COMMAND_LINE_OPTIONS = [variant_transform_options.BigQueryToVcfOptions]
TABLE_SUFFIX_SEPARATOR = bigquery_util.TABLE_SUFFIX_SEPARATOR
SAMPLE_TABLE_SUFFIX = bigquery_util.TABLE_SUFFIX
SAMPLE_TABLE_SUFFIX_SEPARATOR = bigquery_util.SAMPLE_TABLE_SUFFIX_SEPARATOR
SAMPLE_INFO_TABLE_SUFFIX = bigquery_util.SAMPLE_INFO_TABLE_SUFFIX
_FULL_INPUT_TABLE = '{TABLE}' + TABLE_SUFFIX_SEPARATOR + '{SUFFIX}'
_GENOMIC_REGION_TEMPLATE = ('({REFERENCE_NAME_ID}="{REFERENCE_NAME_VALUE}" AND '
'{START_POSITION_ID}>={START_POSITION_VALUE} AND '
Expand All @@ -87,7 +86,7 @@
_SAMPLE_INFO_QUERY_TEMPLATE = (
'SELECT sample_id, sample_name, file_path '
'FROM `{PROJECT_ID}.{DATASET_ID}.{BASE_TABLE_ID}' +
SAMPLE_TABLE_SUFFIX_SEPARATOR + SAMPLE_TABLE_SUFFIX + '`')
TABLE_SUFFIX_SEPARATOR + SAMPLE_INFO_TABLE_SUFFIX + '`')


def run(argv=None):
Expand Down Expand Up @@ -174,12 +173,11 @@ def _bigquery_to_vcf_shards(
`vcf_header_file_path`.
"""
schema = _get_schema(known_args.input_table)
# TODO(allieychen): Modify the SQL query with the specified sample_ids.
query = _get_bigquery_query(known_args, schema)
query = _get_variant_query(known_args, schema)
logging.info('Processing BigQuery query %s:', query)
project_id, dataset_id, table_id = bigquery_util.parse_table_reference(
known_args.input_table)
bq_source = bigquery.BigQuerySource(query=query,
bq_variant_source = bigquery.BigQuerySource(query=query,
validate=True,
use_standard_sql=True)
annotation_names = _extract_annotation_names(schema)
Expand All @@ -193,13 +191,13 @@ def _bigquery_to_vcf_shards(
use_standard_sql=True)
with beam.Pipeline(options=beam_pipeline_options) as p:
variants = (p
| 'ReadFromBigQuery ' >> beam.io.Read(bq_source)
| 'ReadFromBigQuery ' >> beam.io.Read(bq_variant_source)
| bigquery_to_variant.BigQueryToVariant(annotation_names))
sample_table_rows = (
p
| 'ReadFromSampleTable' >> beam.io.Read(bq_sample_source))
if known_args.sample_names:
names_to_ids = (
hash_table = (
sample_table_rows
| 'SampleNameToIdDict' >> sample_mapping_table.SampleNameToIdDict())
sample_names = (p
Expand All @@ -208,21 +206,24 @@ def _bigquery_to_vcf_shards(
sample_ids = (sample_names
| 'GetSampleIds' >>
sample_mapping_table.GetSampleIds(
beam.pvalue.AsSingleton(names_to_ids)))
beam.pvalue.AsSingleton(hash_table))
| 'CombineSampleIds' >> beam.combiners.ToList())
sample_names = sample_names | beam.combiners.ToList()
else:
hash_table = (
sample_table_rows
| 'SampleIdToNameDict' >> sample_mapping_table.SampleIdToNameDict())
sample_ids = (variants
| 'CombineSampleIds' >>
combine_sample_ids.SampleIdsCombiner(
known_args.preserve_sample_order))
ids_to_names = (
sample_table_rows
| 'SampleIdToNameDict' >> sample_mapping_table.SampleIdToNameDict())
sample_names = (sample_ids
| 'GetSampleNames' >>
sample_mapping_table.GetSampleNames(
beam.pvalue.AsSingleton(ids_to_names))
| 'CombineSampleNames' >> beam.combiners.ToList())
sample_ids = sample_ids | beam.combiners.ToList()
sample_names = (sample_ids
| 'GetSampleNames' >>
sample_mapping_table.GetSampleNames(
beam.pvalue.AsSingleton(hash_table))
| 'CombineSampleNames' >> beam.combiners.ToList())
sample_ids = sample_ids | beam.combiners.ToList()

_ = (sample_names
| 'GenerateVcfDataHeader' >>
beam.ParDo(_write_vcf_header_with_sample_names,
Expand Down Expand Up @@ -251,7 +252,7 @@ def _get_schema(input_table):
return table.schema


def _get_bigquery_query(known_args, schema):
def _get_variant_query(known_args, schema):
# type: (argparse.Namespace, bigquery_v2.TableSchema) -> str
"""Returns a BigQuery query for the interested regions."""
columns = _get_query_columns(schema)
Expand Down
8 changes: 4 additions & 4 deletions gcp_variant_transforms/bq_to_vcf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_write_vcf_data_header(self):
content = f.readlines()
self.assertEqual(content, expected_content)

def test_get_bigquery_query_no_region(self):
def test_get_variant_query_no_region(self):
args = self._create_mock_args(
input_table='my_bucket:my_dataset.my_table',
genomic_regions=None)
Expand All @@ -71,11 +71,11 @@ def test_get_bigquery_query_no_region(self):
type=bigquery_util.TableFieldConstants.TYPE_STRING,
mode=bigquery_util.TableFieldConstants.MODE_NULLABLE,
description='Reference name.'))
self.assertEqual(bq_to_vcf._get_bigquery_query(args, schema),
self.assertEqual(bq_to_vcf._get_variant_query(args, schema),
'SELECT reference_name FROM '
'`my_bucket.my_dataset.my_table`')

def test_get_bigquery_query_with_regions(self):
def test_get_variant_query_with_regions(self):
args_1 = self._create_mock_args(
input_table='my_bucket:my_dataset.my_table',
genomic_regions=['c1:1,000-2,000', 'c2'])
Expand All @@ -98,7 +98,7 @@ def test_get_bigquery_query_with_regions(self):
'OR (reference_name="c2" AND start_position>=0 AND '
'end_position<=9223372036854775807)'
)
self.assertEqual(bq_to_vcf._get_bigquery_query(args_1, schema),
self.assertEqual(bq_to_vcf._get_variant_query(args_1, schema),
expected_query)

def test_get_query_columns(self):
Expand Down
16 changes: 6 additions & 10 deletions gcp_variant_transforms/options/variant_transform_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
from gcp_variant_transforms.libs import variant_sharding

TABLE_SUFFIX_SEPARATOR = bigquery_util.TABLE_SUFFIX_SEPARATOR
SAMPLE_TABLE_SUFFIX = bigquery_util.TABLE_SUFFIX
SAMPLE_TABLE_SUFFIX_SEPARATOR = bigquery_util.SAMPLE_TABLE_SUFFIX_SEPARATOR
SAMPLE_INFO_TABLE_SUFFIX = bigquery_util.SAMPLE_INFO_TABLE_SUFFIX


class VariantTransformsOptions(object):
Expand Down Expand Up @@ -219,8 +218,7 @@ def validate(self, parsed_args, client=None):
dataset_id)
all_output_tables = []
all_output_tables.append(
bigquery_util.compose_table_name(
table_id, bigquery_util.SAMPLE_INFO_TABLE_SUFFIX))
bigquery_util.compose_table_name(table_id, SAMPLE_INFO_TABLE_SUFFIX))
sharding = variant_sharding.VariantSharding(
parsed_args.sharding_config_path)
num_shards = sharding.get_num_shards()
Expand Down Expand Up @@ -599,19 +597,17 @@ def validate(self, parsed_args, client=None):

project_id, dataset_id, table_id = bigquery_util.parse_table_reference(
parsed_args.input_table)
if not bigquery_util.table_exist(client, project_id, dataset_id, table_id):
raise ValueError('Table {}:{}.{} does not exist.'.format(
project_id, dataset_id, table_id))
if table_id.count(TABLE_SUFFIX_SEPARATOR) != 1:
raise ValueError(
'Input table {} is malformed - exactly one suffix separator "{}" is '
'required'.format(parsed_args.input_table,
TABLE_SUFFIX_SEPARATOR))
base_table_id = table_id[:table_id.find(TABLE_SUFFIX_SEPARATOR)]
sample_table_id = (
base_table_id + SAMPLE_TABLE_SUFFIX_SEPARATOR + SAMPLE_TABLE_SUFFIX)
bigquery_util.raise_error_if_dataset_not_exists(client, project_id,
dataset_id)
if not bigquery_util.table_exist(client, project_id, dataset_id, table_id):
raise ValueError('Table {}:{}.{} does not exist.'.format(
project_id, dataset_id, table_id))
base_table_id + TABLE_SUFFIX_SEPARATOR + SAMPLE_INFO_TABLE_SUFFIX)

if not bigquery_util.table_exist(client, project_id, dataset_id,
sample_table_id):
Expand Down
2 changes: 1 addition & 1 deletion gcp_variant_transforms/testing/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def dict_values_equal(expected_dict):
def _items_equal(actual_dict):
actual = actual_dict[0]
for k in expected_dict:
if k not in actual or set(expected_dict[k]) != set(actual[k]):
if k not in actual or expected_dict[k] != actual[k]:
raise BeamAssertException(
'Failed assert: %s == %s' % (expected_dict, actual))
if len(expected_dict) != len(actual):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[
{
"test_name": "bq-to-vcf-no-options",
"input_table": "gcp-variant-transforms-test:bq_to_vcf_integration_tests.4_0_new_schema",
"input_table": "gcp-variant-transforms-test:bq_to_vcf_integration_tests.4_0__suffix",
"output_file_name": "bq_to_vcf_no_options.vcf",
"runner": "DirectRunner",
"expected_output_file": "gcp_variant_transforms/testing/data/vcf/bq_to_vcf/expected_output/no_options_chr20.vcf"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[
{
"test_name": "bq-to-vcf-option-allow-incompatible-schema",
"input_table": "gcp-variant-transforms-test:bq_to_vcf_integration_tests.4_2_new_schema",
"input_table": "gcp-variant-transforms-test:bq_to_vcf_integration_tests.4_2__suffix",
"output_file_name": "bq_to_vcf_option_allow_incompatible_schema.vcf",
"allow_incompatible_schema": true,
"runner": "DirectRunner",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[
{
"test_name": "bq-to-vcf-option-customized-export",
"input_table": "gcp-variant-transforms-test:bq_to_vcf_integration_tests.4_0_new_schema",
"input_table": "gcp-variant-transforms-test:bq_to_vcf_integration_tests.4_0__suffix",
"output_file_name": "bq_to_vcf_option_customized_export.vcf",
"genomic_regions": "19:1234566-1234570 20:14369-17330",
"sample_names": "NA00001 NA00003",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[
{
"test_name": "bq-to-vcf-option-number-of-bases-per-shard",
"input_table": "gcp-variant-transforms-test:bq_to_vcf_integration_tests.platinum_NA12877_hg38_10K_lines_new_schema",
"input_table": "gcp-variant-transforms-test:bq_to_vcf_integration_tests.platinum_NA12877_hg38_10K_lines__suffix",
"output_file_name": "bq_to_vcf_option_number_of_bases_per_shard.vcf",
"number_of_bases_per_shard": 100000,
"runner": "DataflowRunner",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[
{
"test_name": "bq-to-vcf-option-representative-header-file",
"input_table": "gcp-variant-transforms-test:bq_to_vcf_integration_tests.4_0_new_schema",
"input_table": "gcp-variant-transforms-test:bq_to_vcf_integration_tests.4_0__suffix",
"representative_header_file": "gs://gcp-variant-transforms-testfiles/small_tests/valid-4.0.vcf",
"preserve_sample_order": true,
"output_file_name": "bq_to_vcf_option_representative_header_file.vcf",
Expand Down
94 changes: 22 additions & 72 deletions gcp_variant_transforms/transforms/sample_mapping_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
SAMPLE_NAME_COLUMN = 'sample_name'
FILE_PATH_COLUMN = 'file_path'
WITH_FILE_SAMPLE_TEMPLATE = "{FILE_PATH}/{SAMPLE_NAME}"
SAMPLE_NAME_TEMPLATE = "{SAMPLE_NAME}_{IND}"


class SampleIdToNameDict(beam.PTransform):
Expand All @@ -31,75 +30,13 @@ class SampleIdToNameDict(beam.PTransform):
def _convert_bq_row(self, row):
sample_id = row[SAMPLE_ID_COLUMN]
sample_name = row[SAMPLE_NAME_COLUMN]
file_path = row[FILE_PATH_COLUMN]
return (sample_id, (sample_name, file_path))

def _process_hash_table(self, hash_table):
sample_names = []
file_ind = {}
ind = 0
for value in hash_table.values():
sample_names.append(value[0])
if value[1] not in file_ind:
ind += 1
file_ind[value[1]] = ind

parsed_dict = {}
if len(sample_names) == len(set(sample_names)):
for k, v in hash_table.items():
parsed_dict[k] = v[0]
else:
for k, v in hash_table.items():
parsed_dict[k] = SAMPLE_NAME_TEMPLATE.format(SAMPLE_NAME=v[0],
IND=file_ind[v[1]])
return parsed_dict
return (sample_id, sample_name)

def expand(self, pcoll):
return (pcoll
| 'BigQueryToMapping' >> beam.Map(self._convert_bq_row)
| 'CombineToDict' >> beam.combiners.ToDict()
| 'ProcessDict' >> beam.Map(self._process_hash_table))
| 'CombineToDict' >> beam.combiners.ToDict())

class GetSampleNames(beam.PTransform):
"""Transforms sample_ids to sample_names"""

def __init__(self, hash_table):
# type: (Dict[int, Tuple(str, str)]) -> None
self._hash_table = hash_table

def _get_sample_id(self, sample_id, hash_table):
# type: (int, Dict[int, Tuple(str, str)]) -> str
return hash_table[sample_id]

def expand(self, pcoll):
return pcoll | beam.Map(self._get_sample_id, self._hash_table)

class ToDictAccumulateCombineFn(beam.CombineFn):
"""CombineFn to create dictionary, but appending values for same keys."""

def create_accumulator(self):
return dict()

def add_input(self, accumulator, element):
key, value = element
if key in accumulator:
accumulator[key].append(value)
else:
accumulator[key] = [value]
return accumulator

def merge_accumulators(self, accumulators):
result = dict()
for a in accumulators:
for k, v in a.items():
if k in result:
result[k].extend(v)
else:
result[k] = v
return result

def extract_output(self, accumulator):
return accumulator

class SampleNameToIdDict(beam.PTransform):
"""Transforms BigQuery table rows to PCollection of `Variant`."""
Expand All @@ -112,8 +49,22 @@ def _convert_bq_row(self, row):
def expand(self, pcoll):
return (pcoll
| 'BigQueryToMapping' >> beam.Map(self._convert_bq_row)
| 'CombineToDict' >> beam.CombineGlobally(
ToDictAccumulateCombineFn()))
| 'CombineToDict' >> beam.combiners.ToDict())

class GetSampleNames(beam.PTransform):
"""Transforms sample_ids to sample_names"""

def __init__(self, hash_table):
# type: (Dict[int, Tuple(str, str)]) -> None
self._hash_table = hash_table

def _get_sample_id(self, sample_id, hash_table):
# type: (int, Dict[int, Tuple(str, str)]) -> str
sample = hash_table[sample_id]
return sample

def expand(self, pcoll):
return pcoll | beam.Map(self._get_sample_id, self._hash_table)

class GetSampleIds(beam.PTransform):
"""Transform sample_names to sample_ids"""
Expand All @@ -124,10 +75,9 @@ def __init__(self, hash_table):

def _get_sample_name(self, sample_name, hash_table):
# type: (str, Dict[str, int]) -> int
return list(set(hash_table[sample_name]))

def print_row(self, row):
return row
if sample_name in hash_table:
return hash_table[sample_name]
raise ValueError('Sample `{}` was not found.'.format(sample_name))

def expand(self, pcoll):
return pcoll | beam.FlatMap(self._get_sample_name, self._hash_table)
return pcoll | beam.Map(self._get_sample_name, self._hash_table)
Loading

0 comments on commit bb8baef

Please sign in to comment.