-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
48 lines (38 loc) · 1.96 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import numpy as np
from segmentation.utils.scribble2mask import FBRS
from propagation.STM.main import STM_Model
class model():
def __init__(self, frames, n_objects, memory_size, fbrs_gpu, stm_gpu):
self.frames = frames
self.memory_size = memory_size
self.n_objects = n_objects
self.n_frames, self.height, self.width = self.frames.shape[:3]
self.current_masks = np.zeros(self.frames.shape[:3], dtype=np.uint8)
self.annotated_frames = []
self.first_interation = False
self.fbrs = FBRS(fbrs_gpu, visualizer=None, external=False)
self.stm = STM_Model("propagation/STM/STM_weights.pth", memory_size,
stm_gpu)
def run_interaction(self, scribbles, range):
target = scribbles['annotated_frame']
if len(self.annotated_frames
) == 0 or self.annotated_frames[-1] != target:
self.annotated_frames.append(target)
annotated_mask = self.fbrs.scribble2mask(self.current_masks[target],
scribbles,
self.frames[target], target,
self.n_objects,
self.first_interation)
self.current_masks[target] = annotated_mask
refined_mask = self.stm.self_refine(self.frames, self.current_masks,
self.n_objects,
self.annotated_frames, range)
self.current_masks[target] = refined_mask
def run_propagation(self, range):
if len(self.annotated_frames) == 0:
return
self.first_interation = False
new_masks = self.stm.propagate(self.frames, self.current_masks,
self.n_objects, self.annotated_frames,
range)
self.current_masks = np.copy(new_masks)