-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathaveragePrecision.py
159 lines (144 loc) · 6.42 KB
/
averagePrecision.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#!/usr/bin/env python
"""
Script for computing average precision.
Use `averagePrecision.py -h` to see an auto-generated description of advanced options.
"""
import argparse
import os
import shutil
import errno
from tqdm import tqdm
import numpy as np
from genomeloader.wrapper import BedWrapper, BedGraphWrapper, NarrowPeakWrapper, BroadPeakWrapper
from sklearn.metrics import auc
import matplotlib
matplotlib.use('Agg')
from matplotlib import style
import matplotlib.pyplot as plt
def get_args():
parser = argparse.ArgumentParser(description='Evaluating predictions with AP metric.',
epilog='\n'.join(__doc__.strip().split('\n')[1:]).strip(),
formatter_class=argparse.RawTextHelpFormatter)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('-abg', '--abedgraph',
help='BEDGraph of predicted intervals.', type=str)
group.add_argument('-anp', '--anarrowpeak',
help='NarrowPeak of predicted intervals.', type=str)
group.add_argument('-abp', '--abroadpeak',
help='NarrowPeak of predicted intervals.', type=str)
parser.add_argument('-b', '--b', required=True,
help='BED of ground truth intervals.', type=str)
parser.add_argument('-t', '--threshold', required=False, default=0.5,
help='IOU threshold (default: 0.5).', type=float)
parser.add_argument('-bl', '--blacklist', required=False,
default=None,
help='Blacklist BED file.', type=str)
parser.add_argument('-o', '--output', required=False, default=None,
help='Output directory (optional).', type=str)
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument('-c', '--chroms', type=str, nargs='+',
default=['chr1', 'chr8', 'chr21'],
help='Chromosome(s) to evaluate on.')
group.add_argument('-wg', '--wholegenome', action='store_true', default=False,
help='Evaluate on the whole genome.')
group.add_argument('-ax', '--autox', action='store_true', default=False,
help='Evaluate on autosomes and X chromosome.')
args = parser.parse_args()
return args
def main():
args = get_args()
output = args.output
b = BedWrapper(args.b)
if args.abedgraph is not None:
a = BedGraphWrapper(args.abedgraph)
elif args.anarrowpeak is not None:
a = NarrowPeakWrapper(args.anarrowpeak)
else:
a = BroadPeakWrapper(args.abroadpeak)
data_col = a.col_names[a.data_col - 1]
iou_threshold = args.threshold
# Load blacklist file
blacklist_file = args.blacklist
blacklist = None if blacklist_file is None else BedWrapper(blacklist_file)
if blacklist is not None: # clip away parts of BED files that overlap blacklist intervals
new_b_bt = b.bt.subtract(blacklist.bt)
b = BedWrapper(new_b_bt.fn)
new_a_bt = a.bt.subtract(blacklist.bt)
a = BedWrapper(new_a_bt.fn, col_names=a.col_names, channel_last=a.channel_last, data_col=a.data_col, dtype=a.dtype)
if not args.wholegenome:
if args.autox:
chroms = ['chr1', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19',
'chr2', 'chr20', 'chr21', 'chr22', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chrX']
else:
chroms = args.chroms
_, _, a = a.train_valid_test_split(valid_chroms=[], test_chroms=chroms)
_, _, b = b.train_valid_test_split(valid_chroms=[], test_chroms=chroms)
num_gt_peaks = len(b)
true_positives_detected = 0
false_positives_detected = 0
gt_genomic_interval_tree = b.genomic_interval_tree
num_pr_peaks = len(a)
predictions_df = a.df
predictions_df.sort_values(by=data_col, ascending=False, inplace=True)
thresholds = []
recalls = []
precisions = []
labels = []
pbar = tqdm(iterable=predictions_df.itertuples(), total=num_pr_peaks)
for row in pbar:
chrom = getattr(row, 'chrom')
start = getattr(row, 'chromStart')
end = getattr(row, 'chromEnd')
value = getattr(row, data_col)
thresholds.append(value)
chrom_gt_tree = gt_genomic_interval_tree[chrom]
potential_gt_intervals = chrom_gt_tree.overlap(start, end)
overlaps_positive = False
for potential_gt_interval in potential_gt_intervals:
row_iou = iou(start, end, potential_gt_interval.begin, potential_gt_interval.end)
if row_iou >= iou_threshold:
chrom_gt_tree.remove(potential_gt_interval)
true_positives_detected += 1
overlaps_positive = True
break
if not overlaps_positive:
false_positives_detected += 1
labels.append(overlaps_positive)
recalls.append(1.0 * true_positives_detected / num_gt_peaks)
precisions.append(1.0 * true_positives_detected / (true_positives_detected + false_positives_detected))
thresholds = np.array(thresholds)
recalls = np.array(recalls)
precisions = np.array(precisions)
thresholds, unique_indices, unique_counts = np.unique(thresholds, return_index=True, return_counts=True)
unique_indices = unique_indices + unique_counts - 1
recalls = recalls[unique_indices]
precisions = precisions[unique_indices]
ap = auc(recalls, precisions)
print('The average precision is %f' % ap)
jaccard = a.bt.jaccard(b.bt)['jaccard']
print('The Jaccard index is %f' % jaccard)
if output is not None:
plt.ioff()
style.use('ggplot')
try:
os.makedirs(output)
except OSError as exc:
if exc.errno == errno.EEXIST:
shutil.rmtree(output)
os.makedirs(output)
plt.plot(recalls, precisions)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.savefig(output + '/pr.pdf')
np.save(output + '/recalls.npy', recalls)
np.save(output + '/precisions.npy', precisions)
def iou(a_start, a_end, b_start, b_end):
if a_start > b_start:
a_start, a_end, b_start, b_end = b_start, b_end, a_start, a_end
if a_end < b_start:
return 0
intersection = a_end - b_start
union = b_end - a_start
return abs(1.0 * intersection / union)
if __name__ == '__main__':
main()