diff --git a/src/pyFAI/gui/pilx/MainWindow.py b/src/pyFAI/gui/pilx/MainWindow.py index 65e66a4e4..7b14a07b0 100644 --- a/src/pyFAI/gui/pilx/MainWindow.py +++ b/src/pyFAI/gui/pilx/MainWindow.py @@ -46,8 +46,6 @@ from silx.gui import qt from silx.gui.colors import Colormap from silx.image.marchingsquares import find_contours -from silx.io.url import DataUrl -from silx.io import get_data from silx.gui.plot.items.image import ImageBase from .models import ImageIndices @@ -56,6 +54,7 @@ compute_radial_values, get_dataset, get_indices_from_values, + get_mask_image, get_radial_dataset, ) from .widgets.DiffractionImagePlotWidget import DiffractionImagePlotWidget @@ -63,10 +62,11 @@ from .widgets.MapPlotWidget import MapPlotWidget from .widgets.TitleWidget import TitleWidget from ...io.integration_config import WorkerConfig -from ...utils.mathutil import binning as rebin_fct -from ...detectors import Detector +from ...utils.mathutil import binning + logger = logging.getLogger(__name__) + class MainWindow(qt.QMainWindow): sigFileChanged = qt.Signal(str) @@ -253,28 +253,25 @@ def getMask(self, image, maskfile=None): :return: 2D array """ if maskfile: - mask_image = get_data(url=DataUrl(maskfile)) - if mask_image.shape != image.shape: - binning = [m//i for i, m in zip(image.shape, mask_image.shape)] - if min(binning)<1: - mask_image = None - else: - mask_image = rebin_fct(mask_image, binning) + mask_image = get_mask_image(maskfile, image.shape) else: mask_image = None + detector = self.worker_config.poni.detector - if detector: - detector_mask = detector.mask - if detector.shape != image.shape: - detector.guess_binning(image) - detector_mask = rebin_fct(detector_mask, detector.binning) - if mask_image is None: - mask_image = detector_mask - else: - numpy.logical_or(mask_image, detector_mask, out=mask_image) - detector.mask = mask_image - mask_image = detector.dynamic_mask(image) - return mask_image + if not detector: + return mask_image + + detector_mask = detector.mask + if detector.shape != image.shape: + detector.guess_binning(image) + detector_mask = binning(detector_mask, detector.binning) + + if mask_image is None: + detector.mask = detector_mask + else: + detector.mask = numpy.logical_or(mask_image, detector_mask) + + return detector.dynamic_mask(image) def displayImageAtIndices(self, indices: ImageIndices): if self._file_name is None: diff --git a/src/pyFAI/gui/pilx/utils.py b/src/pyFAI/gui/pilx/utils.py index 66a2c1444..1ac22096b 100644 --- a/src/pyFAI/gui/pilx/utils.py +++ b/src/pyFAI/gui/pilx/utils.py @@ -28,6 +28,7 @@ """Tool to visualize diffraction maps.""" from __future__ import annotations + __author__ = "Loïc Huder" __contact__ = "loic.huder@ESRF.eu" __license__ = "MIT" @@ -35,16 +36,20 @@ __date__ = "19/02/2025" __status__ = "development" -from typing import Iterable, Optional import logging +from typing import Iterable, Optional, Tuple + logger = logging.getLogger(__name__) -import json +import os.path + import h5py import numpy -import os.path +from silx.io import get_data +from silx.io.url import DataUrl + from ...integrator.azimuthal import AzimuthalIntegrator -from ...detectors import Detector from ...io.integration_config import WorkerConfig +from ...utils.mathutil import binning def compute_radial_values(worker_config: WorkerConfig) -> numpy.ndarray: @@ -115,3 +120,18 @@ def guess_axis_path(existing_axis_path: str, parent: h5py.Group) -> str | None: return guessed_axis_path return None + + +def get_mask_image(maskfile: str, image_shape: Tuple[int, int]) -> numpy.ndarray | None: + """Retrieves mask image from the URL. Rebin to match""" + mask_image = get_data(url=DataUrl(maskfile)) + assert isinstance(mask_image, numpy.ndarray) + if mask_image.shape == image_shape: + return mask_image + + # If mismatched shapes, try to rebin + bin_size = [m // i for i, m in zip(image_shape, mask_image.shape)] + if bin_size[0] == 0 or bin_size[1] == 0: + return None + + return binning(mask_image, bin_size)