Skip to content

Commit

Permalink
remove heuristics
Browse files Browse the repository at this point in the history
  • Loading branch information
liopeer committed Jan 10, 2025
1 parent f6ad8c7 commit 378abce
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 29 deletions.
35 changes: 15 additions & 20 deletions lightly/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,21 @@


def pool_masked(
source: Tensor, mask: Tensor, reduce: str = "mean", num_cls: Optional[int] = None
source: Tensor, mask: Tensor, num_cls: int, reduce: str = "mean"
) -> Tensor:
"""Reduce image feature maps (B, C, H, W) or (C, H, W) according to an integer
index given by `mask` (B, H, W) or (H, W).
"""Reduce image feature maps :math:`(B, C, H, W)` or :math:`(C, H, W)` according to
an integer index given by `mask` :math:`(B, H, W)` or :math:`(H, W)`.
Args:
source: Float tensor of shape (B, C, H, W) or (C, H, W) to be reduced.
mask: Integer tensor of shape (B, H, W) or (H, W) containing the integer indices.
reduce: The reduction operation to be applied, one of 'prod', 'mean', 'amax' or
'amin'. Defaults to 'mean'.
num_cls: The number of classes in the possible masks. If None, the number of classes
is inferred from the unique elements in `mask`. This is useful when not all
classes are present in the mask.
source: Float tensor of shape :math:`(B, C, H, W)` or :math:`(C, H, W)` to be
reduced.
mask: Integer tensor of shape :math:`(B, H, W)` or :math:`(H, W)` containing the
integer indices.
num_cls: The number of classes in the possible masks.
Returns:
A tensor of shape (B, C, N) or (C, N) where N is the number of unique elements
in `mask` or `num_cls` if specified.
A tensor of shape :math:`(B, C, N)` or :math:`(C, N)` where :math:`N`
corresponds to `num_cls`.
"""
if source.dim() == 3:
return _mask_reduce(source, mask, reduce, num_cls)
Expand All @@ -55,29 +53,26 @@ def pool_masked(


def _mask_reduce(
source: Tensor, mask: Tensor, reduce: str = "mean", num_cls: Optional[int] = None
source: Tensor, mask: Tensor, num_cls: int, reduce: str = "mean"
) -> Tensor:
output = _mask_reduce_batched(
source.unsqueeze(0), mask.unsqueeze(0), num_cls=num_cls
source.unsqueeze(0), mask.unsqueeze(0), num_cls=num_cls, reduce=reduce
)
return output.squeeze(0)


def _mask_reduce_batched(
source: Tensor, mask: Tensor, num_cls: Optional[int] = None
source: Tensor, mask: Tensor, num_cls: int, reduce: str = "mean"
) -> Tensor:
b, c, h, w = source.shape
if num_cls is None:
cls = mask.unique(sorted=True)
else:
cls = torch.arange(num_cls, device=mask.device)
cls = torch.arange(num_cls, device=mask.device)
num_cls = cls.size(0)
# create output tensor
output = source.new_zeros((b, c, num_cls)) # (B C N)
mask = mask.unsqueeze(1).expand(-1, c, -1, -1).view(b, c, -1) # (B C HW)
source = source.view(b, c, -1) # (B C HW)
output.scatter_reduce_(
dim=2, index=mask, src=source, reduce="mean", include_self=False
dim=2, index=mask, src=source, reduce=reduce, include_self=False
) # (B C N)
# scatter_reduce_ produces NaNs if the count is zero
output = torch.nan_to_num(output, nan=0.0)
Expand Down
9 changes: 0 additions & 9 deletions tests/models/test_ModelUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,6 @@ def test_masked_pooling_manual(
assert out_manual.shape == (1, 3, 2)
assert (out_manual == expected_result2[:, :2]).all()

def test_masked_pooling_auto(
self, feature_map2: Tensor, mask2: Tensor, expected_result2: Tensor
) -> None:
out_auto = pool_masked(
feature_map2.unsqueeze(0), mask2.unsqueeze(0), num_cls=None
)
assert out_auto.shape == (1, 3, 2)
assert (out_auto == expected_result2[:, :2]).all()

# Type ignore because untyped decorator makes function untyped.
@pytest.mark.parametrize(
"feature_map, mask, expected_result",
Expand Down

0 comments on commit 378abce

Please sign in to comment.