diff --git a/sc2ts/info.py b/sc2ts/info.py index 50cefa1..2fe0292 100644 --- a/sc2ts/info.py +++ b/sc2ts/info.py @@ -1619,6 +1619,7 @@ def draw_subtree( time_scale="rank", title=None, mutation_labels=None, + append_mutation_recurrence=None, size=None, style="", symbol_size=4, @@ -1656,8 +1657,10 @@ def draw_subtree( be placed on the right, and condensed into a dotted line? Default: ``True`` :param mutation_labels dict: A dictionary mapping mutation IDs to labels. If not provided, mutation labels are generated automatically, in the form - ``{inherited_state}{position}{derived_state}``, with counts added if there - are recurrent mutations. + ``{inherited_state}{position}{derived_state}`` + :params append_mutation_dupes bool: If True (default), append a count to the + mutation label indicating the number of other such mutations above the + shown nodes that are at the same position and to the same derived state. :param time_scale str: As for the ``time_scale`` parameter of `draw_svg()`, but defaults to "rank". @@ -1673,6 +1676,8 @@ def draw_subtree( position = 21563 # pick the start of the spike if size is None: size = (1000, 1000) + if append_mutation_recurrence is None: + append_mutation_recurrence = True if remove_clones: # TODO raise NotImplementedError("remove_clones not implemented") @@ -1751,38 +1756,43 @@ def draw_subtree( shown_nodes.append(u) prev = u - if mutation_labels is None: - mutation_labels = collections.defaultdict(list) - multiple_mutations = [] - reverted_mutations = [] - use_mutations = np.where(np.isin(ts.mutations_node, shown_nodes))[0] - sites = ts.mutations_site[use_mutations] - for mut_id in use_mutations: - # TODO Viz the recurrent mutations - mut = ts.mutation(mut_id) - site = ts.site(mut.site) - if np.sum(sites == site.id) > 1: - multiple_mutations.append(mut.id) - inherited_state = site.ancestral_state - if mut.parent >= 0: - parent = ts.mutation(mut.parent) - inherited_state = parent.derived_state - parent_inherited_state = site.ancestral_state - if parent.parent >= 0: - parent_inherited_state = ts.mutation( - parent.parent - ).derived_state - if parent_inherited_state == mut.derived_state: - reverted_mutations.append(mut.id) - # Reverse map label name to mutation id, so we can count duplicates - label = f"{inherited_state}{int(site.position)}{mut.derived_state}" - mutation_labels[label].append(mut.id) + multiple_mutations = [] + reverted_mutations = [] + recurrent_mutations = collections.defaultdict(list) + mut_labels = {} if mutation_labels is None else mutation_labels.copy() + use_mutations = np.where(np.isin(ts.mutations_node, shown_nodes))[0] + sites = ts.mutations_site[use_mutations] + for mut_id in use_mutations: + # TODO Viz the recurrent mutations + mut = ts.mutation(mut_id) + site = ts.site(mut.site) + if np.sum(sites == site.id) > 1: + multiple_mutations.append(mut.id) + inherited_state = site.ancestral_state + if mut.parent >= 0: + parent = ts.mutation(mut.parent) + inherited_state = parent.derived_state + parent_inherited_state = site.ancestral_state + if parent.parent >= 0: + parent_inherited_state = ts.mutation( + parent.parent + ).derived_state + if parent_inherited_state == mut.derived_state: + reverted_mutations.append(mut.id) + pos = int(site.position) + recurrent_mutations[(pos, mut.derived_state)].append(mut.id) + if mutation_labels is None: + mut_labels[mut.id] = f"{inherited_state}{pos}{mut.derived_state}" # If more than one mutation has the same label, add a prefix with the counts - mutation_labels = { - m_id: label + (f" ({i+1}/{len(ids)})" if len(ids) > 1 else "") - for label, ids in mutation_labels.items() - for i, m_id in enumerate(ids) + if append_mutation_recurrence: + num_recurrent = { + m_id: (i+1, len(ids)) + for ids in recurrent_mutations.values() for i, m_id in enumerate(ids) + if len(ids) > 1 } + for m_id, (i, n) in num_recurrent.items(): + if m_id in mut_labels: + mut_labels[m_id] += f" ({i}/{n})" # some default styles styles = [ "".join(f".n{u} > .sym {{fill: cyan}}" for u in tracked_nodes), @@ -1819,7 +1829,7 @@ def draw_subtree( title=title, size=size, order=order, - mutation_labels=mutation_labels, + mutation_labels=mut_labels, all_edge_mutations=True, symbol_size=symbol_size, pack_untracked_polytomies=pack_untracked_polytomies, diff --git a/sc2ts/utils.py b/sc2ts/utils.py index 353da7a..717dfc0 100644 --- a/sc2ts/utils.py +++ b/sc2ts/utils.py @@ -355,7 +355,7 @@ def sort_mutation_label(s): s = ( s.replace("$", "") .replace(r"\bf", "") - .replace("\it", "") + .replace(r"\it", "") .replace("{", "") .replace("}", "") )