diff --git a/spaceKLIP/imagetools.py b/spaceKLIP/imagetools.py index 0c6a9bdf..77f62b8a 100644 --- a/spaceKLIP/imagetools.py +++ b/spaceKLIP/imagetools.py @@ -2012,7 +2012,7 @@ def align_frames(self, mask_override : str, optional Mask some pixels when cross correlating for shifts msk_shp : int, optional - Shape (height or radius) for custom mask invoked by "mask_override" + Shape (height or radius, or [inner radius, outer radius]) for custom mask invoked by "mask_override" shft_exp : float, optional Take image to the given power before cross correlating for shifts, default is 1. For instance, 1/2 helps align nircam bar/narrow data (or other data with weird speckles) kwargs : dict, optional @@ -2034,6 +2034,18 @@ def align_frames(self, os.makedirs(output_dir) # useful masks for computing shifts: + def create_annulus_mask(h, w, center=None, radius=None): + + if center is None: # use the middle of the image + center = (int(w/2), int(h/2)) + if radius is None: # use the smallest distance between the center and image walls + radius = min(center[0], center[1], w-center[0], h-center[1]) + + Y, X = np.ogrid[:h, :w] + dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2) + + mask = (dist_from_center <= radius[0]) | (dist_from_center >= radius[1]) + return mask def create_circular_mask(h, w, center=None, radius=None): if center is None: # use the middle of the image @@ -2080,7 +2092,9 @@ def create_rec_mask(h, w, center=None, z=None): maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) if mask_override is not None: - if mask_override == 'circ': + if mask_override == 'ann': + mask_circ = create_annulus_mask(data[0].shape[0], data[0].shape[1], radius=msk_shp) + elif mask_override == 'circ': mask_circ = create_circular_mask(data[0].shape[0],data[0].shape[1], radius=msk_shp) elif mask_override == 'rec': mask_circ = create_rec_mask(data[0].shape[0],data[0].shape[1], z=msk_shp) @@ -2093,7 +2107,7 @@ def create_rec_mask(h, w, center=None, z=None): mask_temp = np.ones_like(data[0]) else: mask_temp = mask - + # Align frames. head, tail = os.path.split(fitsfile) log.info(' --> Align frames: ' + tail)