-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathutils.py
194 lines (168 loc) · 6.31 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
# -*- coding:utf-8 -*-
import os
import cv2
import pandas as pd
import numpy as np
import configparser as cp
import matplotlib.pyplot as plt
RAW_DATA_PATH = '/home/tony/fall_research/fall_data/MobiAct_Dataset_v2.0/Annotated Data/'
Label = {'STD': 1, 'WAL': 2, 'JOG': 3, 'JUM': 4, 'STU': 5, 'STN': 6, 'SCH': 7, 'SIT': 8, 'CHU': 9,
'LYI': 10, 'FOL': 0, 'FKL': 0, 'BSC': 0, 'SDL': 0, 'CSI': 15, 'CSO': 16}
def extract_data(data_file, sampling_frequency):
"""
从mobileFall中提取数据,用于做实验测试
:param data_file: 原始数据文件
:param sampling_frequency: 原始数据采集频率
:return:
"""
data = pd.read_csv(data_file, index_col=0)
data_size = len(data.label)
for i in range(data_size):
data.iat[i, 10] = Label[data.iloc[i, 10]]
col_data = np.arange(0, data_size, int(sampling_frequency/50))
extract_data = data.iloc[col_data, [1, 2, 3, 4, 5, 6, 10]]
save_path = './dataset/raw/' + os.path.abspath(os.path.dirname(data_file)+os.path.sep+".").replace(RAW_DATA_PATH, '')
if not os.path.exists(save_path):
os.makedirs(save_path)
save_path = './dataset/raw/' + data_file.replace(RAW_DATA_PATH, '')
extract_data.to_csv(save_path, index=0)
def find_all_data_and_extract(path):
"""
递归的查找所有文件并进行转化
:param path:
:return:
"""
if not os.path.exists(path):
print('路径存在问题:', path)
return None
for i in os.listdir(path):
if os.path.isfile(path+"/"+i):
if 'csv' in i:
extract_data(path+"/"+i, 200)
else:
find_all_data_and_extract(path+"/"+i)
def parser_cfg_file(cfg_file):
"""
读取配置文件中的信息
:param cfg_file: 文件路径
:return:
"""
content_params = {}
config = cp.ConfigParser()
config.read(cfg_file)
for section in config.sections():
# 获取配置文件中的net信息
if section == 'net':
for option in config.options(section):
content_params[option] = config.get(section,option)
# 获取配置文件中的train信息
if section == 'train':
for option in config.options(section):
content_params[option] = config.get(section,option)
return content_params
def show_data(data, name=None):
'''
show data
:param data: DataFrame
:return:
'''
num = data.acc_x.size
x = np.arange(num)
fig = plt.figure(1, figsize=(100, 60))
# 子表1绘制加速度传感器数据
plt.subplot(2, 1, 1)
plt.title('acc')
plt.plot(x, data.acc_x, label='x')
plt.plot(x, data.acc_y, label='y')
plt.plot(x, data.acc_z, label='z')
# 添加解释图标
plt.legend()
x_flag = np.arange(0, num, num / 10)
plt.xticks(x_flag)
# 子表2绘制陀螺仪传感器数据
plt.subplot(2, 1, 2)
plt.title('gyro')
plt.plot(x, data.gyro_x, label='x')
plt.plot(x, data.gyro_y, label='y')
plt.plot(x, data.gyro_z, label='z')
plt.legend()
plt.xticks(x_flag)
#plt.show()
if name is None:
plt.show()
else:
plt.savefig(name)
plt.close()
def kalman_filter(data):
kalman = cv2.KalmanFilter(6, 6)
kalman.measurementMatrix = np.array([[1, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 1]], np.float32)
kalman.transitionMatrix = np.array([[1, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 1]], np.float32)
kalman.processNoiseCov = np.array([[1, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 1]], np.float32) * 0.003
kalman.measurementNoiseCov = np.array([[1, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 1]], np.float32) * 1
row_num = data.acc_x.size
for i in range(row_num):
correct = np.array(data.iloc[i, 0:6].values, np.float32).reshape([6, 1])
kalman.correct(correct)
predict = kalman.predict()
data.iloc[i, 0] = predict[0]
data.iloc[i, 1] = predict[1]
data.iloc[i, 2] = predict[2]
data.iloc[i, 3] = predict[3]
data.iloc[i, 4] = predict[4]
data.iloc[i, 5] = predict[5]
return data
def find_all_data_and_filtrate(path):
"""
递归的查找所有文件并进行kalman过滤
:param path:
:return:
"""
if not os.path.exists(path):
print('路径存在问题:', path)
return None
for i in os.listdir(path):
if os.path.isfile(path+"/"+i):
if 'csv' in i:
data = pd.read_csv(path+"/"+i)
data = kalman_filter(data)
data.to_csv(path+"/"+i, index=False)
else:
find_all_data_and_filtrate(path+"/"+i)
def main():
#find_all_data_and_extract(RAW_DATA_PATH)
find_all_data_and_filtrate('./dataset/kalman/')
if __name__ == '__main__':
main()
# if os.path.exists('./dataset/train/BSC_1_1_annotated.csv') == False:
# print('./dataset/train/BSC_1_1_annotated.csv', '文件不存在!')
# data = pd.read_csv('./dataset/train/BSC_1_1_annotated.csv')
#
# #show_data(data)
# data = kalman_filter(data)
# data.to_csv('./dataset/train/BSC_1_1_annotated.csv', index=False)
# #show_data(data)
# # a = data.iloc[4:5,0]
# # print(a)
# data = pd.read_csv('./dataset/train/STU_1_1_annotated.csv')
#
# show_data(data)