-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcontrol.py
228 lines (208 loc) · 9.77 KB
/
control.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import argparse
import threading
import time
import pickle as pk
import serial
from filter import BandpassFilter1D, NotchFilter1D
from processing import MeanShift1D, Detrend1D, Resample1D, Normalize1D
from feature_extraction import *
from Ax12 import Ax12
# process signal of each channel
def process_signal1d(x, raw_fs=1000, low_fs=1, high_fs=120, notch_fs=60, Q=20, window_size=250, step_size=50, target_fs=512):
"""
@param x: signal of a single channel
@param raw_fs: original sampling rate
@param low_fs: low cutoff frequency
@param high_fs: high cutoff frequency
@param notch_fs: notch cutoff frequency
@param Q: Q factor
@param window_size: windows size for detrending
@param step_size: step size for detrending
@param target_fs: target sampling rate for resampling step
"""
# mean-correct signal
x_processed = MeanShift1D.apply(x)
# filtering noise
x_processed = BandpassFilter1D.apply(x_processed, low_fs, high_fs, order=4, fs=raw_fs)
x_processed = NotchFilter1D.apply(x_processed, notch_fs, Q=Q, fs=raw_fs)
# detrend
x_processed = Detrend1D.apply(x_processed, detrend_type='locreg', window_size=window_size, step_size=step_size)
# resample
x_processed = Resample1D.apply(x_processed, raw_fs, target_fs)
# rectify
x_processed = abs(x_processed)
# normalize
x_processed = Normalize1D.apply(x_processed, norm_type='min_max')
return x_processed
# process multi-channel signal
def process_signalnd(x, raw_fs=1000, low_fs=1, high_fs=120, notch_fs=60, Q=20, window_size=250, step_size=50, target_fs=512):
"""
@param x: signal of a single channel
@param raw_fs: original sampling rate
@param low_fs: low cutoff frequency
@param high_fs: high cutoff frequency
@param notch_fs: notch cutoff frequency
@param Q: Q factor
@param window_size: windows size for detrending
@param step_size: step size for detrending
@param target_fs: target sampling rate for resampling step
"""
num_channels = x.shape[1]
x_processed = np.array([])
for i in range(num_channels):
# process each channel
channel_processed = process_signal1d(x[:, i], raw_fs, low_fs, high_fs, notch_fs, Q, window_size, step_size, target_fs)
channel_processed = np.expand_dims(channel_processed, axis=1)
if i == 0:
x_processed = channel_processed
continue
x_processed = np.hstack((x_processed, channel_processed))
return x_processed
# The class to connect the electrodes through serial
class SerialPort:
def __init__(self, port='COM1', baud=115200, cls=None, pca=None, controller=None, num_channels=4, interval=1000, timeout=0.1):
super(SerialPort, self).__init__()
self.port = serial.Serial(port, baud)
self.signal = None
self.interval = interval
self.cls = cls
self.num_channels = num_channels
self.timeout = timeout
self.feature_window_size = 50 # Please modify as your setting
self.concat = True # Please change as your setting
self.avg_pool = True # Please change as your setting
self.pca = pca
self.controller = controller
def serial_open(self):
if not self.port.isOpen():
self.port.open()
def serial_close(self):
self.port.close()
def serial_send(self):
print('Send action...')
time.sleep(self.timeout)
if self.action == '0':
# define your '0' action
self.controller.bow(motor_pos=[220, 300, 200, 512]) # set motor positions as your setting
elif self.action == '1':
# define your '1' action
self.controller.shake(motor_pos=[0, 512, 500, 0]) # set motor positions as your setting
elif self.action == '2':
# define your '3' action
self.controller.up(motor_pos=[0, 200, 512, 128]) # set motor positions as your setting
def serial_read(self):
print('Receiving signal...')
self.action = '0'
while True:
values = []
# read signal from serial
for i in range(self.interval):
string = self.port.readline().decode('utf-8').rstrip() # Read and decode a byte string
values.extend([float(value) for value in string.split(' ')])
# reshape signal
signal = np.reshape(np.array(values), (self.interval, self.num_channels), order='C')
# process signal
# please change parameters as your settings
signal_processed = process_signalnd(signal, raw_fs=1000, low_fs=10, high_fs=120, notch_fs=60, Q=20, window_size=512, step_size=50, target_fs=512)
# extract, transpose and flatten feature vectors
# change your feature as your setting
peak = MaxPeak.apply(signal_processed, self.feature_window_size).T.flatten()
mean = Mean.apply(signal_processed, self.feature_window_size).T.flatten()
var = Variance.apply(signal_processed, self.feature_window_size).T.flatten()
std = StandardDeviation.apply(signal_processed, self.feature_window_size).T.flatten()
skew = Skewness.apply(signal_processed, self.feature_window_size).T.flatten()
kurt = Kurtosis.apply(signal_processed, self.feature_window_size).T.flatten()
rms = RootMeanSquare.apply(signal_processed, self.feature_window_size).T.flatten()
if self.concat:
feature = np.hstack([peak, mean, var, std, skew, kurt, rms])
feature = np.expand_dims(feature, axis=0)
else:
feature = np.vstack([peak, mean, var, std, skew, kurt, rms])
if self.avg_pool:
# average pooling
feature = feature.mean(axis=0)
else:
# max pooling
feature = feature.max(axis=0)
feature = np.expand_dims(feature, axis=0)
if self.pca:
feature = self.pca.transform(feature)
y_preds = self.cls.predict(feature)
self.action = str(y_preds)
class RobotController:
def __init__(self, port='COM3', baud=9600, num_motors=4):
self.AX12 = Ax12
self.AX12.DEVICENAME = port
self.AX12.BAUDRATE = baud
self.AX12.connect()
self.dxl_motors = []
self.num_motors = num_motors
for i in range(self.num_motors):
self.dxl_motors.append(self.AX12(i))
def init_pos(self, motor_pos):
"""Initialize positions for each motor"""
for i in range(self.num_motors):
self.dxl_motors[i].set_moving_speed(200) # change the speed as you want
self.dxl_motors[i].set_goal_position(motor_pos[i]) # initialize position of Joint ith
time.sleep(0.1) # set time as you want
def bow(self, motor_pos):
"""
Define the method name with your action by yourself. This method is just an example
@param motor_pos: a list of position of motors
"""
for i in range(self.num_motors):
self.dxl_motors[i].set_moving_speed(200)
self.dxl_motors[i].set_goal_position(motor_pos[i]) # set the goal position
time.sleep(0.1) # set time as you want
def shake(self, motor_pos):
"""
Define the method name with your action by yourself. This method is just an example
@param motor_pos: a list of position of motors
"""
for i in range(self.num_motors):
self.dxl_motors[i].set_moving_speed(200)
self.dxl_motors[i].set_goal_position(motor_pos[i]) # set the goal position
time.sleep(0.1) # set time as you want
def up(self, motor_pos):
"""
Define the method name with your action by yourself. This method is just an example
@param motor_pos: a list of position of motors
"""
for i in range(self.num_motors):
self.dxl_motors[i].set_moving_speed(200)
self.dxl_motors[i].set_goal_position(motor_pos[i]) # set the goal position
time.sleep(0.1) # set time as you want
def disconnect(self):
"""Disconnect the robot"""
self.dxl_motors[0].set_torque_enable(0)
self.AX12.disconnect()
if '__name__' == '__main__':
# Set command line arguments
parser = argparse.ArgumentParser(description='Real-time robot-arm controlling')
parser.add_argument('--arduport', type=str, default='COM1', help='COM port for arduino')
parser.add_argument('--ardubaud', type=int, default=115200, help='Baud rate for arduino')
parser.add_argument('--axport', type=str, default='COM3', help='COM port for dynamix 12')
parser.add_argument('--axbaud', type=int, default=9600, help='Baud rate for dynamix 12')
parser.add_argument('--num-motors', type=int, default=4, help='Number of motors')
parser.add_argument('--channels', type=int, default=4, help='The number of channels')
parser.add_argument('--segment', type=int, default=1000, help='Segmentation interval')
parser.add_argument('--timeout', type=float, default=1, help='Time out')
args = parser.parse_args()
# define reboot
controller = RobotController(port=args.axport, baud=args.axbaud, num_motors=args.num_motors)
controller.init_pos([220, 512, 300, 512])
# define classifier
cls = pk.load(open('svc.pkl', 'rb'))
# define pre-processing pca
pca = pk.load(open('pca.pkl', 'rb'))
# Setup serial line
mserial = SerialPort(args.port, args.baud_rate, cls, pca, controller, args.channels, args.segment, args.timeout)
t1 = threading.Thread(target=mserial.serial_read)
t1.start()
try:
while True:
mserial.serial_send()
except KeyboardInterrupt:
print('Press Ctrl-C to terminate while statement')
mserial.serial_close()
controller.disconnect()