forked from monemati/PX4-ROS2-Gazebo-YOLOv8
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathuav_camera_det_ssd.py
102 lines (83 loc) · 3.53 KB
/
uav_camera_det_ssd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Pytorch SSD library
from vision.ssd.vgg_ssd import create_vgg_ssd, create_vgg_ssd_predictor
from vision.ssd.mobilenetv1_ssd import create_mobilenetv1_ssd, create_mobilenetv1_ssd_predictor
from vision.ssd.mobilenetv1_ssd_lite import create_mobilenetv1_ssd_lite, create_mobilenetv1_ssd_lite_predictor
from vision.ssd.squeezenet_ssd_lite import create_squeezenet_ssd_lite, create_squeezenet_ssd_lite_predictor
from vision.ssd.mobilenet_v2_ssd_lite import create_mobilenetv2_ssd_lite, create_mobilenetv2_ssd_lite_predictor
from vision.utils.misc import Timer
# Import the necessary libraries
import rclpy # Python library for ROS 2
from rclpy.node import Node # Handles the creation of nodes
from sensor_msgs.msg import Image # Image is the message type
from cv_bridge import CvBridge # Package to convert between ROS and OpenCV Images
import cv2 # OpenCV library
net_type = "mb1-ssd"
model_path = "ssd_models/trafic_small_512/model.pth"
label_path = "ssd_models/trafic_small_512/labels.txt"
#image_path = sys.argv[4]
class_names = [name.strip() for name in open(label_path).readlines()]
net = create_mobilenetv1_ssd(len(class_names), is_test=True)
net.load(model_path)
predictor = create_mobilenetv1_ssd_predictor(net, candidate_size=200)
class ImageSubscriber(Node):
"""
Create an ImageSubscriber class, which is a subclass of the Node class.
"""
def __init__(self):
"""
Class constructor to set up the node
"""
# Initiate the Node class's constructor and give it a name
super().__init__('image_subscriber')
# Create the subscriber. This subscriber will receive an Image
# from the video_frames topic. The queue size is 10 messages.
self.subscription = self.create_subscription(
Image,
'camera',
self.listener_callback,
10)
self.subscription # prevent unused variable warning
# Used to convert between ROS and OpenCV images
self.br = CvBridge()
def listener_callback(self, data):
"""
Callback function.
"""
# Display the message on the console
self.get_logger().info('Receiving video frame')
# Convert ROS Image message to OpenCV image
current_frame = self.br.imgmsg_to_cv2(data, desired_encoding="bgr8")
image = current_frame
# Object Detection
imageForSSD = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
boxes, labels, probs = predictor.predict(imageForSSD, 10, 0.2)
for i in range(boxes.size(0)):
box = boxes[i, :]
print(box)
cv2.rectangle(image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255, 255, 0), 4)
#label = f"""{voc_dataset.class_names[labels[i]]}: {probs[i]:.2f}"""
label = f"{class_names[labels[i]]}: {probs[i]:.2f}"
cv2.putText(image, label,
(int(box[0]) + 20, int(box[1]) + 40),
cv2.FONT_HERSHEY_SIMPLEX,
1, # font scale
(255, 0, 255),
2) # line type
# Show Results
cv2.imshow('Detected Frame', image)
cv2.waitKey(1)
def main(args=None):
# Initialize the rclpy library
rclpy.init(args=args)
# Create the node
image_subscriber = ImageSubscriber()
# Spin the node so the callback function is called.
rclpy.spin(image_subscriber)
# Destroy the node explicitly
# (optional - otherwise it will be done automatically
# when the garbage collector destroys the node object)
image_subscriber.destroy_node()
# Shutdown the ROS client library for Python
rclpy.shutdown()
if __name__ == '__main__':
main()