Skip to content

Commit

Permalink
Merge pull request #398 from hyanwong/mut-plotting
Browse files Browse the repository at this point in the history
Make mutation labelling more flexible
  • Loading branch information
jeromekelleher authored Nov 8, 2024
2 parents 87c99fd + 0708a8e commit 1905fb3
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 34 deletions.
76 changes: 43 additions & 33 deletions sc2ts/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -1619,6 +1619,7 @@ def draw_subtree(
time_scale="rank",
title=None,
mutation_labels=None,
append_mutation_recurrence=None,
size=None,
style="",
symbol_size=4,
Expand Down Expand Up @@ -1656,8 +1657,10 @@ def draw_subtree(
be placed on the right, and condensed into a dotted line? Default: ``True``
:param mutation_labels dict: A dictionary mapping mutation IDs to labels. If not
provided, mutation labels are generated automatically, in the form
``{inherited_state}{position}{derived_state}``, with counts added if there
are recurrent mutations.
``{inherited_state}{position}{derived_state}``
:params append_mutation_dupes bool: If True (default), append a count to the
mutation label indicating the number of other such mutations above the
shown nodes that are at the same position and to the same derived state.
:param time_scale str: As for the ``time_scale`` parameter of `draw_svg()`, but
defaults to "rank".
Expand All @@ -1673,6 +1676,8 @@ def draw_subtree(
position = 21563 # pick the start of the spike
if size is None:
size = (1000, 1000)
if append_mutation_recurrence is None:
append_mutation_recurrence = True
if remove_clones:
# TODO
raise NotImplementedError("remove_clones not implemented")
Expand Down Expand Up @@ -1751,38 +1756,43 @@ def draw_subtree(
shown_nodes.append(u)
prev = u

if mutation_labels is None:
mutation_labels = collections.defaultdict(list)
multiple_mutations = []
reverted_mutations = []
use_mutations = np.where(np.isin(ts.mutations_node, shown_nodes))[0]
sites = ts.mutations_site[use_mutations]
for mut_id in use_mutations:
# TODO Viz the recurrent mutations
mut = ts.mutation(mut_id)
site = ts.site(mut.site)
if np.sum(sites == site.id) > 1:
multiple_mutations.append(mut.id)
inherited_state = site.ancestral_state
if mut.parent >= 0:
parent = ts.mutation(mut.parent)
inherited_state = parent.derived_state
parent_inherited_state = site.ancestral_state
if parent.parent >= 0:
parent_inherited_state = ts.mutation(
parent.parent
).derived_state
if parent_inherited_state == mut.derived_state:
reverted_mutations.append(mut.id)
# Reverse map label name to mutation id, so we can count duplicates
label = f"{inherited_state}{int(site.position)}{mut.derived_state}"
mutation_labels[label].append(mut.id)
multiple_mutations = []
reverted_mutations = []
recurrent_mutations = collections.defaultdict(list)
mut_labels = {} if mutation_labels is None else mutation_labels.copy()
use_mutations = np.where(np.isin(ts.mutations_node, shown_nodes))[0]
sites = ts.mutations_site[use_mutations]
for mut_id in use_mutations:
# TODO Viz the recurrent mutations
mut = ts.mutation(mut_id)
site = ts.site(mut.site)
if np.sum(sites == site.id) > 1:
multiple_mutations.append(mut.id)
inherited_state = site.ancestral_state
if mut.parent >= 0:
parent = ts.mutation(mut.parent)
inherited_state = parent.derived_state
parent_inherited_state = site.ancestral_state
if parent.parent >= 0:
parent_inherited_state = ts.mutation(
parent.parent
).derived_state
if parent_inherited_state == mut.derived_state:
reverted_mutations.append(mut.id)
pos = int(site.position)
recurrent_mutations[(pos, mut.derived_state)].append(mut.id)
if mutation_labels is None:
mut_labels[mut.id] = f"{inherited_state}{pos}{mut.derived_state}"
# If more than one mutation has the same label, add a prefix with the counts
mutation_labels = {
m_id: label + (f" ({i+1}/{len(ids)})" if len(ids) > 1 else "")
for label, ids in mutation_labels.items()
for i, m_id in enumerate(ids)
if append_mutation_recurrence:
num_recurrent = {
m_id: (i+1, len(ids))
for ids in recurrent_mutations.values() for i, m_id in enumerate(ids)
if len(ids) > 1
}
for m_id, (i, n) in num_recurrent.items():
if m_id in mut_labels:
mut_labels[m_id] += f" ({i}/{n})"
# some default styles
styles = [
"".join(f".n{u} > .sym {{fill: cyan}}" for u in tracked_nodes),
Expand Down Expand Up @@ -1819,7 +1829,7 @@ def draw_subtree(
title=title,
size=size,
order=order,
mutation_labels=mutation_labels,
mutation_labels=mut_labels,
all_edge_mutations=True,
symbol_size=symbol_size,
pack_untracked_polytomies=pack_untracked_polytomies,
Expand Down
2 changes: 1 addition & 1 deletion sc2ts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def sort_mutation_label(s):
s = (
s.replace("$", "")
.replace(r"\bf", "")
.replace("\it", "")
.replace(r"\it", "")
.replace("{", "")
.replace("}", "")
)
Expand Down

0 comments on commit 1905fb3

Please sign in to comment.