-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathutils.py
115 lines (91 loc) · 3.49 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import base64
import os
import re
from datetime import datetime
from io import BytesIO
import numpy as np
import torch
from PIL import Image
def resize(image: Image.Image, size: int = 512, mode='long_edge', sample=Image.LANCZOS) -> Image.Image:
"""
:param sample: Sample method, PIL.Image.LANCZOS or PIL.Image.NEAREST
:param image: PIL.Image
:param size: int
:param mode: str, 'long_edge' or 'short_edge'
:return: Resized PIL.Image
"""
assert mode in ['long_edge', 'short_edge'], "mode must be 'long_edge' or 'short_edge'"
w, h = image.size
ratio = size / (max(w, h) if mode == 'long_edge' else min(w, h))
w, h = int(w * ratio), int(h * ratio)
# Make sure the size is divisible by 8
w = w - w % 8
h = h - h % 8
return image.resize((w, h), sample)
def image_to_list(image: Image.Image):
return {'image': np.array(image).tolist()}
def list_to_image(image_list: dict):
image_list = image_list['image']
arr = np.array(image_list).astype(np.uint8)
if len(arr.shape) == 3:
img = Image.fromarray(arr).convert('RGB')
elif len(arr.shape) == 2:
img = Image.fromarray(arr).convert('L')
else:
raise ValueError(f"Unknown image shape: {arr.shape}")
return img
def decode_json(data: dict):
for k, v in data.items():
if isinstance(v, dict) and 'image' in v:
data[k] = list_to_image(v)
elif isinstance(v, list):
for i, item in enumerate(v):
if isinstance(item, dict) and 'image' in item:
v[i] = list_to_image(item)
return data
def encode_json(data: dict):
for k, v in data.items():
if isinstance(v, Image.Image):
data[k] = image_to_list(v)
elif isinstance(v, list):
for i, item in enumerate(v):
if isinstance(item, Image.Image):
v[i] = image_to_list(item)
# remove None
data = {k: v for k, v in data.items() if v is not None}
return data
def encode_pil_to_base64(image: Image.Image):
image_data = BytesIO()
image.save(image_data, format='PNG', save_all=True)
image_data_bytes = image_data.getvalue()
encoded_image = base64.b64encode(image_data_bytes).decode('utf-8')
return encoded_image
def decode_pil_from_base64(base64_str):
base64_data = re.sub('^data:image/.+;base64,', '', base64_str)
byte_data = base64.b64decode(base64_data)
image_data = BytesIO(byte_data)
img = Image.open(image_data)
return img
def move_to_cache(image_path, cache_path="./cache"):
# create a folder named with date
cache_path = os.path.join(cache_path, datetime.now().strftime("%Y%m%d"))
if not os.path.exists(cache_path):
os.makedirs(cache_path)
# rename as timestamp HHMMSS and move to cache folder
timestamp = datetime.now().strftime("%H%M%S")
cache_path = os.path.join(cache_path, f'{timestamp}{os.path.splitext(image_path)[1]}')
os.system(f"mv {image_path} {cache_path}")
# return the abspath
return os.path.abspath(cache_path)
def assemble_response(data: dict):
# print(data)
data = encode_json(data)
time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f"[{time}] " + ", ".join([f"{k}: {v if isinstance(v, str) else 'Image'}" for k, v in data.items()]))
data.update({"status": 200, "time": time})
return data
def torch_gc(device):
if torch.cuda.is_available():
with torch.cuda.device(device):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()