Skip to content

Commit

Permalink
Scale new instances to new frame size (#1568)
Browse files Browse the repository at this point in the history
* Fix typehinting in `AddInstance`

* brought over changes from my own branch

* added suggestions

* Ensured google style comments

---------

Co-authored-by: roomrys <[email protected]>
Co-authored-by: sidharth srinath <[email protected]>
  • Loading branch information
3 people authored Nov 1, 2023
1 parent dbe14a8 commit cb82d36
Showing 1 changed file with 37 additions and 18 deletions.
55 changes: 37 additions & 18 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class which inherits from `AppCommand` (or a more specialized class such as
from enum import Enum
from glob import glob
from pathlib import Path, PurePath
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union, cast

import attr
import cv2
Expand Down Expand Up @@ -2879,13 +2879,13 @@ def do_action(cls, context: CommandContext, params: dict):
@staticmethod
def create_new_instance(
context: CommandContext,
from_predicted: bool,
copy_instance: Optional[Instance],
from_predicted: Optional[PredictedInstance],
copy_instance: Optional[Union[Instance, PredictedInstance]],
mark_complete: bool,
init_method: str,
location: Optional[QtCore.QPoint],
from_prev_frame: bool,
):
) -> Instance:
"""Create new instance."""

# Now create the new instance
Expand Down Expand Up @@ -2913,14 +2913,15 @@ def create_new_instance(

# If we're copying a predicted instance or from another frame, copy the track
if hasattr(copy_instance, "score") or from_prev_frame:
copy_instance = cast(Union[PredictedInstance, Instance], copy_instance)
new_instance.track = copy_instance.track

return new_instance

@staticmethod
def fill_missing_nodes(
context: CommandContext,
copy_instance: Optional[Instance],
copy_instance: Optional[Union[Instance, PredictedInstance]],
init_method: str,
new_instance: Instance,
location: Optional[QtCore.QPoint],
Expand Down Expand Up @@ -2967,10 +2968,10 @@ def fill_missing_nodes(
@staticmethod
def set_visible_nodes(
context: CommandContext,
copy_instance: Optional[Instance],
copy_instance: Optional[Union[Instance, PredictedInstance]],
new_instance: Instance,
mark_complete: bool,
) -> Tuple[Instance, bool]:
) -> bool:
"""Sets visible nodes for new instance.
Args:
Expand All @@ -2988,15 +2989,28 @@ def set_visible_nodes(

has_missing_nodes = False

# go through each node in skeleton
# Calculate scale factor for getting new x and y values.
old_size_width = copy_instance.frame.video.shape[2]
old_size_height = copy_instance.frame.video.shape[1]
new_size_width = new_instance.frame.video.shape[2]
new_size_height = new_instance.frame.video.shape[1]
scale_width = new_size_width / old_size_width
scale_height = new_size_height / old_size_height

# Go through each node in skeleton.
for node in context.state["skeleton"].node_names:
# if we're copying from a skeleton that has this node
# If we're copying from a skeleton that has this node.
if node in copy_instance and not copy_instance[node].isnan():
# just copy x, y, and visible
# we don't want to copy a PredictedPoint or score attribute
# Ensure x, y inside current frame, then copy x, y, and visible.
# We don't want to copy a PredictedPoint or score attribute.
x_old = copy_instance[node].x
y_old = copy_instance[node].y
x_new = x_old * scale_width
y_new = y_old * scale_height

new_instance[node] = Point(
x=copy_instance[node].x,
y=copy_instance[node].y,
x=x_new,
y=y_new,
visible=copy_instance[node].visible,
complete=mark_complete,
)
Expand All @@ -3007,18 +3021,22 @@ def set_visible_nodes(

@staticmethod
def find_instance_to_copy_from(
context: CommandContext, copy_instance: Optional[Instance], init_method: bool
) -> Tuple[Optional[Instance], bool, bool]:
context: CommandContext,
copy_instance: Optional[Union[Instance, PredictedInstance]],
init_method: bool,
) -> Tuple[
Optional[Union[Instance, PredictedInstance]], Optional[PredictedInstance], bool
]:
"""Find instance to copy from.
Args:
context: The command context.
copy_instance: The instance to copy from.
copy_instance: The `Instance` to copy from.
init_method: The initialization method.
Returns:
The instance to copy from, whether it's from a predicted instance, and
whether it's from a previous frame.
The instance to copy from, the predicted instance (if it is from a predicted
instance, else None), and whether it's from a previous frame.
"""

from_predicted = copy_instance
Expand Down Expand Up @@ -3071,6 +3089,7 @@ def find_instance_to_copy_from(
from_prev_frame = True

from_predicted = from_predicted if hasattr(from_predicted, "score") else None
from_predicted = cast(Optional[PredictedInstance], from_predicted)

return copy_instance, from_predicted, from_prev_frame

Expand Down

0 comments on commit cb82d36

Please sign in to comment.