-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
106 lines (91 loc) · 3.67 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from network import *
from helper import *
import config as conf
import argparse
import os, time
import json
import matplotlib.pyplot as plt
import cv2
# Arguments
parser = argparse.ArgumentParser()
parser.add_argument('-train', help='Start a training session. Specify the epochs', type=int)
parser.add_argument('-test', help='Test all images and get result', action='store_true')
parser.add_argument('-webcam', help='Streaming a webcam', action='store_true')
parser.add_argument('-newbatches', help='make new batch files (Training porpuse)', action='store_true')
parser.add_argument('-video', help='train and/or test an entire video', action='store_true')
args = parser.parse_args()
def main():
model = load()
if args.newbatches:
make_training_batch_files()
if args.train:
model = Trainer(model, args.train).run()
if args.test:
Tester(model).run()
if args.video:
prompt = input('Do you want to train some videos? [y/n] ')
if prompt in conf.YES:
# Training
files = get_file_lists()
for file, target in files:
if '.mp4' in file or '.avi' in file:
if target == [0, 1]:
model = VideoAgent(model,
'data/train/with_mask/' + file,
in_test_folder=False).do_train(True)
else:
model = VideoAgent(model,
'data/train/no_mask/' + file,
in_test_folder=False).do_train(False)
print('Training ends. Starting testing...')
print('--------------------------------------')
videos = os.listdir('data/test/video/')
if len(videos) == 0:
# Error if no videos are available
print('No videos found in data/test/video/. Please add any')
for vid in videos:
if '.mp4' in vid or '.avi' in vid:
if not 'out_' in vid:
VideoAgent(model, vid).do_test()
if args.webcam:
if os.path.isfile(conf.webcams_filename) is False:
# Error prehandling
print(f"'{conf.webcams_filename}' does not exist. Please create one and/or adjust the config file")
exit()
# Loading the file
print(f"Loading {conf.webcams_filename}")
try:
f = open(conf.webcams_filename)
obj = json.loads(f.read())
f.close()
except:
print("Failed to load file. Exit")
exit()
agents = []
for cam in obj:
# Start all agents
agent = WebcamAgent(model, cam['url'], cam['name'], run_thread_start=True)
agents.append(agent)
while True:
if conf.show_webcam_images:
for agent in agents:
# Prepare images
if agent.cam_alive and agent.proceed_img is not None:
cv2.imshow(agent.name, agent.proceed_img)
else:
cv2.imshow(agent.name, conf.no_video_stream)
if cv2.waitKey(1) & 0xFF == ord('q'):
# Video play abort
cv2.destroyAllWindows()
break
else:
time.sleep(1)
if __name__ == "__main__":
start_time = time.time()
try:
main()
except KeyboardInterrupt:
print("Keyboard interrupt. Exit")
exit()
print(f"Program took {time.time() - start_time}s")
print("Goodbye")