-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTest.py
56 lines (46 loc) · 1.65 KB
/
Test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
import shutil
import numpy as np
from PIL import ImageFile, Image, ImageOps
import matplotlib.pyplot as plt
from dataloaders.data_util.utils import get_rice_encode, decode_segmap
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
# A dataset
masks = str(ROOT_DIR) + '/masks/'
m_img_files = sorted(os.listdir(masks))
def mask_to_class(img, color_codes=get_rice_encode(), one_hot_encode=False):
if color_codes is None:
color_codes = {val: i for i, val in enumerate(set(tuple(v) for m2d in img for v in m2d))}
n_labels = len(color_codes)
result = np.ndarray(shape=img.shape[:2], dtype=int)
result[:, :] = -1
for rgb, idx in color_codes.items():
# print(rgb, idx) # (img == rgb).all(2)
result[np.where(img == rgb)] = idx
if one_hot_encode:
one_hot_labels = np.zeros((img.shape[0], img.shape[1], n_labels))
# one-hot encoding
for c in range(n_labels):
one_hot_labels[:, :, c] = (result == c).astype(int)
result = one_hot_labels
return result, color_codes
#
# count = 0
# import cv2
# file = m_img_files
# for file in m_img_files:
# _target = Image.open(masks + str(file))
# _tmp = np.array(_target, dtype=np.uint8)
# print(_tmp.shape)
#
# img, _ = mask_to_class(_tmp)
# img = decode_segmap(img, dataset='rweeds', plot=False)
#
# segmap = np.array(img * 255).astype(np.uint8)
#
# rgb_img = cv2.resize(segmap, (_tmp.shape[1], _tmp.shape[0]),
# interpolation=cv2.INTER_NEAREST)
# bgr = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2BGR)
# cv2.imwrite("masks/results/" + str(count) + "_result.png", bgr)
# count = count + 1
#