-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtype_hints.py
51 lines (41 loc) · 1020 Bytes
/
type_hints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from typing import Optional, Sequence, Tuple, TypedDict, TypeVar, Union, OrderedDict
from torch import Tensor
T = TypeVar('T')
KEYPOINT_FORMAT = TypeVar('KEYPOINT_FORMAT')
class ClipData(TypedDict):
label: str
start: float
end: float
video: float
class KeypointData(TypedDict):
image_id: str
category_id: int
keypoints: list[float]
score: float
box: list[float]
idx: list[float]
class Box(TypedDict):
x1: float
y1: float
width: float
height: float
class SignerData(TypedDict):
scores: list[float]
roi: Box
keypoints: list[KeypointData]
class ModelCheckpoint(TypedDict):
epoch: int
model_state_dict: OrderedDict[str, Tensor]
optimizer_state_dict: dict
train_loss: float
val_loss: float
train_loss_hist: list[float]
val_loss_hist: list[float]
ClipSample = Tuple[
Optional[Tensor],
Optional[Sequence[KEYPOINT_FORMAT]],
Union[list[int], Tensor]]
KeypointModelSample = Tuple[
list[Tensor],
Tensor
]