-
Notifications
You must be signed in to change notification settings - Fork 1
/
run.py
executable file
·162 lines (129 loc) · 5.89 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#!/usr/bin/env python
import dynclipy
task = dynclipy.main()
# avoid errors due to no $DISPLAY environment variable available when running sc.pl.paga
import matplotlib
matplotlib.use('Agg')
import pandas as pd
import numpy as np
import h5py
import json
import scanpy.api as sc
import anndata
import numba
import warnings
import time
checkpoints = {}
# ____________________________________________________________________________
# Load data ####
counts = task["counts"]
parameters = task["parameters"]
start_id = task["priors"]["start_id"]
if isinstance(start_id, list):
start_id = start_id[0]
if "groups_id" in task["priors"]:
groups_id = task["priors"]['groups_id']
else:
groups_id = None
# create dataset
if groups_id is not None:
obs = pd.DataFrame(groups_id)
obs.index = groups_id["cell_id"]
obs["louvain"] = obs["group_id"].astype("category")
adata = anndata.AnnData(counts)
adata.obs = obs
else:
adata = anndata.AnnData(counts)
checkpoints["method_afterpreproc"] = time.time()
# ____________________________________________________________________________
# Basic preprocessing ####
# normalisation & filtering
if counts.shape[1] < 100 and parameters["filter_features"]:
print("You have less than 100 features, but the filter_features parameter is true. This will likely result in an error. Disable filter_features to avoid this")
if parameters["filter_features"]:
n_top_genes = min(2000, counts.shape[1])
sc.pp.recipe_zheng17(adata, n_top_genes=n_top_genes)
# precalculating some dimensionality reductions
sc.tl.pca(adata, n_comps=parameters["n_comps"])
with warnings.catch_warnings():
warnings.simplefilter('ignore', numba.errors.NumbaDeprecationWarning)
sc.pp.neighbors(adata, n_neighbors=parameters["n_neighbors"])
# denoise the graph by recomputing it in the first few diffusion components
if parameters["n_dcs"] != 0:
sc.tl.diffmap(adata, n_comps=parameters["n_dcs"])
# ____________________________________________________________________________
# Cluster, infer trajectory, infer pseudotime, compute dimension reduction ###
# add grouping if not provided
if groups_id is None:
sc.tl.louvain(adata, resolution=parameters["resolution"])
# run paga
sc.tl.paga(adata)
# compute a layout for the paga graph
# - this simply uses a Fruchterman-Reingold layout, a tree layout or any other
# popular graph layout is also possible
# - to obtain a clean visual representation, one can discard low-confidence edges
# using the parameter threshold
sc.pl.paga(adata, threshold=0.01, layout='fr', show=False)
# run dpt for pseudotime information that is overlayed with paga
adata.uns['iroot'] = np.where(adata.obs.index == start_id)[0][0]
if parameters["n_dcs"] == 0:
sc.tl.diffmap(adata)
sc.tl.dpt(adata, n_dcs = min(adata.obsm['X_diffmap'].shape[1], 10))
# run umap for a dimension-reduced embedding, use the positions of the paga
# graph to initialize this embedding
if parameters["embedding_type"] == 'umap':
sc.tl.umap(adata, init_pos='paga')
dimred_name = 'X_umap'
else:
sc.tl.draw_graph(adata, init_pos='paga')
dimred_name = "X_draw_graph_" + parameters["embedding_type"]
checkpoints["method_aftermethod"] = time.time()
# ____________________________________________________________________________
# Process & save output ####
# grouping
grouping = pd.DataFrame({"cell_id": adata.obs.index, "group_id": adata.obs.louvain})
# milestone network
milestone_network = pd.DataFrame(
np.triu(adata.uns["paga"]["connectivities"].todense(), k = 0),
index=adata.obs.louvain.cat.categories,
columns=adata.obs.louvain.cat.categories
).stack().reset_index()
milestone_network.columns = ["from", "to", "length"]
milestone_network = milestone_network.query("length >= " + str(parameters["connectivity_cutoff"])).reset_index(drop=True)
milestone_network["directed"] = False
# dimred
dimred = pd.DataFrame([x for x in adata.obsm[dimred_name].T]).T
dimred.columns = ["comp_" + str(i+1) for i in range(dimred.shape[1])]
dimred["cell_id"] = adata.obs.index
# branch progressions: the scaled dpt_pseudotime within every cluster
branch_progressions = adata.obs
branch_progressions["dpt_pseudotime"] = branch_progressions["dpt_pseudotime"].replace([np.inf, -np.inf], 1) # replace unreachable pseudotime with maximal pseudotime
branch_progressions["percentage"] = branch_progressions.groupby("louvain")["dpt_pseudotime"].apply(lambda x: (x-x.min())/(x.max() - x.min())).fillna(0.5)
branch_progressions["cell_id"] = adata.obs.index
branch_progressions["branch_id"] = branch_progressions["louvain"].astype(np.str)
branch_progressions = branch_progressions[["cell_id", "branch_id", "percentage"]]
# branches:
# - length = difference between max and min dpt_pseudotime within every cluster
# - directed = not yet correctly inferred
branches = adata.obs.groupby("louvain").apply(lambda x: x["dpt_pseudotime"].max() - x["dpt_pseudotime"].min()).reset_index()
branches.columns = ["branch_id", "length"]
branches["branch_id"] = branches["branch_id"].astype(np.str)
branches["directed"] = True
# branch network: determine order of from and to based on difference in average pseudotime
branch_network = milestone_network[["from", "to"]]
average_pseudotime = adata.obs.groupby("louvain")["dpt_pseudotime"].mean()
for i, (branch_from, branch_to) in enumerate(zip(branch_network["from"], branch_network["to"])):
if average_pseudotime[branch_from] > average_pseudotime[branch_to]:
branch_network.at[i, "to"] = branch_from
branch_network.at[i, "from"] = branch_to
# save
dataset = dynclipy.wrap_data(cell_ids = adata.obs.index)
dataset.add_branch_trajectory(
grouping = grouping,
branch_progressions = branch_progressions,
branches = branches,
branch_network = branch_network
)
dataset.add_dimred(dimred = dimred)
dataset.add_timings(checkpoints)
dataset.write_output(task["output"])