-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathupload_image_data.py
96 lines (73 loc) · 2.22 KB
/
upload_image_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Creates vectors if images and upload to weaviate db
from pathlib import Path
import os
import io
import base64
import weaviate
from PIL import Image
from tqdm import tqdm
from models.dinov2 import DinoV2Embed
def setup_batch(client):
"""
Prepare batching config for Weaviate
"""
client.batch.configure(
batch_size=100,
dynamic=True,
timeout_retries=3,
callback=None
)
def delete_images(client):
"""
Remove all images from vector db
"""
with client.batch as batch:
batch.delete_objects(
class_name='Image',
where={
'operator': 'NotEqual',
'path': ['filepath'],
'valueString': 'x'
},
output='verbose'
)
def img_to_base64(img_path):
"""
img_content is PIL.Image ?
"""
img = Image.open(img_path)
img_format = img.format
img = img.convert('RGB') # PIL.Image.Image
img_bytes = io.BytesIO()
img.save(img_bytes, format=img_format)
img_bytes = img_bytes.getvalue()
return base64.b64encode(img_bytes).decode('utf-8')
def import_data(client, source_path):
"""
Process all images and upload its vector into db
"""
model = DinoV2Embed()
with client.batch as batch:
for img_path in Path(source_path).rglob('**/*.jpg'):
if img_path.is_file():
# print(f'IMG PATH: {img_path}')
tqdm.write(f'IMG PATH: {img_path}')
img_vector = model.embed(img_path)
img_base64 = img_to_base64(img_path)
data_properties = {
'image': img_base64,
'filepath': str(img_path)
}
batch.add_data_object(data_properties, 'Image', vector=img_vector)
if __name__ == '__main__':
WEAVIATE_URL = os.getenv('WEAVIATE_URL')
if not WEAVIATE_URL:
WEAVIATE_URL = 'http://localhost:8080'
client = weaviate.Client(WEAVIATE_URL)
setup_batch(client)
delete_images(client)
# Looks for subdir inside dataset directory
p = Path('dataset')
for child in tqdm(p.iterdir(), disable=None):
tqdm.write(f'DIR: {child}')
import_data(client, child)