-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
136 lines (106 loc) · 5.88 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""Testing script for the visual grounding pipeline based on segmentation and CLIP."""
import argparse
import torch
from torch.utils.data import random_split
from modules.pipelines.clipseg import ClipSeg
from modules.pipelines.clipssd import ClipSSD
from modules.pipelines.detrclip import DetrClip
from modules.pipelines.mdetr import MDETRvg
from modules.pipelines.yoloclip import YoloClip
from modules.refcocog import RefCOCOg
from modules.utilities import visual_grounding_test, get_best_device
supported_pipelines = ["yoloclip", "clipseg", "detrclip", "clipssd", "mdetr"]
def main(args):
if args.pipeline not in supported_pipelines:
raise ValueError(f"Pipeline `{args.pipeline}` not supported. Supported pipelines are: {supported_pipelines}.")
device = get_best_device()
dataset = RefCOCOg(ds_path=args.datapath)
test_ds = RefCOCOg(ds_path=args.datapath, split='test')
if args.red_dataset is not None:
print(f"[INFO] Reducing dataset to {args.reduce_dataset * 100}% of its original size.")
keep = args.reduce_dataset
dataset, _ = random_split(dataset, [int(keep * len(dataset)), len(dataset) - int(keep * len(dataset))])
test_ds, _ = random_split(test_ds, [int(keep * len(test_ds)), len(test_ds) - int(keep * len(test_ds))])
print(f"[INFO] Dataset Size: {len(dataset)}")
print(f"[INFO] test split: {len(test_ds)}")
if args.clip_version is None:
args.clip_version = "RN50"
print(f"[INFO] No CLIP version specified. Using {args.clip_version}.")
else:
print(f"[INFO] Using CLIP version: {args.clip_version}")
if args.pipeline == "yoloclip":
if args.yolo_version is None:
args.yolo_version = "yolov8x"
print(f"[INFO] No YOLO version specified. Using {args.yolo_version}")
pipeline = YoloClip(dataset.categories,
clip_ver=args.clip_version,
yolo_ver=args.yolo_version,
device=device)
if args.pipeline == "clipseg":
if args.seg_method is None:
args.seg_method = "w"
print(f"[INFO] No segmentation method specified. Using Watershed.")
if args.n_segments is None:
args.n_segments = (4, 8, 16, 32)
print(f"[INFO] No number of segments specified. Using {args.n_segments}.")
if args.threshold is None:
args.threshold = 0.75
print(f"[INFO] No threshold specified. Using {args.threshold}.")
pipeline = ClipSeg(dataset.categories,
clip_ver=args.clip_version,
method=args.seg_method,
n_segments=args.n_segments,
q=args.threshold,
quiet=True,
device=device)
if args.pipeline == "detrclip":
pipeline = DetrClip(dataset.categories,
clip_ver=args.clip_version,
device=device)
if args.pipeline == "clipssd":
if args.confidence_t is None:
raise ValueError(f"Pipeline `{args.pipeline}` need the following arguments:"
f"`confidence_t`.")
pipeline = ClipSSD(dataset.categories,
clip_ver=args.clip_version,
confidence_t=args.confidence_t,
device=device)
if args.pipeline == "mdetr":
pipeline = MDETRvg(dataset.categories,
clip_ver=args.clip_version,
device=device)
if args.clip_pth is not None:
checkpoint = torch.load(args.clip_pth, map_location=device)
pipeline.clip_model.load_state_dict(checkpoint['model_state_dict'])
print(f"[INFO] Fine-tuned CLIP model loaded from {args.clip_pth}.")
print(f"[INFO] Starting test\n")
visual_grounding_test(pipeline, test_ds, logging=args.logging)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Test the visual grounding pipeline.',
formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=50))
parser.add_argument('-p', '--pipeline', type=str,
help='Pipeline to test (yoloclip or clipseg).')
parser.add_argument('-dp', '--datapath', type=str, default="dataset/refcocog",
help='path to the dataset.')
parser.add_argument('-lg', '--logging', action='store_true',
help='Whether to log the results or not.')
parser.add_argument('-rd', '--red_dataset', type=float, default=None,
help='Whether to use a reduced version of the dataset or not')
parser.add_argument('-cv', '--clip_version', type=str,
help='CLIP version to use (RN50, RN101, ViT-L/14)')
parser.add_argument('-yv', '--yolo_version', type=str,
help='Yolo version to use (yolov5s, yolov8x). [only for yoloclip]')
parser.add_argument('-sm', '--seg_method', type=str,
help='Method to use for segmentation (`s`for SLIC or `w` for Watershed) [only for segclip].')
parser.add_argument('-ns', '--n_segments', type=list,
help='Number of segments to use for segmentation [only for segclip].')
parser.add_argument('-ts', '--threshold', type=float,
help='Threshold for filtering CLIP heatmap [only for segclip].')
parser.add_argument('-ds', '--downsampling', type=int,
help='Heatmap downsampling factor [only for clipseg].')
parser.add_argument('-ct', '--confidence_t', type=float,
help='Confidence t for Single Shot Detection [only for clipssd].')
parser.add_argument('-cp', '--clip_pth', type=str,
help='Path to a fine-tuned CLIP model state-dict.')
args = parser.parse_args()
main(args)