diff --git a/src/plenoptic/simulate/models/portilla_simoncelli_masked.py b/src/plenoptic/simulate/models/portilla_simoncelli_masked.py index cabb5a54..7ed5cf3a 100644 --- a/src/plenoptic/simulate/models/portilla_simoncelli_masked.py +++ b/src/plenoptic/simulate/models/portilla_simoncelli_masked.py @@ -824,9 +824,21 @@ def _compute_pixel_stats(self, mask: list[Tensor], image: Tensor) -> Tensor: f"{self._mask_input_idx}, b c h w -> b c {self._mask_output_idx}" ) mean = einops.einsum(*mask, image, weighted_avg_expr) - var = einops.einsum(*mask, image.pow(2), weighted_avg_expr) - skew = einops.einsum(*mask, image.pow(3), weighted_avg_expr) - kurtosis = einops.einsum(*mask, image.pow(4), weighted_avg_expr) + # these are all non-central moments... + moment_2 = einops.einsum(*mask, image.pow(2), weighted_avg_expr) + moment_3 = einops.einsum(*mask, image.pow(3), weighted_avg_expr) + moment_4 = einops.einsum(*mask, image.pow(4), weighted_avg_expr) + # ... which we use to compute the var, skew, and kurtosis. the formulas we use + # for var and skew here can be found on their respective wikipedia pages, and + # the one for kurtosis comes from Eero working through the algebra + var = moment_2 - mean.pow(2) + skew = (moment_3 - 3 * mean * var - mean.pow(3)) / (var.pow(1.5)) + kurtosis = ( + moment_4 + - 4 * mean * moment_3 + + 6 * mean.pow(2) * moment_2 + - 3 * mean.pow(4) + ) / (var.pow(2)) return einops.rearrange( [mean, var, skew, kurtosis], f"stats b c {self._mask_output_idx} -> b c ({self._mask_output_idx}) stats",