diff --git a/examples/pooled_texture_model.ipynb b/examples/pooled_texture_model.ipynb index c0bb7551..4c3152a6 100644 --- a/examples/pooled_texture_model.ipynb +++ b/examples/pooled_texture_model.ipynb @@ -28,28 +28,30 @@ } ], "source": [ - "import plenoptic as po\n", - "import torch\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", "import einops\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "\n", + "import plenoptic as po\n", + "\n", "# so that relative sizes of axes created by po.imshow and others look right\n", - "plt.rcParams['figure.dpi'] = 72\n", + "plt.rcParams[\"figure.dpi\"] = 72\n", "\n", "# set seed for reproducibility\n", "po.tools.set_seed(1)\n", "\n", "# use GPU if available\n", - "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "%matplotlib inline\n", "\n", "\n", "# Animation-related settings\n", - "plt.rcParams['animation.html'] = 'html5'\n", + "plt.rcParams[\"animation.html\"] = \"html5\"\n", "# use single-threaded ffmpeg for animation writer\n", - "plt.rcParams['animation.writer'] = 'ffmpeg'\n", - "plt.rcParams['animation.ffmpeg_args'] = ['-threads', '1']" + "plt.rcParams[\"animation.writer\"] = \"ffmpeg\"\n", + "plt.rcParams[\"animation.ffmpeg_args\"] = [\"-threads\", \"1\"]" ] }, { @@ -111,9 +113,9 @@ " mask_sz = int(256 // len(mask))\n", " for j, m in enumerate(mask):\n", " if i == 0:\n", - " m[..., j*mask_sz:(j+1)*mask_sz, :] = 1\n", + " m[..., j * mask_sz : (j + 1) * mask_sz, :] = 1\n", " else:\n", - " m[..., j*mask_sz:(j+1)*mask_sz] = 1\n", + " m[..., j * mask_sz : (j + 1) * mask_sz] = 1\n", " masks.append(torch.cat(mask, 0))\n", "print([m.shape for m in masks])" ] @@ -144,8 +146,8 @@ } ], "source": [ - "masked = einops.einsum(*masks, img, 'm0 h w, m1 h w, b c h w -> b c m0 m1 h w')\n", - "po.imshow(einops.rearrange(masked, 'b c m0 m1 h w -> b (c m0 m1) h w'));" + "masked = einops.einsum(*masks, img, \"m0 h w, m1 h w, b c h w -> b c m0 m1 h w\")\n", + "po.imshow(einops.rearrange(masked, \"b c m0 m1 h w -> b (c m0 m1) h w\"));" ] }, { @@ -179,7 +181,7 @@ "for i, m in enumerate(mask):\n", " x = int(i // np.sqrt(len(mask)))\n", " y = int(i % np.sqrt(len(mask)))\n", - " m[..., x*mask_sz:(x+1)*mask_sz, y*mask_sz:(y+1)*mask_sz] = 1\n", + " m[..., x * mask_sz : (x + 1) * mask_sz, y * mask_sz : (y + 1) * mask_sz] = 1\n", "single_masks = torch.cat(mask, 0)\n", "single_masks.shape" ] @@ -202,8 +204,8 @@ } ], "source": [ - "masked = einops.einsum(single_masks, img, 'm0 h w, b c h w -> b c m0 h w')\n", - "po.imshow(einops.rearrange(masked, 'b c m0 h w -> b (c m0) h w'));" + "masked = einops.einsum(single_masks, img, \"m0 h w, b c h w -> b c m0 h w\")\n", + "po.imshow(einops.rearrange(masked, \"b c m0 h w -> b (c m0) h w\"));" ] }, { @@ -237,7 +239,7 @@ } ], "source": [ - "einops.einsum(*masks, 'm0 h w, m1 h w -> m0 m1')" + "einops.einsum(*masks, \"m0 h w, m1 h w -> m0 m1\")" ] }, { @@ -269,8 +271,8 @@ } ], "source": [ - "masks = [m/64 for m in masks]\n", - "einops.einsum(*masks, 'm0 h w, m1 h w -> m0 m1')" + "masks = [m / 64 for m in masks]\n", + "einops.einsum(*masks, \"m0 h w, m1 h w -> m0 m1\")" ] }, { @@ -444,7 +446,11 @@ } ], "source": [ - "met = po.synth.MetamerCTF(img, ps_mask, loss_function=po.tools.optim.l2_norm,)" + "met = po.synth.MetamerCTF(\n", + " img,\n", + " ps_mask,\n", + " loss_function=po.tools.optim.l2_norm,\n", + ")" ] }, { @@ -470,7 +476,9 @@ } ], "source": [ - "met.synthesize(max_iter=500, store_progress=10, change_scale_criterion=None, ctf_iters_to_check=7)" + "met.synthesize(\n", + " max_iter=500, store_progress=10, change_scale_criterion=None, ctf_iters_to_check=7\n", + ")" ] }, { @@ -507,7 +515,9 @@ } ], "source": [ - "po.synth.metamer.plot_synthesis_status(met, width_ratios={'plot_representation_error': 2});" + "po.synth.metamer.plot_synthesis_status(\n", + " met, width_ratios={\"plot_representation_error\": 2}\n", + ");" ] }, { @@ -7428,7 +7438,7 @@ } ], "source": [ - "po.synth.metamer.animate(met, width_ratios={'plot_representation_error': 2})" + "po.synth.metamer.animate(met, width_ratios={\"plot_representation_error\": 2})" ] }, { @@ -7458,7 +7468,11 @@ ], "source": [ "img = po.data.einstein().to(DEVICE)\n", - "met = po.synth.MetamerCTF(img, ps_mask, loss_function=po.tools.optim.l2_norm,)" + "met = po.synth.MetamerCTF(\n", + " img,\n", + " ps_mask,\n", + " loss_function=po.tools.optim.l2_norm,\n", + ")" ] }, { @@ -7476,7 +7490,9 @@ } ], "source": [ - "met.synthesize(max_iter=500, store_progress=10, change_scale_criterion=None, ctf_iters_to_check=7)" + "met.synthesize(\n", + " max_iter=500, store_progress=10, change_scale_criterion=None, ctf_iters_to_check=7\n", + ")" ] }, { @@ -7505,7 +7521,9 @@ } ], "source": [ - "po.synth.metamer.plot_synthesis_status(met, width_ratios={'plot_representation_error': 2});" + "po.synth.metamer.plot_synthesis_status(\n", + " met, width_ratios={\"plot_representation_error\": 2}\n", + ");" ] }, { @@ -7528,7 +7546,8 @@ "outputs": [], "source": [ "import sys\n", - "sys.path.append('../../pooling-windows/')\n", + "\n", + "sys.path.append(\"../../pooling-windows/\")\n", "import pooling" ] }, @@ -7556,7 +7575,7 @@ } ], "source": [ - "pw = pooling.PoolingWindows(.5, img.shape[-2:], max_eccentricity=5)\n", + "pw = pooling.PoolingWindows(0.5, img.shape[-2:], max_eccentricity=5)\n", "pooling_windows = [pw.ecc_windows[0], pw.angle_windows[0]]" ] }, @@ -7617,7 +7636,9 @@ } ], "source": [ - "plt.stem(po.to_numpy(einops.einsum(*pooling_windows, 'm1 h w, m2 h w -> m1 m2').flatten()));" + "plt.stem(\n", + " po.to_numpy(einops.einsum(*pooling_windows, \"m1 h w, m2 h w -> m1 m2\").flatten())\n", + ");" ] }, { @@ -7695,7 +7716,9 @@ ], "source": [ "pooling_windows = [pw.ecc_windows[0][:-1], pw.angle_windows[0]]\n", - "plt.stem(po.to_numpy(einops.einsum(*pooling_windows, 'm1 h w, m2 h w -> m1 m2').flatten()));" + "plt.stem(\n", + " po.to_numpy(einops.einsum(*pooling_windows, \"m1 h w, m2 h w -> m1 m2\").flatten())\n", + ");" ] }, { @@ -7723,7 +7746,11 @@ "metadata": {}, "outputs": [], "source": [ - "met = po.synth.MetamerCTF(img, ps_mask, loss_function=po.tools.optim.l2_norm,)" + "met = po.synth.MetamerCTF(\n", + " img,\n", + " ps_mask,\n", + " loss_function=po.tools.optim.l2_norm,\n", + ")" ] }, { @@ -7741,7 +7768,9 @@ } ], "source": [ - "met.synthesize(max_iter=500, store_progress=10, change_scale_criterion=None, ctf_iters_to_check=7)" + "met.synthesize(\n", + " max_iter=500, store_progress=10, change_scale_criterion=None, ctf_iters_to_check=7\n", + ")" ] }, { @@ -7762,7 +7791,9 @@ } ], "source": [ - "po.synth.metamer.plot_synthesis_status(met, width_ratios={'plot_representation_error': 2});" + "po.synth.metamer.plot_synthesis_status(\n", + " met, width_ratios={\"plot_representation_error\": 2}\n", + ");" ] }, { @@ -14163,7 +14194,7 @@ } ], "source": [ - "po.synth.metamer.animate(met, width_ratios={'plot_representation_error': 2})" + "po.synth.metamer.animate(met, width_ratios={\"plot_representation_error\": 2})" ] }, { @@ -14192,7 +14223,7 @@ } ], "source": [ - "fig = po.imshow([img, met.metamer], title=['Target image', 'Metamer']);\n", + "fig = po.imshow([img, met.metamer], title=[\"Target image\", \"Metamer\"])\n", "for ax in fig.axes:\n", " pw.plot_windows(ax, subset=False)" ] diff --git a/src/plenoptic/simulate/models/portilla_simoncelli_masked.py b/src/plenoptic/simulate/models/portilla_simoncelli_masked.py index 527fd1cf..cabb5a54 100644 --- a/src/plenoptic/simulate/models/portilla_simoncelli_masked.py +++ b/src/plenoptic/simulate/models/portilla_simoncelli_masked.py @@ -6,8 +6,9 @@ images have the same values for all PS texture stats, humans should consider them as members of the same family of textures. """ + from collections import OrderedDict -from typing import List, Optional, Tuple, Union +from typing import Literal import einops import matplotlib as mpl @@ -17,17 +18,18 @@ import torch.fft import torch.nn as nn from torch import Tensor -from typing_extensions import Literal from ...tools import signal from ...tools.conv import blur_downsample from ...tools.data import to_numpy from ...tools.display import clean_stem_plot, clean_up_axes, update_stem from ...tools.validate import validate_input +from ..canonical_computations.steerable_pyramid_freq import ( + SCALES_TYPE as PYR_SCALES_TYPE, +) from ..canonical_computations.steerable_pyramid_freq import SteerablePyramidFreq -from ..canonical_computations.steerable_pyramid_freq import SCALES_TYPE as PYR_SCALES_TYPE -SCALES_TYPE = Union[Literal["pixel_statistics"], PYR_SCALES_TYPE] +SCALES_TYPE = Literal["pixel_statistics"] | PYR_SCALES_TYPE class PortillaSimoncelliMasked(nn.Module): @@ -82,8 +84,8 @@ class PortillaSimoncelliMasked(nn.Module): def __init__( self, - image_shape: Tuple[int, int], - mask: List[Tensor], + image_shape: tuple[int, int], + mask: list[Tensor], n_scales: int = 4, n_orientations: int = 4, spatial_corr_width: int = 9, @@ -91,48 +93,56 @@ def __init__( super().__init__() self.image_shape = image_shape - if (any([(image_shape[-1] / 2**i) % 2 for i in range(n_scales)]) or - any([(image_shape[-2] / 2**i) % 2 for i in range(n_scales)])): - raise ValueError("Because of how the Portilla-Simoncelli model handles " - "multiscale representations, it only works with images" - " whose shape can be divided by 2 `n_scales` times.") + if any([(image_shape[-1] / 2**i) % 2 for i in range(n_scales)]) or any( + [(image_shape[-2] / 2**i) % 2 for i in range(n_scales)] + ): + raise ValueError( + "Because of how the Portilla-Simoncelli model handles " + "multiscale representations, it only works with images" + " whose shape can be divided by 2 `n_scales` times." + ) if any([m.ndim != 3 for m in mask]): raise ValueError("All masks must be 3d!") if any([m.shape[-2:] != image_shape for m in mask]): - raise ValueError("Last two dimensions of mask must be height and width and must match image_shape!") + raise ValueError( + "Last two dimensions of mask must be height and width" + " and must match image_shape!" + ) if any([m.min() < 0 for m in mask]): raise ValueError("All masks must be non-negative!") # we need to downsample the masks for each scale, plus one additional scale for # the reconstructed lowpass image - for i in range(n_scales+1): + for i in range(n_scales + 1): if i == 0: scale_mask = mask else: # multiply by the factor of four in order to keep the sum # approximately equal across scales. - scale_mask = [4**(i/len(mask)) * blur_downsample(m.unsqueeze(0), - i, scale_filter=True).squeeze(0) - for m in mask] + scale_mask = [ + 4 ** (i / len(mask)) + * blur_downsample(m.unsqueeze(0), i, scale_filter=True).squeeze(0) + for m in mask + ] for j, m in enumerate(scale_mask): # it's possible negative values will get introduced by the downsampling # above, in which case we remove them, since they mess up our # computations. in particular, they could result in negative variance # values. - self.register_buffer(f'_mask_{j}_scale_{i}', m.clip(min=0)) + self.register_buffer(f"_mask_{j}_scale_{i}", m.clip(min=0)) # these indices are used to create the einsum expressions - self._mask_input_idx = ', '.join([f'm{i} h w' for i in range(len(mask))]) + self._mask_input_idx = ", ".join([f"m{i} h w" for i in range(len(mask))]) self._n_masks = len(mask) self._mask_output_idx = f"{' '.join([f'm{i}' for i in range(len(mask))])}" self.spatial_corr_width = spatial_corr_width self.n_scales = n_scales self.n_orientations = n_orientations # these are each lists of tensors of shape (batch, channel, n_autocorrs, height, - # width), one per scale, where n_autocorrs is approximately spatial_corr_width^2 / - # 2 + # width), one per scale, where n_autocorrs is approximately + # spatial_corr_width^2 / 2 rolls_h, rolls_w = self._create_autocorr_idx(spatial_corr_width, image_shape) for i, (h, w) in enumerate(zip(rolls_h, rolls_w)): - self.register_buffer(f'_autocorr_rolls_h_scale_{i}', h) - self.register_buffer(f'_autocorr_rolls_w_scale_{i}', w) + self.register_buffer(f"_autocorr_rolls_h_scale_{i}", h) + self.register_buffer(f"_autocorr_rolls_w_scale_{i}", w) self._n_autocorrs = rolls_h[0].shape[3] self._pyr = SteerablePyramidFreq( self.image_shape, @@ -154,13 +164,17 @@ def __init__( # Dictionary defining necessary statistics, that is, those that are not # redundant - self._necessary_stats_dict = self._create_necessary_stats_dict(scales_shape_dict) + self._necessary_stats_dict = self._create_necessary_stats_dict( + scales_shape_dict + ) # turn this into tensor we can use in forward pass. first into a # boolean mask... - _necessary_stats_mask = einops.pack(list(self._necessary_stats_dict.values()), '*')[0] + _necessary_stats_mask = einops.pack( + list(self._necessary_stats_dict.values()), "*" + )[0] # then into a tensor of indices _necessary_stats_mask = torch.where(_necessary_stats_mask)[0] - self.register_buffer('_necessary_stats_mask', _necessary_stats_mask) + self.register_buffer("_necessary_stats_mask", _necessary_stats_mask) # This array is composed of the following values: 'pixel_statistics', # 'residual_lowpass', 'residual_highpass' and integer values from 0 to @@ -168,9 +182,13 @@ def __init__( # returned by this object's forward method. It must be a numpy array so # we can have a mixture of ints and strs (and so we can use np.in1d # later) - self._representation_scales = einops.pack(list(scales_shape_dict.values()), '*')[0] + self._representation_scales = einops.pack( + list(scales_shape_dict.values()), "*" + )[0] # just select the scales of the necessary stats. - self._representation_scales = self._representation_scales[self._necessary_stats_mask] + self._representation_scales = self._representation_scales[ + self._necessary_stats_mask + ] # There are two types of computations where we add a little epsilon to help with # stability: # - division of one statistic by another @@ -183,22 +201,32 @@ def __init__( def mask(self): # inspired by # https://discuss.pytorch.org/t/why-no-nn-bufferlist-like-function-for-registered-buffer-tensor/18884/10 - return [[getattr(self, f'_mask_{j}_scale_{i}')for j in range(self._n_masks)] - for i in range(self.n_scales+1)] + return [ + [getattr(self, f"_mask_{j}_scale_{i}") for j in range(self._n_masks)] + for i in range(self.n_scales + 1) + ] @property def _autocorr_rolls_h(self): # inspired by # https://discuss.pytorch.org/t/why-no-nn-bufferlist-like-function-for-registered-buffer-tensor/18884/10 - return [getattr(self, f'_autocorr_rolls_h_scale_{i}') for i in range(self.n_scales+1)] + return [ + getattr(self, f"_autocorr_rolls_h_scale_{i}") + for i in range(self.n_scales + 1) + ] @property def _autocorr_rolls_w(self): # inspired by # https://discuss.pytorch.org/t/why-no-nn-bufferlist-like-function-for-registered-buffer-tensor/18884/10 - return [getattr(self, f'_autocorr_rolls_w_scale_{i}') for i in range(self.n_scales+1)] - - def _create_autocorr_idx(self, spatial_corr_width, image_shape) -> Tuple[List[Tensor], List[Tensor]]: + return [ + getattr(self, f"_autocorr_rolls_w_scale_{i}") + for i in range(self.n_scales + 1) + ] + + def _create_autocorr_idx( + self, spatial_corr_width, image_shape + ) -> tuple[list[Tensor], list[Tensor]]: """Create indices used to shift images when computing autocorrelation. The autocorrelation of ``img`` is the product of ``img`` with itself shifted by @@ -223,39 +251,67 @@ def _create_autocorr_idx(self, spatial_corr_width, image_shape) -> Tuple[List[Te Returns ------- rolls_h, rolls_w : - List of tensors of shape ``(1, 1, n_orientations, n_autocorrs, height, width)`` - giving the shifts along the height (``shape[-2]``) and width (``shape[-1]``) - dimensions required for computing the autocorrelations. Each entry in the - list corresponds to a different scale, and thus height and width decrease. + List of tensors of shape ``(1, 1, n_orientations, n_autocorrs, height, + width)`` giving the shifts along the height (``shape[-2]``) and width + (``shape[-1]``) dimensions required for computing the autocorrelations. Each + entry in the list corresponds to a different scale, and thus height and + width decrease. """ # because of the symmetry of autocorrelation, in order to generate all # autocorrelations, we only need the lower triangle (so that we take the # autocorrelation between the image and itself shifted 1 pixel to the left, but # not also shifted 1 pixel to the right)... - half_width = (spatial_corr_width-1) // 2 - autocorr_shift_vals = [i - half_width for i in - np.tril_indices(spatial_corr_width)] + half_width = (spatial_corr_width - 1) // 2 + autocorr_shift_vals = [ + i - half_width for i in np.tril_indices(spatial_corr_width) + ] # if spatial_corr_width is even, then we also need these shifts: if np.mod(spatial_corr_width, 2) == 0: - autocorr_shift_vals[0] = np.concatenate([np.zeros(spatial_corr_width, dtype=int)-half_width, autocorr_shift_vals[0]], 0) - autocorr_shift_vals[1] = np.concatenate([np.arange(spatial_corr_width, dtype=int)-half_width, autocorr_shift_vals[1]], 0) + autocorr_shift_vals[0] = np.concatenate( + [ + np.zeros(spatial_corr_width, dtype=int) - half_width, + autocorr_shift_vals[0], + ], + 0, + ) + autocorr_shift_vals[1] = np.concatenate( + [ + np.arange(spatial_corr_width, dtype=int) - half_width, + autocorr_shift_vals[1], + ], + 0, + ) # and up to the central element on the diagonal. - idx = [i!=j or i<0 for i, j in zip(*autocorr_shift_vals)] + idx = [i != j or i < 0 for i, j in zip(*autocorr_shift_vals)] # put the (0, 0) shift, which corresponds to the variance, at the very end, so # we know where it is - autocorr_shift_vals = [np.concatenate([i[idx], np.zeros(1, dtype=int)], 0) - for i in autocorr_shift_vals] + autocorr_shift_vals = [ + np.concatenate([i[idx], np.zeros(1, dtype=int)], 0) + for i in autocorr_shift_vals + ] img_h, img_w = image_shape rolls_h, rolls_w = [], [] # need one additional scale, since we compute the autocorrelation of the # reconstructed lowpass images as well - for _ in range(self.n_scales+1): - arange_h = torch.arange(img_h).view((1, 1, 1, img_h, 1)).repeat((1, 1, self.n_orientations, 1, img_h)) - arange_w = torch.arange(img_w).view((1, 1, 1, 1, img_w)).repeat((1, 1, self.n_orientations, img_w, 1)) - rolls_h.append(torch.stack([arange_h.roll(i, -2) for i in autocorr_shift_vals[0]], 3)) - rolls_w.append(torch.stack([arange_w.roll(i, -1) for i in autocorr_shift_vals[1]], 3)) + for _ in range(self.n_scales + 1): + arange_h = ( + torch.arange(img_h) + .view((1, 1, 1, img_h, 1)) + .repeat((1, 1, self.n_orientations, 1, img_h)) + ) + arange_w = ( + torch.arange(img_w) + .view((1, 1, 1, 1, img_w)) + .repeat((1, 1, self.n_orientations, img_w, 1)) + ) + rolls_h.append( + torch.stack([arange_h.roll(i, -2) for i in autocorr_shift_vals[0]], 3) + ) + rolls_w.append( + torch.stack([arange_w.roll(i, -1) for i in autocorr_shift_vals[1]], 3) + ) img_h = int(img_h // 2) img_w = int(img_w // 2) return rolls_h, rolls_w @@ -294,16 +350,18 @@ def _create_scales_shape_dict(self) -> OrderedDict: """ shape_dict = OrderedDict() # There are 6 pixel statistics - shape_dict['pixel_statistics'] = np.array(4*['pixel_statistics']) + shape_dict["pixel_statistics"] = np.array(4 * ["pixel_statistics"]) # These are the basic building blocks of the scale assignments for many # of the statistics calculated by the PortillaSimoncelli model. scales = np.arange(self.n_scales) # the cross-scale correlations exclude the coarsest scale - scales_without_coarsest = np.arange(self.n_scales-1) + scales_without_coarsest = np.arange(self.n_scales - 1) # the statistics computed on the reconstructed bandpass images have an # extra scale corresponding to the lowpass residual - scales_with_lowpass = np.array(scales.tolist() + ["residual_lowpass"], dtype=object) + scales_with_lowpass = np.array( + scales.tolist() + ["residual_lowpass"], dtype=object + ) # now we go through each statistic in order and create a dummy array # full of 1s with the same shape as the actual statistic (excluding the @@ -312,47 +370,56 @@ def _create_scales_shape_dict(self) -> OrderedDict: # arrays above to turn those 1s into values describing the # corresponding scale. - auto_corr_mag = np.ones((self._n_autocorrs-1, self.n_orientations, self.n_scales), dtype=int) + auto_corr_mag = np.ones( + (self._n_autocorrs - 1, self.n_orientations, self.n_scales), dtype=int + ) # this rearrange call is turning scales from 1d with shape (n_scales, ) # to 4d with shape (1, 1, n_scales, 1), so that it matches # auto_corr_mag. the following rearrange calls do similar. - auto_corr_mag *= einops.rearrange(scales, 's -> 1 1 s') - shape_dict['auto_correlation_magnitude'] = auto_corr_mag + auto_corr_mag *= einops.rearrange(scales, "s -> 1 1 s") + shape_dict["auto_correlation_magnitude"] = auto_corr_mag - shape_dict['skew_reconstructed'] = scales_with_lowpass + shape_dict["skew_reconstructed"] = scales_with_lowpass - shape_dict['kurtosis_reconstructed'] = scales_with_lowpass + shape_dict["kurtosis_reconstructed"] = scales_with_lowpass - auto_corr = np.ones((self._n_autocorrs-1, self.n_scales+1), dtype=object) - auto_corr *= einops.rearrange(scales_with_lowpass, 's -> 1 s') - shape_dict['auto_correlation_reconstructed'] = auto_corr + auto_corr = np.ones((self._n_autocorrs - 1, self.n_scales + 1), dtype=object) + auto_corr *= einops.rearrange(scales_with_lowpass, "s -> 1 s") + shape_dict["auto_correlation_reconstructed"] = auto_corr - shape_dict['std_reconstructed'] = scales_with_lowpass + shape_dict["std_reconstructed"] = scales_with_lowpass - cross_orientation_corr_mag = np.ones((self.n_orientations, self.n_orientations, - self.n_scales), dtype=int) - cross_orientation_corr_mag *= einops.rearrange(scales, 's -> 1 1 s') - shape_dict['cross_orientation_correlation_magnitude'] = cross_orientation_corr_mag + cross_orientation_corr_mag = np.ones( + (self.n_orientations, self.n_orientations, self.n_scales), dtype=int + ) + cross_orientation_corr_mag *= einops.rearrange(scales, "s -> 1 1 s") + shape_dict["cross_orientation_correlation_magnitude"] = ( + cross_orientation_corr_mag + ) mags_std = np.ones((self.n_orientations, self.n_scales), dtype=int) - mags_std *= einops.rearrange(scales, 's -> 1 s') - shape_dict['magnitude_std'] = mags_std + mags_std *= einops.rearrange(scales, "s -> 1 s") + shape_dict["magnitude_std"] = mags_std - cross_scale_corr_mag = np.ones((self.n_orientations, self.n_orientations, - self.n_scales-1), dtype=int) - cross_scale_corr_mag *= einops.rearrange(scales_without_coarsest, 's -> 1 1 s') - shape_dict['cross_scale_correlation_magnitude'] = cross_scale_corr_mag + cross_scale_corr_mag = np.ones( + (self.n_orientations, self.n_orientations, self.n_scales - 1), dtype=int + ) + cross_scale_corr_mag *= einops.rearrange(scales_without_coarsest, "s -> 1 1 s") + shape_dict["cross_scale_correlation_magnitude"] = cross_scale_corr_mag - cross_scale_corr_real = np.ones((self.n_orientations, 2*self.n_orientations, - self.n_scales-1), dtype=int) - cross_scale_corr_real *= einops.rearrange(scales_without_coarsest, 's -> 1 1 s') - shape_dict['cross_scale_correlation_real'] = cross_scale_corr_real + cross_scale_corr_real = np.ones( + (self.n_orientations, 2 * self.n_orientations, self.n_scales - 1), dtype=int + ) + cross_scale_corr_real *= einops.rearrange(scales_without_coarsest, "s -> 1 1 s") + shape_dict["cross_scale_correlation_real"] = cross_scale_corr_real - shape_dict['var_highpass_residual'] = np.array(["residual_highpass"]) + shape_dict["var_highpass_residual"] = np.array(["residual_highpass"]) return shape_dict - def _create_necessary_stats_dict(self, scales_shape_dict: OrderedDict) -> OrderedDict: + def _create_necessary_stats_dict( + self, scales_shape_dict: OrderedDict + ) -> OrderedDict: """Create mask specifying the necessary statistics. Some of the statistics computed by the model are redundant, due to @@ -382,10 +449,9 @@ def _create_necessary_stats_dict(self, scales_shape_dict: OrderedDict) -> Ordere # for cross_orientation_correlation_magnitude (because we've normalized # this matrix to be true cross-correlations, the diagonals are all 1, # like for the auto-correlations) - triu_inds = torch.triu_indices(self.n_orientations, - self.n_orientations) + triu_inds = torch.triu_indices(self.n_orientations, self.n_orientations) for k, v in mask_dict.items(): - if k == 'cross_orientation_correlation_magnitude': + if k == "cross_orientation_correlation_magnitude": # Symmetry M_{i,j} = M_{j,i}. # Start with all True, then place False in redundant stats. mask = torch.ones(v.shape, dtype=torch.bool) @@ -396,9 +462,7 @@ def _create_necessary_stats_dict(self, scales_shape_dict: OrderedDict) -> Ordere mask_dict[k] = mask return mask_dict - def forward( - self, image: Tensor, scales: Optional[List[SCALES_TYPE]] = None - ) -> Tensor: + def forward(self, image: Tensor, scales: list[SCALES_TYPE] | None = None) -> Tensor: r"""Generate Texture Statistics representation of an image. Note that separate batches and channels are analyzed in parallel. @@ -445,7 +509,9 @@ def forward( # real_pyr_coeffs, which contain the demeaned magnitude of the pyramid # coefficients and the real part of the pyramid coefficients # respectively. - mag_pyr_coeffs, real_pyr_coeffs = self._compute_intermediate_representations(pyr_coeffs) + mag_pyr_coeffs, real_pyr_coeffs = self._compute_intermediate_representations( + pyr_coeffs + ) # Then, the reconstructed lowpass image at each scale. (this is a list # of length n_scales+1 containing tensors of shape (batch, channel, @@ -463,69 +529,93 @@ def forward( pixel_stats = self._compute_pixel_stats(self.mask[0], image) # Compute the central autocorrelation of the coefficient magnitudes. This is a - # tensor of shape: (batch, channel, masks, n_autocorrs, n_orientations, n_scales). + # tensor of shape: (batch, channel, masks, n_autocorrs, n_orientations, + # n_scales). autocorr_mags, mags_var = self._compute_autocorr(self.mask, mag_pyr_coeffs) # mags_var is the variance of the magnitude coefficients at each scale (it's an # intermediary of the computation of the auto-correlations). We take the square # root to get the standard deviation. After this, mags_std will have shape # (batch, channel, masks, n_orientations, n_scale) - mags_std = einops.rearrange((mags_var + self._stability_epsilon).sqrt(), - f'b c {self._mask_output_idx} o s -> b c ({self._mask_output_idx}) o s') + mags_std = einops.rearrange( + (mags_var + self._stability_epsilon).sqrt(), + f"b c {self._mask_output_idx} o s -> b c ({self._mask_output_idx}) o s", + ) # Compute the central autocorrelation of the reconstructed lowpass images at # each scale (and their variances). autocorr_recon is a tensor of shape (batch, # channel, masks, n_autocorrs, n_scales+1), and var_recon is a tensor of shape # (batch, channel, masks, n_scales+1) - autocorr_recon, var_recon = self._compute_autocorr(self.mask, reconstructed_images) + autocorr_recon, var_recon = self._compute_autocorr( + self.mask, reconstructed_images + ) # Compute the standard deviation, skew, and kurtosis of each reconstructed # lowpass image. std_recon, skew_recon, and kurtosis_recon will all end up as # tensors of shape (batch, channel, masks, n_scales+1) - std_recon = einops.rearrange((var_recon + self._stability_epsilon).sqrt(), - f'b c {self._mask_output_idx} s -> b c ({self._mask_output_idx}) s') - skew_recon, kurtosis_recon = self._compute_skew_kurtosis_recon(self.mask, - reconstructed_images, - var_recon) + std_recon = einops.rearrange( + (var_recon + self._stability_epsilon).sqrt(), + f"b c {self._mask_output_idx} s -> b c ({self._mask_output_idx}) s", + ) + skew_recon, kurtosis_recon = self._compute_skew_kurtosis_recon( + self.mask, reconstructed_images, var_recon + ) # Compute the cross-orientation correlations between the magnitude # coefficients at each scale. this will be a tensor of shape (batch, # channel, n_orientations, n_orientations, n_scales) - cross_ori_corr_mags = self._compute_cross_correlation(self.mask, mag_pyr_coeffs, mag_pyr_coeffs, - mags_var, mags_var) + cross_ori_corr_mags = self._compute_cross_correlation( + self.mask, mag_pyr_coeffs, mag_pyr_coeffs, mags_var, mags_var + ) # If we have more than one scale, compute the cross-scale correlations if self.n_scales != 1: # First, double the phase the coefficients, so we can correctly # compute correlations across scales. - phase_doubled_mags, phase_doubled_sep = self._double_phase_pyr_coeffs(pyr_coeffs) + phase_doubled_mags, phase_doubled_sep = self._double_phase_pyr_coeffs( + pyr_coeffs + ) # Compute the cross-scale correlations between the magnitude # coefficients. For each coefficient, we're correlating it with the # coefficients at the next-coarsest scale. this will be a tensor of # shape (batch, channel, n_orientations, n_orientations, # n_scales-1) - cross_scale_corr_mags = self._compute_cross_correlation(self.mask, mag_pyr_coeffs[:-1], - phase_doubled_mags, - mags_var[..., :-1]) + cross_scale_corr_mags = self._compute_cross_correlation( + self.mask, mag_pyr_coeffs[:-1], phase_doubled_mags, mags_var[..., :-1] + ) # Compute the cross-scale correlations between the real # coefficients and the real and imaginary coefficients at the next # coarsest scale. this will be a tensor of shape (batch, channel, # n_orientations, 2*n_orientations, n_scales-1) - cross_scale_corr_real = self._compute_cross_correlation(self.mask, real_pyr_coeffs[:-1], phase_doubled_sep) + cross_scale_corr_real = self._compute_cross_correlation( + self.mask, real_pyr_coeffs[:-1], phase_doubled_sep + ) # Compute the variance of the highpass residual - var_highpass_residual = einops.einsum(*self.mask[0], highpass.pow(2), - f"{self._mask_input_idx}, b c h w -> b c {self._mask_output_idx}") - var_highpass_residual = einops.rearrange(var_highpass_residual, - f'b c {self._mask_output_idx} -> b c ({self._mask_output_idx})') + var_highpass_residual = einops.einsum( + *self.mask[0], + highpass.pow(2), + f"{self._mask_input_idx}, b c h w -> b c {self._mask_output_idx}", + ) + var_highpass_residual = einops.rearrange( + var_highpass_residual, + f"b c {self._mask_output_idx} -> b c ({self._mask_output_idx})", + ) # Now, combine all these stats together, first into a list - all_stats = [pixel_stats, autocorr_mags, skew_recon, - kurtosis_recon, autocorr_recon, std_recon, - cross_ori_corr_mags, mags_std] + all_stats = [ + pixel_stats, + autocorr_mags, + skew_recon, + kurtosis_recon, + autocorr_recon, + std_recon, + cross_ori_corr_mags, + mags_std, + ] if self.n_scales != 1: all_stats += [cross_scale_corr_mags, cross_scale_corr_real] all_stats += [var_highpass_residual] # And then pack them into a 3d tensor - representation_tensor, pack_info = einops.pack(all_stats, 'b c m *') + representation_tensor, pack_info = einops.pack(all_stats, "b c m *") # the only time when this is None is during testing, when we make sure # that our assumptions are all valid. @@ -535,7 +625,9 @@ def forward( self._pack_info = pack_info else: # Throw away all redundant statistics - representation_tensor = representation_tensor.index_select(-1, self._necessary_stats_mask) + representation_tensor = representation_tensor.index_select( + -1, self._necessary_stats_mask + ) # Return the subset of stats corresponding to the specified scale. if scales is not None: @@ -544,7 +636,7 @@ def forward( return representation_tensor def remove_scales( - self, representation_tensor: Tensor, scales_to_keep: List[SCALES_TYPE] + self, representation_tensor: Tensor, scales_to_keep: list[SCALES_TYPE] ) -> Tensor: """Remove statistics not associated with scales. @@ -598,7 +690,7 @@ def convert_to_tensor(self, representation_dict: OrderedDict) -> Tensor: Convert tensor representation to dictionary. """ - rep = einops.pack(list(representation_dict.values()), 'b c *')[0] + rep = einops.pack(list(representation_dict.values()), "b c *")[0] # then get rid of all the nans / unnecessary stats return rep.index_select(-1, self._necessary_stats_mask) @@ -618,7 +710,7 @@ def convert_to_dict(self, representation_tensor: Tensor) -> OrderedDict: Returns ------- - rep + rep Dictionary of representation, with informative keys. See Also @@ -630,7 +722,8 @@ def convert_to_dict(self, representation_tensor: Tensor) -> OrderedDict: if representation_tensor.shape[-1] != len(self._representation_scales): raise ValueError( "representation tensor is the wrong length (expected " - f"{len(self._representation_scales)} but got {representation_tensor.shape[-1]})!" + f"{len(self._representation_scales)} but got" + f"{representation_tensor.shape[-1]})!" " Did you remove some of the scales? (i.e., by setting " "scales in the forward pass)? convert_to_dict does not " "support such tensors." @@ -643,10 +736,13 @@ def convert_to_dict(self, representation_tensor: Tensor) -> OrderedDict: # found in representation_tensor and all the other dimensions # determined by the values in necessary_stats_dict. shape = (*representation_tensor.shape[:3], *v.shape) - new_v = torch.nan * torch.ones(shape, dtype=representation_tensor.dtype, - device=representation_tensor.device) + new_v = torch.nan * torch.ones( + shape, + dtype=representation_tensor.dtype, + device=representation_tensor.device, + ) # v.sum() gives the number of necessary elements from this stat - this_stat_vec = representation_tensor[..., n_filled:n_filled+v.sum()] + this_stat_vec = representation_tensor[..., n_filled : n_filled + v.sum()] # use boolean indexing to put the values from new_stat_vec in the # appropriate place new_v[..., v] = this_stat_vec @@ -654,7 +750,9 @@ def convert_to_dict(self, representation_tensor: Tensor) -> OrderedDict: n_filled += v.sum() return rep - def _compute_pyr_coeffs(self, image: Tensor) -> Tuple[OrderedDict, List[Tensor], Tensor, Tensor]: + def _compute_pyr_coeffs( + self, image: Tensor + ) -> tuple[OrderedDict, list[Tensor], Tensor, Tensor]: """Compute pyramid coefficients of image. Note that the residual lowpass has been demeaned independently for each @@ -687,19 +785,21 @@ def _compute_pyr_coeffs(self, image: Tensor) -> Tuple[OrderedDict, List[Tensor], """ pyr_coeffs = self._pyr.forward(image) # separate out the residuals and demean the residual lowpass - lowpass = pyr_coeffs['residual_lowpass'] + lowpass = pyr_coeffs["residual_lowpass"] lowpass = lowpass - lowpass.mean(dim=(-2, -1), keepdim=True) - pyr_coeffs['residual_lowpass'] = lowpass - highpass = pyr_coeffs['residual_highpass'] + pyr_coeffs["residual_lowpass"] = lowpass + highpass = pyr_coeffs["residual_highpass"] # This is a list of tensors, one for each scale, where each tensor is # of shape (batch, channel, n_orientations, height, width) (note that # height and width halves on each scale) - coeffs_list = [torch.stack([pyr_coeffs[(i, j)] for j in range(self.n_orientations)], 2) - for i in range(self.n_scales)] + coeffs_list = [ + torch.stack([pyr_coeffs[(i, j)] for j in range(self.n_orientations)], 2) + for i in range(self.n_scales) + ] return pyr_coeffs, coeffs_list, highpass, lowpass - def _compute_pixel_stats(self, mask: List[Tensor], image: Tensor) -> Tensor: + def _compute_pixel_stats(self, mask: list[Tensor], image: Tensor) -> Tensor: """Compute the pixel stats: first four moments. Note that for the masked version, these are the *non-central* moments, i.e., @@ -720,15 +820,22 @@ def _compute_pixel_stats(self, mask: List[Tensor], image: Tensor) -> Tensor: non-central moments """ - weighted_avg_expr = f"{self._mask_input_idx}, b c h w -> b c {self._mask_output_idx}" + weighted_avg_expr = ( + f"{self._mask_input_idx}, b c h w -> b c {self._mask_output_idx}" + ) mean = einops.einsum(*mask, image, weighted_avg_expr) var = einops.einsum(*mask, image.pow(2), weighted_avg_expr) skew = einops.einsum(*mask, image.pow(3), weighted_avg_expr) kurtosis = einops.einsum(*mask, image.pow(4), weighted_avg_expr) - return einops.rearrange([mean, var, skew, kurtosis], f'stats b c {self._mask_output_idx} -> b c ({self._mask_output_idx}) stats') + return einops.rearrange( + [mean, var, skew, kurtosis], + f"stats b c {self._mask_output_idx} -> b c ({self._mask_output_idx}) stats", + ) @staticmethod - def _compute_intermediate_representations(pyr_coeffs: Tensor) -> Tuple[List[Tensor], List[Tensor]]: + def _compute_intermediate_representations( + pyr_coeffs: Tensor, + ) -> tuple[list[Tensor], list[Tensor]]: """Compute useful intermediate representations. These representations are: @@ -759,12 +866,18 @@ def _compute_intermediate_representations(pyr_coeffs: Tensor) -> Tuple[List[Tens """ magnitude_pyr_coeffs = [coeff.abs() for coeff in pyr_coeffs] - magnitude_means = [mag.mean((-2, -1), keepdim=True) for mag in magnitude_pyr_coeffs] - magnitude_pyr_coeffs = [mag - mn for mag, mn in zip(magnitude_pyr_coeffs, magnitude_means)] + magnitude_means = [ + mag.mean((-2, -1), keepdim=True) for mag in magnitude_pyr_coeffs + ] + magnitude_pyr_coeffs = [ + mag - mn for mag, mn in zip(magnitude_pyr_coeffs, magnitude_means) + ] real_pyr_coeffs = [coeff.real for coeff in pyr_coeffs] return magnitude_pyr_coeffs, real_pyr_coeffs - def _reconstruct_lowpass_at_each_scale(self, pyr_coeffs_dict: OrderedDict) -> List[Tensor]: + def _reconstruct_lowpass_at_each_scale( + self, pyr_coeffs_dict: OrderedDict + ) -> list[Tensor]: """Reconstruct the lowpass unoriented image at each scale. The autocorrelation, standard deviation, skew, and kurtosis of each of @@ -786,9 +899,11 @@ def _reconstruct_lowpass_at_each_scale(self, pyr_coeffs_dict: OrderedDict) -> Li widths. """ - reconstructed_images = [self._pyr.recon_pyr(pyr_coeffs_dict, levels=['residual_lowpass'])] + reconstructed_images = [ + self._pyr.recon_pyr(pyr_coeffs_dict, levels=["residual_lowpass"]) + ] # go through scales backwards - for lev in range(self.n_scales-1, -1, -1): + for lev in range(self.n_scales - 1, -1, -1): recon = self._pyr.recon_pyr(pyr_coeffs_dict, levels=[lev]) reconstructed_images.append(recon + reconstructed_images[-1]) # now downsample as necessary, so that these end up the same size as @@ -796,11 +911,15 @@ def _reconstruct_lowpass_at_each_scale(self, pyr_coeffs_dict: OrderedDict) -> Li # in order to approximately equalize the steerable pyramid coefficient # values across scales. This could also be handled by making the # pyramid tight frame - reconstructed_images[:-1] = [signal.shrink(r, 2**(self.n_scales-i)) * 4**(self.n_scales-i) - for i, r in enumerate(reconstructed_images[:-1])] + reconstructed_images[:-1] = [ + signal.shrink(r, 2 ** (self.n_scales - i)) * 4 ** (self.n_scales - i) + for i, r in enumerate(reconstructed_images[:-1]) + ] return reconstructed_images - def _compute_autocorr(self, mask: List[Tensor], coeffs_list: List[Tensor]) -> Tuple[Tensor, Tensor]: + def _compute_autocorr( + self, mask: list[Tensor], coeffs_list: list[Tensor] + ) -> tuple[Tensor, Tensor]: """Compute the autocorrelation of some statistics. Parameters @@ -831,27 +950,38 @@ def _compute_autocorr(self, mask: List[Tensor], coeffs_list: List[Tensor]) -> Tu """ if coeffs_list[0].ndim == 5: - dims = 'o' + dims = "o" rolls_h = self._autocorr_rolls_h rolls_w = self._autocorr_rolls_w var_dim = -2 elif coeffs_list[0].ndim == 4: - dims = '' + dims = "" rolls_h = [r[:, :, 0] for r in self._autocorr_rolls_h] rolls_w = [r[:, :, 0] for r in self._autocorr_rolls_w] var_dim = -1 else: - raise ValueError("coeffs_list must contain tensors of either 4 or 5 dimensions!") - autocorr_expr = f"{self._mask_input_idx}, b c {dims} h w, b c {dims} shift h w -> b c {self._mask_output_idx} shift {dims}" + raise ValueError( + "coeffs_list must contain tensors of either 4 or 5 dimensions!" + ) + autocorr_expr = ( + f"{self._mask_input_idx}, b c {dims} h w, " + f"b c {dims} shift h w ->" + f" b c {self._mask_output_idx} shift {dims}" + ) acs = [] vars = [] # iterate through scales - for coeff, rolls_h, rolls_w, scale_mask in zip(coeffs_list, rolls_h, rolls_w, mask): + for coeff, rolls_h, rolls_w, scale_mask in zip( + coeffs_list, rolls_h, rolls_w, mask + ): # the following two lines are equivalent to having two for loops over # range(-spatial_corr_width//2, spatial_corr_width//2) and using roll along # the last two indices, but is much more efficient, especially on the gpu. - rolled_coeff = einops.repeat(coeff, f'b c {dims} h w -> b c {dims} shift h w', - shift=self._n_autocorrs) + rolled_coeff = einops.repeat( + coeff, + f"b c {dims} h w -> b c {dims} shift h w", + shift=self._n_autocorrs, + ) rolled_coeff = rolled_coeff.gather(-2, rolls_h).gather(-1, rolls_w) autocorr = einops.einsum(*scale_mask, coeff, rolled_coeff, autocorr_expr) # this returns a view of autocorr that just selects out the variance, while @@ -859,16 +989,31 @@ def _compute_autocorr(self, mask: List[Tensor], coeffs_list: List[Tensor]) -> Tu # shift, which corresponds to the variance, as the last element var = torch.narrow(autocorr, var_dim, -1, 1) # and then drop the variance from here - acs.append(torch.narrow(autocorr, var_dim, 0, self._n_autocorrs-1) - / (var + self._stability_epsilon)) + acs.append( + torch.narrow(autocorr, var_dim, 0, self._n_autocorrs - 1) + / (var + self._stability_epsilon) + ) vars.append(var) - acs = einops.rearrange(acs, f'scales b c {self._mask_output_idx} shifts {dims} -> b c ({self._mask_output_idx}) shifts {dims} scales') - vars = einops.rearrange(vars, f'scales b c {self._mask_output_idx} shifts {dims} -> b c {self._mask_output_idx} {dims} (shifts scales)', shifts=1) + acs = einops.rearrange( + acs, + ( + f"scales b c {self._mask_output_idx} shifts {dims} -> " + f"b c ({self._mask_output_idx}) shifts {dims} scales" + ), + ) + vars = einops.rearrange( + vars, + ( + f"scales b c {self._mask_output_idx} shifts {dims} -> " + f"b c {self._mask_output_idx} {dims} (shifts scales)" + ), + shifts=1, + ) return acs, vars - def _compute_skew_kurtosis_recon(self, mask: List[Tensor], - reconstructed_images: List[Tensor], - var_recon: Tensor) -> Tuple[Tensor, Tensor]: + def _compute_skew_kurtosis_recon( + self, mask: list[Tensor], reconstructed_images: list[Tensor], var_recon: Tensor + ) -> tuple[Tensor, Tensor]: """Compute the skew and kurtosis of each lowpass reconstructed image. For each scale, if the ratio of its variance to the original image's @@ -897,23 +1042,62 @@ def _compute_skew_kurtosis_recon(self, mask: List[Tensor], ``reconstructed_images``. """ - var_recon = einops.rearrange(var_recon, f'b c {self._mask_output_idx} scales -> b c ({self._mask_output_idx}) scales') + var_recon = einops.rearrange( + var_recon, + ( + f"b c {self._mask_output_idx} scales -> " + f"b c ({self._mask_output_idx}) scales" + ), + ) skew_recon = [] kurtosis_recon = [] for img, scale_mask in zip(reconstructed_images, mask): - skew_recon.append(einops.einsum(*scale_mask, img.pow(3), f"{self._mask_input_idx}, b c h w -> b c {self._mask_output_idx}")) - kurtosis_recon.append(einops.einsum(*scale_mask, img.pow(4), f"{self._mask_input_idx}, b c h w -> b c {self._mask_output_idx}")) - skew_recon = einops.rearrange(skew_recon, f'scales b c {self._mask_output_idx} -> b c ({self._mask_output_idx}) scales') - kurtosis_recon = einops.rearrange(kurtosis_recon, f'scales b c {self._mask_output_idx} -> b c ({self._mask_output_idx}) scales') + skew_recon.append( + einops.einsum( + *scale_mask, + img.pow(3), + ( + f"{self._mask_input_idx}, b c h w -> " + f"b c {self._mask_output_idx}" + ), + ) + ) + kurtosis_recon.append( + einops.einsum( + *scale_mask, + img.pow(4), + ( + f"{self._mask_input_idx}, b c h w -> " + f"b c {self._mask_output_idx}" + ), + ) + ) + skew_recon = einops.rearrange( + skew_recon, + ( + f"scales b c {self._mask_output_idx} ->" + f" b c ({self._mask_output_idx}) scales" + ), + ) + kurtosis_recon = einops.rearrange( + kurtosis_recon, + ( + f"scales b c {self._mask_output_idx} -> " + f"b c ({self._mask_output_idx}) scales" + ), + ) skew_recon = skew_recon / (var_recon.pow(1.5) + self._stability_epsilon) kurtosis_recon = kurtosis_recon / (var_recon.pow(2) + self._stability_epsilon) return skew_recon, kurtosis_recon - def _compute_cross_correlation(self, mask: List[Tensor], - coeffs_tensor: List[Tensor], - coeffs_tensor_other: List[Tensor], - coeffs_var: Optional[Tensor] = None, - coeffs_other_var: Optional[Tensor] = None) -> Tensor: + def _compute_cross_correlation( + self, + mask: list[Tensor], + coeffs_tensor: list[Tensor], + coeffs_tensor_other: list[Tensor], + coeffs_var: Tensor | None = None, + coeffs_other_var: Tensor | None = None, + ) -> Tensor: """Compute cross-correlations. Parameters @@ -938,10 +1122,22 @@ def _compute_cross_correlation(self, mask: List[Tensor], """ covars = [] - covar_expr = f'{self._mask_input_idx}, b c o1 h w, b c o2 h w -> b c {self._mask_output_idx} o1 o2' - var_expr = f'{self._mask_input_idx}, b c o1 h w, b c o1 h w -> b c {self._mask_output_idx} o1' - outer_prod_expr = f'b c {self._mask_output_idx} o1, b c {self._mask_output_idx} o2 -> b c {self._mask_output_idx} o1 o2' - for i, (scale_mask, coeff, coeff_other) in enumerate(zip(mask, coeffs_tensor, coeffs_tensor_other)): + covar_expr = ( + f"{self._mask_input_idx}, b c o1 h w, b c o2 h w ->" + f" b c {self._mask_output_idx} o1 o2" + ) + var_expr = ( + f"{self._mask_input_idx}, b c o1 h w, b c o1 h w ->" + f" b c {self._mask_output_idx} o1" + ) + outer_prod_expr = ( + f"b c {self._mask_output_idx} o1, " + f"b c {self._mask_output_idx} o2 ->" + f" b c {self._mask_output_idx} o1 o2" + ) + for i, (scale_mask, coeff, coeff_other) in enumerate( + zip(mask, coeffs_tensor, coeffs_tensor_other) + ): # compute the covariance covar = einops.einsum(*scale_mask, coeff, coeff_other, covar_expr) # Then normalize it to get the Pearson product-moment correlation @@ -949,27 +1145,34 @@ def _compute_cross_correlation(self, mask: List[Tensor], # https://numpy.org/doc/stable/reference/generated/numpy.corrcoef.html. if coeffs_var is None: # First, compute the variances of each coeff - coeff_var = einops.einsum(*scale_mask, coeff, coeff, - var_expr) + coeff_var = einops.einsum(*scale_mask, coeff, coeff, var_expr) else: coeff_var = coeffs_var[..., i] if coeffs_other_var is None: # First, compute the variances of each coeff - coeff_other_var = einops.einsum(*scale_mask, coeff_other, coeff_other, - var_expr) + coeff_other_var = einops.einsum( + *scale_mask, coeff_other, coeff_other, var_expr + ) else: coeff_other_var = coeffs_other_var[..., i] # Then compute the outer product of those variances. - var_outer_prod = einops.einsum(coeff_var, coeff_other_var, - outer_prod_expr) + var_outer_prod = einops.einsum(coeff_var, coeff_other_var, outer_prod_expr) # And the sqrt of this is what we use to normalize the covariance # into the cross-correlation std_outer_prod = (var_outer_prod + self._stability_epsilon).sqrt() covars.append(covar / (std_outer_prod + self._stability_epsilon)) - return einops.rearrange(covars, f'scales b c {self._mask_output_idx} o1 o2 -> b c ({self._mask_output_idx}) o1 o2 scales') + return einops.rearrange( + covars, + ( + f"scales b c {self._mask_output_idx} o1 o2 ->" + f" b c ({self._mask_output_idx}) o1 o2 scales" + ), + ) @staticmethod - def _double_phase_pyr_coeffs(pyr_coeffs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]: + def _double_phase_pyr_coeffs( + pyr_coeffs: list[Tensor], + ) -> tuple[list[Tensor], list[Tensor]]: """Upsample and double the phase of pyramid coefficients. Parameters @@ -1005,21 +1208,24 @@ def _double_phase_pyr_coeffs(pyr_coeffs: List[Tensor]) -> Tuple[List[Tensor], Li doubled_phase = signal.expand(coeff, 2) / 4.0 doubled_phase = signal.modulate_phase(doubled_phase, 2) doubled_phase_mag = doubled_phase.abs() - doubled_phase_mag = doubled_phase_mag - doubled_phase_mag.mean((-2, -1), keepdim=True) + doubled_phase_mag = doubled_phase_mag - doubled_phase_mag.mean( + (-2, -1), keepdim=True + ) doubled_phase_mags.append(doubled_phase_mag) - doubled_phase_sep.append(einops.pack([doubled_phase.real, doubled_phase.imag], - 'b c * h w')[0]) + doubled_phase_sep.append( + einops.pack([doubled_phase.real, doubled_phase.imag], "b c * h w")[0] + ) return doubled_phase_mags, doubled_phase_sep def plot_representation( - self, - data: Tensor, - ax: Optional[plt.Axes] = None, - figsize: Tuple[float, float] = (15, 5), - ylim: Optional[Union[Tuple[float, float], Literal[False]]] = False, - batch_idx: int = 0, - title: Optional[str] = None, - ) -> Tuple[plt.Figure, List[plt.Axes]]: + self, + data: Tensor, + ax: plt.Axes | None = None, + figsize: tuple[float, float] = (15, 5), + ylim: tuple[float, float] | Literal[False] | None = False, + batch_idx: int = 0, + title: str | None = None, + ) -> tuple[plt.Figure, list[plt.Axes]]: r"""Plot the representation in a human viewable format -- stem plots with data separated out by statistic type. @@ -1050,7 +1256,7 @@ def plot_representation( norm) If self.n_scales > 1, we also have: - + - cross_scale_correlation_magnitude: the cross-correlations between the pyramid coefficient magnitude at one scale and the same orientation at the next-coarsest scale (summarized using Euclidean norm). @@ -1136,38 +1342,42 @@ def plot_representation( return fig, axes def _representation_for_plotting(self, rep: OrderedDict) -> OrderedDict: - r"""Convert the data into a dictionary representation that is more convenient for plotting. + r"""Convert the data into a more convenient representation for plotting. Intended as a helper function for plot_representation. """ data = OrderedDict() - data["pixels+var_highpass"] = torch.cat([rep.pop("pixel_statistics"), - rep.pop("var_highpass_residual")], -1) + data["pixels+var_highpass"] = torch.cat( + [rep.pop("pixel_statistics"), rep.pop("var_highpass_residual")], -1 + ) data["std+skew+kurtosis recon"] = torch.cat( ( rep.pop("std_reconstructed"), rep.pop("skew_reconstructed"), rep.pop("kurtosis_reconstructed"), - ), -1 + ), + -1, ) - data['magnitude_std'] = rep.pop('magnitude_std').flatten(1) + data["magnitude_std"] = rep.pop("magnitude_std").flatten(1) # want to plot these in a specific order - all_keys = ['auto_correlation_reconstructed', - 'auto_correlation_magnitude', - 'cross_orientation_correlation_magnitude', - 'cross_scale_correlation_magnitude', - 'cross_scale_correlation_real'] + all_keys = [ + "auto_correlation_reconstructed", + "auto_correlation_magnitude", + "cross_orientation_correlation_magnitude", + "cross_scale_correlation_magnitude", + "cross_scale_correlation_real", + ] if set(rep.keys()) != set(all_keys): raise ValueError("representation has unexpected keys!") for k in all_keys: # if we only have one scale, no cross-scale stats - if k.startswith('cross_scale') and self.n_scales == 1: + if k.startswith("cross_scale") and self.n_scales == 1: continue # these will then be 2d, with masks on the first dimension - if k == 'cross_orientation_correlation_magnitude': + if k == "cross_orientation_correlation_magnitude": # this one has nans in it (indicating unnecessary stats), and so we need # to compute the L2 norm ourselves data[k] = rep[k].pow(2).nansum((1, 2)).sqrt() @@ -1177,11 +1387,11 @@ def _representation_for_plotting(self, rep: OrderedDict) -> OrderedDict: return data def update_plot( - self, - axes: List[plt.Axes], - data: Tensor, - batch_idx: int = 0, - ) -> List[plt.Artist]: + self, + axes: list[plt.Axes], + data: Tensor, + batch_idx: int = 0, + ) -> list[plt.Artist]: r"""Update the information in our representation plot. This is used for creating an animation of the representation