diff --git a/spaceKLIP/imagetools.py b/spaceKLIP/imagetools.py index b762b8fd..c15b3818 100644 --- a/spaceKLIP/imagetools.py +++ b/spaceKLIP/imagetools.py @@ -1599,6 +1599,7 @@ def recenter_frames(self, method='fourier', subpix_first_sci_only=False, spectral_type='G2V', + shft_exp=1, kwargs={}, subdir='recentered'): """ @@ -1624,6 +1625,8 @@ def recenter_frames(self, spectral_type : str, optional Host star spectral type for the WebbPSF model used to determine the star position behind the coronagraphic mask. The default is 'G2V'. + shft_exp : float, optional + Take image to the given power before cross correlating for shifts, default is 1. For instance, 1/2 helps align nircam bar/narrow data (or other data with weird speckles) kwargs : dict, optional Keyword arguments for the scipy.ndimage.shift routine. The default is {}. @@ -1691,6 +1694,7 @@ def recenter_frames(self, xc, yc, xshift, yshift = self.find_nircam_centers(data0=data[k].copy(), key=key, j=j, + shft_exp=shft_exp, spectral_type=spectral_type, date=head_pri['DATE-BEG'], output_dir=output_dir) @@ -1820,6 +1824,7 @@ def find_nircam_centers(self, key, j, spectral_type='G2V', + shft_exp=1, date=None, output_dir=None, fov_pix=65, @@ -1840,6 +1845,8 @@ def find_nircam_centers(self, spectral_type : str, optional Host star spectral type for the WebbPSF model used to determine the star position behind the coronagraphic mask. The default is 'G2V'. + shft_exp : float, optional + Take image to the given power before cross correlating for shifts, default is 1. date : str, optional Observation date in the format 'YYYY-MM-DDTHH:MM:SS.MMM'. The default is None. @@ -1919,9 +1926,16 @@ def find_nircam_centers(self, xycen=(xc, yc), npix=fov_pix) + if shft_exp == 1: + img1 = datasub* masksub + img2 = model_psf* masksub + else: + img1 = np.power(np.abs(datasub), shft_exp)* masksub + img2 = np.power(np.abs(model_psf), shft_exp) * masksub + # Determine relative shift between data and model PSF. - shift, error, phasediff = phase_cross_correlation(datasub * masksub, - model_psf * masksub, + shift, error, phasediff = phase_cross_correlation(img1, + img2, upsample_factor=1000, normalization=None) yshift, xshift = shift @@ -1973,6 +1987,9 @@ def find_nircam_centers(self, def align_frames(self, method='fourier', align_algo='leastsq', + mask_override=None, + msk_shp=8, + shft_exp=1, kwargs={}, subdir='aligned'): """ @@ -1985,6 +2002,12 @@ def align_frames(self, align_algo : 'leastsq' or 'header' Algorithm to determine the alignment offsets. Default is 'leastsq', 'header' assumes perfect header offsets. + mask_override : str, optional + Mask some pixels when cross correlating for shifts + msk_shp : int, optional + Shape (height or radius) for custom mask invoked by "mask_override" + shft_exp : float, optional + Take image to the given power before cross correlating for shifts, default is 1. For instance, 1/2 helps align nircam bar/narrow data (or other data with weird speckles) kwargs : dict, optional Keyword arguments for the scipy.ndimage.shift routine. The default is {}. @@ -2002,6 +2025,31 @@ def align_frames(self, output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) + + # useful masks for computing shifts: + def create_circular_mask(h, w, center=None, radius=None): + + if center is None: # use the middle of the image + center = (int(w/2), int(h/2)) + if radius is None: # use the smallest distance between the center and image walls + radius = min(center[0], center[1], w-center[0], h-center[1]) + + Y, X = np.ogrid[:h, :w] + dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2) + + mask = dist_from_center <= radius + return mask + + def create_rec_mask(h, w, center=None, z=None): + if center is None: # use the middle of the image + center = (int(w/2), int(h/2)) + if z is None: + z = h//4 + + mask = np.zeros((h,w), dtype=bool) + mask[center[1]-z:center[1]+z,:] = True + + return mask # Loop through concatenations. database_temp = deepcopy(self.database.obs) @@ -2024,6 +2072,20 @@ def align_frames(self, data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) + if mask_override is not None: + if mask_override == 'circ': + mask_circ = create_circular_mask(data[0].shape[0],data[0].shape[1], radius=msk_shp) + elif mask_override == 'rec': + mask_circ = create_rec_mask(data[0].shape[0],data[0].shape[1], z=msk_shp) + else: + raise ValueError('There are `circ` and `rec` custom masks available') + mask_temp = data[0].copy() + mask_temp[~mask_circ] = 1 + mask_temp[mask_circ] = 0 + elif mask is None: + mask_temp = np.ones_like(data[0]) + else: + mask_temp = mask # Align frames. head, tail = os.path.split(fitsfile) @@ -2064,10 +2126,14 @@ def align_frames(self, if (np.abs(xshift) < 1e-3) and (np.abs(yshift) < 1e-3): p0 = np.array([0., 0., 1.]) if align_algo == 'leastsq': + if shft_exp != 1: + args = (np.power(np.abs(data[k]), shft_exp), np.power(np.abs(ref_image), shft_exp), mask_temp, method, kwargs) + else: + args = (data[k], ref_image, mask_temp, method, kwargs) # Use header values to initiate least squares fit pp = leastsq(ut.alignlsq, p0, - args=(data[k], ref_image, mask, method, kwargs))[0] + args=args)[0] elif align_algo == 'header': # Just assume the header values are correct pp = p0