-
Notifications
You must be signed in to change notification settings - Fork 14
/
metrics.py
26 lines (20 loc) · 1.04 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
#==========================
# Depth Prediction Metrics
#==========================
def abs_rel_error(pred, gt, mask):
'''Compute absolute relative difference error'''
return ((pred[mask>0] - gt[mask>0]).abs() / gt[mask>0]).mean()
def sq_rel_error(pred, gt, mask):
'''Compute squared relative difference error'''
return (((pred[mask>0] - gt[mask>0]) ** 2) / gt[mask>0]).mean()
def lin_rms_sq_error(pred, gt, mask):
'''Compute the linear RMS error except the final square-root step'''
return ((pred[mask>0] - gt[mask>0]) ** 2).mean()
def log_rms_sq_error(pred, gt, mask):
'''Compute the log RMS error except the final square-root step'''
mask = (mask > 0) & (pred > 1e-7) & (gt > 1e-7) # Compute a mask of valid values
return ((pred[mask].log() - gt[mask].log()) ** 2).mean()
def delta_inlier_ratio(pred, gt, mask, degree=1):
'''Compute the delta inlier rate to a specified degree (def: 1)'''
return (torch.max(pred[mask>0] / gt[mask>0], gt[mask>0] / pred[mask>0]) < (1.25 ** degree)).float().mean()