Skip to content

Commit

Permalink
Merge pull request #387 from jeromekelleher/refactor-match-tsinfer
Browse files Browse the repository at this point in the history
Refactor match_tsinfer
  • Loading branch information
jeromekelleher authored Nov 6, 2024
2 parents 0982c7b + 21b4e9f commit 977c2be
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 190 deletions.
115 changes: 58 additions & 57 deletions sc2ts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,24 @@

logger = logging.getLogger(__name__)


# Common arguments/options

num_mismatches = click.option(
"-k",
"--num-mismatches",
default=3,
show_default=True,
type=float,
help="Number of mismatches to accept in favour of recombination",
)
deletions_as_missing = click.option(
"--deletions-as-missing/--no-deletions-as-missing",
default=True,
help="Treat all deletions as missing data when matching haplotypes",
show_default=True,
)

__before = time.time()


Expand Down Expand Up @@ -364,13 +382,8 @@ def summarise_base(ts, date, progress):
@click.argument("metadata", type=click.Path(exists=True, dir_okay=False))
@click.argument("matches", type=click.Path(exists=True, dir_okay=False))
@click.argument("output_ts", type=click.Path(dir_okay=False))
@click.option(
"--num-mismatches",
default=3,
show_default=True,
type=float,
help="Number of mismatches to accept in favour of recombination",
)
@num_mismatches
@deletions_as_missing
@click.option(
"--hmm-cost-threshold",
default=5,
Expand Down Expand Up @@ -419,12 +432,6 @@ def summarise_base(ts, date, progress):
type=int,
help="Number of days in the past to reconsider potential matches",
)
@click.option(
"--deletions-as-missing/--no-deletions-as-missing",
default=True,
help="Treat all deletions as missing data when matching haplotypes",
show_default=True,
)
@click.option(
"--max-daily-samples",
default=None,
Expand Down Expand Up @@ -541,12 +548,7 @@ def extend(
@click.command()
@click.argument("alignment_db")
@click.argument("ts_file")
@click.option(
"--deletions-as-missing/--no-deletions-as-missing",
default=True,
help="Treat all deletions as missing data when matching haplotypes",
show_default=True,
)
@deletions_as_missing
@click.option("-v", "--verbose", count=True)
def validate(alignment_db, ts_file, deletions_as_missing, verbose):
"""
Expand Down Expand Up @@ -649,16 +651,15 @@ def _match_worker(work):
)
logger.info(f"Start: {msg}")
ts = tszip.load(work.ts_path)
mu, rho = sc2ts.solve_num_mismatches(work.num_mismatches)
matches = sc2ts.match_tsinfer(
sc2ts.match_tsinfer(
samples=work.samples,
ts=ts,
mu=mu,
rho=rho,
num_mismatches=work.num_mismatches,
mismatch_threshold=100,
# FIXME!
deletions_as_missing=False,
num_threads=0,
show_progress=False,
# Maximum possible precision
likelihood_threshold=1e-200,
mirror_coordinates=work.direction == "reverse",
)
runs = []
Expand All @@ -668,18 +669,26 @@ def _match_worker(work):
strain=sample.strain,
num_mismatches=work.num_mismatches,
direction=work.direction,
match=matches[sample.strain],
match=sample.hmm_match,
)
)
logger.info(f"Finish: {msg}")
return runs


@click.command()
@click.command(name="match")
@click.argument("alignments_path", type=click.Path(exists=True, dir_okay=False))
@click.argument("ts_path", type=click.Path(exists=True, dir_okay=False))
@click.argument("strains", nargs=-1)
@click.option("--num-mismatches", default=3, type=int, help="num-mismatches")
@num_mismatches
@deletions_as_missing
@click.option(
"--mismatch-threshold",
type=int,
default=100,
show_default=True,
help="Set the HMM likelihood threshold to this number of mutations",
)
@click.option(
"--direction",
type=click.Choice(["forward", "reverse"]),
Expand All @@ -695,11 +704,13 @@ def _match_worker(work):
@click.option("--progress/--no-progress", default=True)
@click.option("-v", "--verbose", count=True)
@click.option("-l", "--log-file", default=None, type=click.Path(dir_okay=False))
def run_match(
def _match(
alignments_path,
ts_path,
strains,
num_mismatches,
deletions_as_missing,
mismatch_threshold,
direction,
num_threads,
progress,
Expand All @@ -726,26 +737,24 @@ def run_match(
if sample.haplotype is None:
raise ValueError(f"No alignment stored for {sample.strain}")

mu, rho = sc2ts.solve_num_mismatches(num_mismatches)
matches = sc2ts.match_tsinfer(
sc2ts.match_tsinfer(
samples=samples,
ts=ts,
mu=mu,
rho=rho,
num_mismatches=num_mismatches,
deletions_as_missing=deletions_as_missing,
mismatch_threshold=mismatch_threshold,
num_threads=num_threads,
show_progress=progress,
progress_title=progress_title,
progress_phase="HMM",
# Maximum possible precision
likelihood_threshold=1e-200,
mirror_coordinates=direction == "reverse",
)
for strain in strains:
for sample in samples:
run = HmmRun(
strain=strain,
strain=sample.strain,
num_mismatches=num_mismatches,
direction=direction,
match=matches[strain],
match=sample.hmm_match,
)
print(run.asjson())

Expand Down Expand Up @@ -773,14 +782,7 @@ def find_previous_date_path(date, path_pattern):
@click.argument("alignments", type=click.Path(exists=True, dir_okay=False))
@click.argument("ts", type=click.Path(exists=True, dir_okay=False))
@click.argument("path_pattern")
@click.option(
"-k",
"--num-mismatches",
default=[3],
type=int,
multiple=True,
help="num-mismatches",
)
@num_mismatches
@click.option(
"--num-threads",
default=0,
Expand All @@ -790,7 +792,7 @@ def find_previous_date_path(date, path_pattern):
@click.option("--progress/--no-progress", default=True)
@click.option("-v", "--verbose", count=True)
@click.option("-l", "--log-file", default=None, type=click.Path(dir_okay=False))
def run_rematch_recombinants(
def rematch_recombinants(
alignments,
ts,
path_pattern,
Expand Down Expand Up @@ -842,15 +844,14 @@ def run_rematch_recombinants(
work = []
for recombinant, samples in recombinant_to_samples.items():
for direction in ["forward", "reverse"]:
for k in num_mismatches:
work.append(
MatchWork(
recombinant_to_path[recombinant],
samples,
num_mismatches=k,
direction=direction,
)
work.append(
MatchWork(
recombinant_to_path[recombinant],
samples,
num_mismatches=num_mismatches,
direction=direction,
)
)

bar = sc2ts.get_progress(None, progress_title, "HMM", progress, total=len(work))

Expand Down Expand Up @@ -892,6 +893,6 @@ def cli():
cli.add_command(list_dates)
cli.add_command(extend)
cli.add_command(validate)
cli.add_command(run_match)
cli.add_command(run_rematch_recombinants)
cli.add_command(_match)
cli.add_command(rematch_recombinants)
cli.add_command(tally_lineages)
Loading

0 comments on commit 977c2be

Please sign in to comment.