diff --git a/spaceKLIP/imagetools.py b/spaceKLIP/imagetools.py index e44c9fd..93ff71c 100644 --- a/spaceKLIP/imagetools.py +++ b/spaceKLIP/imagetools.py @@ -30,6 +30,7 @@ from spaceKLIP import utils as ut from spaceKLIP.psf import JWST_PSF from spaceKLIP.xara import core +from spaceKLIP.utils import gaussian_kernel from webbpsf_ext import robust from webbpsf_ext.coords import dist_image from webbpsf_ext.webbpsf_ext_core import _transmission_map @@ -42,6 +43,7 @@ from webbpsf.constants import JWST_CIRCUMSCRIBED_DIAMETER from astropy.io import fits from spaceKLIP.starphot import get_stellar_magnitudes, read_spec_file +import scipy.ndimage import logging log = logging.getLogger(__name__) @@ -67,29 +69,29 @@ class ImageTools(): """ The spaceKLIP image manipulation tools class. - + """ - + def __init__(self, database): """ Initialize the spaceKLIP image manipulation tools class. - + Parameters ---------- database : spaceKLIP.Database SpaceKLIP database on which the image manipulation steps shall be run. - + Returns ------- None. - + """ - + # Make an internal alias of the spaceKLIP database class. self.database = database - + pass def _get_output_dir(self, subdir): @@ -144,7 +146,7 @@ def remove_frames(self, subdir='removed'): """ Remove individual frames from the data. - + Parameters ---------- index : int or list of int or dict of list of list of int, optional @@ -163,40 +165,40 @@ def remove_frames(self, subdir : str, optional Name of the directory where the data products shall be saved. The default is 'removed'. - + Returns ------- None. - + """ - + # Check input. if isinstance(index, int): index = [index] - + # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) - + # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) - + # Loop through FITS files. nfitsfiles = len(self.database.obs[key]) for j in range(nfitsfiles): - + # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) nints = self.database.obs[key]['NINTS'][j] - + # Skip file types that are not in the list of types. if self.database.obs[key]['TYPE'][j] in types: - + # Remove frames. head, tail = os.path.split(fitsfile) log.info(' --> Frame removal: ' + tail) @@ -213,24 +215,24 @@ def remove_frames(self, if maskoffs is not None: maskoffs = np.delete(maskoffs, index_temp, axis=0) nints = data.shape[0] - + # Write FITS file and PSF mask. head_pri['NINTS'] = nints fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) - + # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile, nints=nints) - + pass - + def crop_frames(self, npix=1, types=['SCI', 'SCI_BG', 'REF', 'REF_BG'], subdir='cropped'): """ Crop all frames. - + Parameters ---------- npix : int or list of four int, optional @@ -244,32 +246,32 @@ def crop_frames(self, subdir : str, optional Name of the directory where the data products shall be saved. The default is 'cropped'. - + Returns ------- None. - + """ - + # Check input. if isinstance(npix, int): npix = [npix, npix, npix, npix] # left, right, bottom, top if len(npix) != 4: raise UserWarning('Parameter npix must either be an int or a list of four int (left, right, bottom, top)') - + # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) - + # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) - + # Loop through FITS files. nfitsfiles = len(self.database.obs[key]) for j in range(nfitsfiles): - + # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs = ut.read_obs(fitsfile) @@ -277,10 +279,10 @@ def crop_frames(self, mask = ut.read_msk(maskfile) crpix1 = self.database.obs[key]['CRPIX1'][j] crpix2 = self.database.obs[key]['CRPIX2'][j] - + # Skip file types that are not in the list of types. if self.database.obs[key]['TYPE'][j] in types: - + # Crop frames. head, tail = os.path.split(fitsfile) log.info(' --> Frame cropping: ' + tail) @@ -293,18 +295,18 @@ def crop_frames(self, crpix1 -= npix[0] crpix2 -= npix[2] log.info(' --> Frame cropping: old shape = ' + str(sh[1:]) + ', new shape = ' + str(data.shape[1:])) - + # Write FITS file and PSF mask. head_sci['CRPIX1'] = crpix1 head_sci['CRPIX2'] = crpix2 fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) - + # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile, crpix1=crpix1, crpix2=crpix2) - + pass - + def pad_frames(self, npix=1, cval=np.nan, @@ -312,7 +314,7 @@ def pad_frames(self, subdir='padded'): """ Pad all frames. - + Parameters ---------- npix : int or list of four int, optional @@ -328,32 +330,32 @@ def pad_frames(self, subdir : str, optional Name of the directory where the data products shall be saved. The default is 'padded'. - + Returns ------- None. - + """ - + # Check input. if isinstance(npix, int): npix = [npix, npix, npix, npix] # left, right, bottom, top if len(npix) != 4: raise UserWarning('Parameter npix must either be an int or a list of four int (left, right, bottom, top)') - + # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) - + # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) - + # Loop through FITS files. nfitsfiles = len(self.database.obs[key]) for j in range(nfitsfiles): - + # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs = ut.read_obs(fitsfile) @@ -361,10 +363,10 @@ def pad_frames(self, mask = ut.read_msk(maskfile) crpix1 = self.database.obs[key]['CRPIX1'][j] crpix2 = self.database.obs[key]['CRPIX2'][j] - + # Skip file types that are not in the list of types. if self.database.obs[key]['TYPE'][j] in types: - + # Crop frames. head, tail = os.path.split(fitsfile) log.info(' --> Frame padding: ' + tail) @@ -377,25 +379,25 @@ def pad_frames(self, crpix1 += npix[0] crpix2 += npix[2] log.info(' --> Frame padding: old shape = ' + str(sh[1:]) + ', new shape = ' + str(data.shape[1:]) + ', fill value = %.2f' % cval) - + # Write FITS file and PSF mask. head_sci['CRPIX1'] = crpix1 head_sci['CRPIX2'] = crpix2 fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) - + # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile, crpix1=crpix1, crpix2=crpix2) - + pass - + def coadd_frames(self, nframes=None, types=['SCI', 'SCI_BG', 'REF', 'REF_BG'], subdir='coadded'): """ Coadd frames. - + Parameters ---------- nframes : int, optional @@ -407,13 +409,13 @@ def coadd_frames(self, subdir : str, optional Name of the directory where the data products shall be saved. The default is 'coadded'. - + Returns ------- None. - + """ - + # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): @@ -421,15 +423,15 @@ def coadd_frames(self, # The starting value. nframes0 = nframes - + # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) - + # Loop through FITS files. nfitsfiles = len(self.database.obs[key]) for j in range(nfitsfiles): - + # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs = ut.read_obs(fitsfile) @@ -437,14 +439,14 @@ def coadd_frames(self, mask = ut.read_msk(maskfile) nints = self.database.obs[key]['NINTS'][j] effinttm = self.database.obs[key]['EFFINTTM'][j] - + # If nframes is not provided, collapse everything. if nframes0 is None: nframes = nints - + # Skip file types that are not in the list of types. if self.database.obs[key]['TYPE'][j] in types: - + # Coadd frames. head, tail = os.path.split(fitsfile) log.info(' --> Frame coadding: ' + tail) @@ -464,29 +466,29 @@ def coadd_frames(self, nints = data.shape[0] effinttm *= nframes log.info(' --> Frame coadding: %.0f coadd(s) of %.0f frame(s)' % (ncoadds, nframes)) - + # Write FITS file and PSF mask. head_pri['NINTS'] = nints head_pri['EFFINTTM'] = effinttm fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) - + # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile, nints=nints, effinttm=effinttm) - + pass - + def subtract_median(self, types=['SCI', 'SCI_TA', 'SCI_BG', 'REF', 'REF_TA', 'REF_BG'], method='border', sigma=3.0, borderwidth=32, subdir='medsub'): - + """ Subtract the median from each frame. Clip everything brighter than 5- sigma from the background before computing the median. - + Parameters ---------- types : list of str, optional @@ -515,30 +517,30 @@ def subtract_median(self, None. """ - + # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) - + log.info(f'Median subtraction using method={method}') # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) - + # Loop through FITS files. nfitsfiles = len(self.database.obs[key]) for j in range(nfitsfiles): - + # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) - + # Skip file types that are not in the list of types. if self.database.obs[key]['TYPE'][j] in types: - + # Subtract median. head, tail = os.path.split(fitsfile) log.info(' --> Median subtraction: ' + tail) @@ -590,13 +592,14 @@ def subtract_median(self, data -= bg_median log.info(' --> Median subtraction: mean of frame median = %.2f' % np.mean(bg_median)) - + # Write FITS file and PSF mask. fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) - + # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile) + def subtract_background_godoy(self, types=['SCI', 'REF'], @@ -729,33 +732,34 @@ def subtract_background_godoy(self, self.database.update_obs(key, j, fitsfile) pass + def subtract_background(self, nints_per_med=None, subdir='bgsub'): """ Median subtract the corresponding background observations from the SCI and REF - data in the spaceKLIP database. - + data in the spaceKLIP database. + Parameters ---------- nints_per_med : int Number of integrations per median. For example, if you have a target + background dataset with 20 integrations each and nints_per_med is set to 5, a median of every 5 background images will be subtracted from - the corresponding 5 target images. The default is None (i.e. a median + the corresponding 5 target images. The default is None (i.e. a median across all images). subdir : str, optional Name of the directory where the data products shall be saved. The default is 'bgsub'. - + Returns ------- None. - + """ - + # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): @@ -763,24 +767,24 @@ def subtract_background(self, # Store the nints_per_med parameter orig_nints_per_med = deepcopy(nints_per_med) - + # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) - + # Find science, reference, and background files. ww = np.where((self.database.obs[key]['TYPE'] == 'SCI') | (self.database.obs[key]['TYPE'] == 'REF'))[0] ww_sci_bg = np.where(self.database.obs[key]['TYPE'] == 'SCI_BG')[0] ww_ref_bg = np.where(self.database.obs[key]['TYPE'] == 'REF_BG')[0] - + # Loop through science background files. if len(ww_sci_bg) != 0: sci_bg_data = [] sci_bg_erro = [] sci_bg_pxdq = [] for j in ww_sci_bg: - + # Read science background file. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs = ut.read_obs(fitsfile) @@ -790,7 +794,7 @@ def subtract_background(self, if orig_nints_per_med == None: nints_per_med = nints indxs = np.arange(nints) - split_inds = [x+1 for x in indxs if (x+1)%nints_per_med == 0 + split_inds = [x+1 for x in indxs if (x+1)%nints_per_med == 0 and x < (nints-nints_per_med)] # Compute median science background. @@ -811,14 +815,14 @@ def subtract_background(self, sci_bg_pxdq_split[k] = np.sum(sci_bg_pxdq_split[k] & 1 == 1, axis=0) != 0 else: sci_bg_data = None - + # Loop through reference background files. if len(ww_ref_bg) != 0: ref_bg_data = [] ref_bg_erro = [] ref_bg_pxdq = [] for j in ww_ref_bg: - + # Read reference background file. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs = ut.read_obs(fitsfile) @@ -828,7 +832,7 @@ def subtract_background(self, if orig_nints_per_med == None: nints_per_med = nints indxs = np.arange(nints) - split_inds = [x+1 for x in indxs if (x+1)%nints_per_med == 0 + split_inds = [x+1 for x in indxs if (x+1)%nints_per_med == 0 and x < (nints-nints_per_med)] # Compute median reference background. ref_bg_data += [data] @@ -848,12 +852,12 @@ def subtract_background(self, ref_bg_pxdq_split[k] = np.sum(ref_bg_pxdq_split[k] & 1 == 1, axis=0) != 0 else: ref_bg_data = None - + # Check input. if sci_bg_data is None and ref_bg_data is None: raise UserWarning('Could not find any background files') - - # Loop through science and reference files. + + # Loop through science and reference files. for j in ww: # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] @@ -872,9 +876,9 @@ def subtract_background(self, if orig_nints_per_med == None: nints_per_med = nints indxs = np.arange(nints) - split_inds = [x+1 for x in indxs if (x+1)%nints_per_med == 0 + split_inds = [x+1 for x in indxs if (x+1)%nints_per_med == 0 and x < (nints-nints_per_med)] - + # Subtract background. head, tail = os.path.split(fitsfile) log.info(' --> Background subtraction: ' + tail) @@ -903,12 +907,12 @@ def subtract_background(self, # Write FITS file and PSF mask. fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) - + # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile) - + pass - + def fix_bad_pixels(self, method='timemed+dqmed+medfilt', @@ -988,18 +992,18 @@ def fix_bad_pixels(self, subdir : str, optional Name of the directory where the data products shall be saved. The default is 'bpcleaned'. - + Returns ------- None. """ - + # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) - + # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): # if we limit to only processing some concatenations, check whether this concatenation matches the pattern @@ -1007,20 +1011,20 @@ def fix_bad_pixels(self, continue log.info('--> Concatenation ' + key) - + # Loop through FITS files. nfitsfiles = len(self.database.obs[key]) for j in range(nfitsfiles): - + # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) - + # Skip file types that are not in the list of types. if self.database.obs[key]['TYPE'][j] in types: - + # Call bad pixel cleaning routines. pxdq_temp = pxdq.copy() # if self.database.obs[key]['TELESCOP'][j] == 'JWST' and self.database.obs[key]['INSTRUME'][j] == 'NIRCAM': @@ -1054,7 +1058,7 @@ def fix_bad_pixels(self, # pxdq[(pxdq != 0) & np.logical_not(pxdq & 512 == 512) & (pxdq_temp == 0)] = 0 # else: # pxdq[(pxdq & 1 == 1) & np.logical_not(pxdq & 512 == 512) & (pxdq_temp == 0)] = 0 - + # update the pixel DQ bit flags for the output files. # The pxdq variable here is effectively just the DO_NOT_USE flag, discarding other bits. # We want to make a new dq which retains the other bits as much as possible. @@ -1069,12 +1073,12 @@ def fix_bad_pixels(self, # Write FITS file and PSF mask. fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, new_dq, head_pri, head_sci, is2d, imshifts, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) - + # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile) - + pass - + def find_bad_pixels_bpclean(self, data, erro, @@ -1084,7 +1088,7 @@ def find_bad_pixels_bpclean(self, """ Use an iterative sigma clipping algorithm to identify additional bad pixels in the data. - + Parameters ---------- data : 3D-array @@ -1110,13 +1114,13 @@ def find_bad_pixels_bpclean(self, The default is [-1, 0, 1]. The default is {}. - + Returns ------- None. """ - + # Check input. if 'sigclip' not in bpclean_kwargs.keys(): bpclean_kwargs['sigclip'] = 5. @@ -1128,7 +1132,7 @@ def find_bad_pixels_bpclean(self, bpclean_kwargs['shift_x'] += [0] if 0 not in bpclean_kwargs['shift_y']: bpclean_kwargs['shift_y'] += [0] - + # Pad data. pad_left = np.abs(np.min(bpclean_kwargs['shift_x'])) pad_right = np.abs(np.max(bpclean_kwargs['shift_x'])) @@ -1143,7 +1147,7 @@ def find_bad_pixels_bpclean(self, else: top = -pad_top pad_vals = ((pad_bottom, pad_top), (pad_left, pad_right)) - + # Find bad pixels using median of neighbors. pxdq_orig = pxdq.copy() ww = pxdq != 0 @@ -1152,23 +1156,23 @@ def find_bad_pixels_bpclean(self, erro_temp = erro.copy() erro_temp[ww] = np.nan for i in range(ww.shape[0]): - + # Get median background and standard deviation. bg_med = np.nanmedian(data_temp[i]) bg_std = robust.medabsdev(data_temp[i]) bg_ind = data[i] < (bg_med + 10. * bg_std) # clip bright PSFs for final calculation bg_med = np.nanmedian(data_temp[i][bg_ind]) bg_std = robust.medabsdev(data_temp[i][bg_ind]) - + # Create initial mask of large negative values. ww[i] = ww[i] | (data[i] < bg_med - bpclean_kwargs['sigclip'] * bg_std) - + # Loop through max 10 iterations. for it in range(10): data_temp[i][ww[i]] = np.nan erro_temp[i][ww[i]] = np.nan - - # Shift data. + + # Shift data. pad_data = np.pad(data_temp[i], pad_vals, mode='edge') pad_erro = np.pad(erro_temp[i], pad_vals, mode='edge') data_arr = [] @@ -1196,9 +1200,9 @@ def find_bad_pixels_bpclean(self, pxdq[i][ww[i]] = 1 print('') log.info(' --> Method bpclean: identified %.0f additional bad pixel(s) -- %.2f%%' % (np.sum(pxdq) - np.sum(pxdq_orig), 100. * (np.sum(pxdq) - np.sum(pxdq_orig)) / np.prod(pxdq.shape))) - + pass - + def find_bad_pixels_custom(self, data, erro, @@ -1207,7 +1211,7 @@ def find_bad_pixels_custom(self, custom_kwargs={}): """ Use a custom bad pixel map to flag additional bad pixels in the data. - + Parameters ---------- data : 3D-array @@ -1224,23 +1228,23 @@ def find_bad_pixels_custom(self, match the keys of the observations database and the dictionary content must be binary bad pixel maps (1 = bad, 0 = good) with the same shape as the corresponding data. The default is {}. - + Returns ------- None. - + """ - + # Find bad pixels using median of neighbors. pxdq_orig = pxdq.copy() pxdq_custom = custom_kwargs[key] != 0 if pxdq_custom.ndim == pxdq.ndim - 1: # Enable 3D bad pixel map to flag individual frames - pxdq_custom = np.array([pxdq_custom] * pxdq.shape[0]) + pxdq_custom = np.array([pxdq_custom] * pxdq.shape[0]) pxdq[pxdq_custom] = 1 log.info(' --> Method custom: flagged %.0f additional bad pixel(s) -- %.2f%%' % (np.sum(pxdq) - np.sum(pxdq_orig), 100. * (np.sum(pxdq) - np.sum(pxdq_orig)) / np.prod(pxdq.shape))) - + pass - + def fix_bad_pixels_timemed(self, data, erro, @@ -1249,7 +1253,7 @@ def fix_bad_pixels_timemed(self, """ Replace pixels which are only bad in some frames with their median value from the good frames. - + Parameters ---------- data : 3D-array @@ -1263,13 +1267,13 @@ def fix_bad_pixels_timemed(self, Keyword arguments for the 'timemed' method. Available keywords are: - n/a The default is {}. - + Returns ------- None. - + """ - + # Fix bad pixels using time median. ww = pxdq != 0 ww_all_bad = np.array([np.sum(ww, axis=0) == ww.shape[0]] * ww.shape[0]) @@ -1280,9 +1284,9 @@ def fix_bad_pixels_timemed(self, erro[ww_not_all_bad] = np.nan erro[ww_not_all_bad] = np.array([np.nanmedian(erro, axis=0)] * erro.shape[0])[ww_not_all_bad] pxdq[ww_not_all_bad] = 0 - + pass - + def fix_bad_pixels_dqmed(self, data, erro, @@ -1291,7 +1295,7 @@ def fix_bad_pixels_dqmed(self, """ Replace bad pixels with the median value of their surrounding good pixels. - + Parameters ---------- data : 3D-array @@ -1312,13 +1316,13 @@ def fix_bad_pixels_dqmed(self, The default is [-1, 0, 1]. The default is {}. - + Returns ------- None. """ - + # Check input. if 'shift_x' not in dqmed_kwargs.keys(): dqmed_kwargs['shift_x'] = [-1, 0, 1] @@ -1328,7 +1332,7 @@ def fix_bad_pixels_dqmed(self, dqmed_kwargs['shift_x'] += [0] if 0 not in dqmed_kwargs['shift_y']: dqmed_kwargs['shift_y'] += [0] - + # Pad data. pad_left = np.abs(np.min(dqmed_kwargs['shift_x'])) pad_right = np.abs(np.max(dqmed_kwargs['shift_x'])) @@ -1343,7 +1347,7 @@ def fix_bad_pixels_dqmed(self, else: top = -pad_top pad_vals = ((0, 0), (pad_bottom, pad_top), (pad_left, pad_right)) - + # Fix bad pixels using median of neighbors. ww = pxdq != 0 data_temp = data.copy() @@ -1371,9 +1375,9 @@ def fix_bad_pixels_dqmed(self, erro[i][ww[i]] = erro_med[ww[i]] pxdq[i][ww[i]] = 0 log.info(' --> Method dqmed: fixing %.0f bad pixel(s) -- %.2f%%' % (np.sum(ww), 100. * np.sum(ww) / np.prod(ww.shape))) - + pass - + def fix_bad_pixels_medfilt(self, data, erro, @@ -1381,7 +1385,7 @@ def fix_bad_pixels_medfilt(self, medfilt_kwargs={}): """ Replace bad pixels with an image plane median filter. - + Parameters ---------- data : 3D-array @@ -1404,11 +1408,11 @@ def fix_bad_pixels_medfilt(self, None. """ - + # Check input. if 'size' not in medfilt_kwargs.keys(): medfilt_kwargs['size'] = 4 - + # Fix bad pixels using median filter. ww = pxdq != 0 log.info(' --> Method medfilt: fixing %.0f bad pixel(s) -- %.2f%%' % (np.sum(ww), 100. * np.sum(ww) / np.prod(ww.shape))) @@ -1420,16 +1424,16 @@ def fix_bad_pixels_medfilt(self, data[i][ww[i]] = median_filter(data_temp[i], **medfilt_kwargs)[ww[i]] erro[i][ww[i]] = median_filter(erro_temp[i], **medfilt_kwargs)[ww[i]] pxdq[i][ww[i]] = 0 - + pass - + def replace_nans(self, cval=0., types=['SCI', 'SCI_BG', 'REF', 'REF_BG'], subdir='nanreplaced'): """ Replace all nans in the data with a constant value. - + Parameters ---------- cval : float, optional @@ -1440,58 +1444,58 @@ def replace_nans(self, subdir : str, optional Name of the directory where the data products shall be saved. The default is 'nanreplaced'. - + Returns ------- None. - + """ - + # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) - + # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) - + # Loop through FITS files. nfitsfiles = len(self.database.obs[key]) for j in range(nfitsfiles): - + # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) - + # Skip file types that are not in the list of types. if self.database.obs[key]['TYPE'][j] in types: - + # Replace nans. head, tail = os.path.split(fitsfile) log.info(' --> Nan replacement: ' + tail) ww = np.isnan(data) data[ww] = cval log.info(' --> Nan replacement: replaced %.0f nan pixel(s) with value ' % (np.sum(ww)) + str(cval) + ' -- %.2f%%' % (100. * np.sum(ww)/np.prod(ww.shape))) - + # Write FITS file and PSF mask. fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) - + # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile) - + pass - + def blur_frames(self, fact='auto', types=['SCI', 'SCI_BG', 'REF', 'REF_BG'], subdir='blurred'): """ Blur frames with a Gaussian filter. - + Parameters ---------- fact : 'auto' or 'fix23' or float or dict of list of float or None, optional @@ -1512,36 +1516,36 @@ def blur_frames(self, subdir : str, optional Name of the directory where the data products shall be saved. The default is 'blurred'. - + Returns ------- None. - + """ - + # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) - + # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) - + # Loop through FITS files. Nfitsfiles = len(self.database.obs[key]) for j in range(Nfitsfiles): - + # Read FITS file. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) - + # Skip file types that are not in the list of types. fact_temp = None if self.database.obs[key]['TYPE'][j] in types: - + # Blur frames. head, tail = os.path.split(fitsfile) log.info(' --> Frame blurring: ' + tail) @@ -1581,7 +1585,7 @@ def blur_frames(self, mask = gaussian_filter(mask, fact_temp) else: log.info(' --> Frame blurring: skipped') - + # Write FITS file. if fact_temp is None: pass @@ -1589,22 +1593,22 @@ def blur_frames(self, head_pri['BLURFWHM'] = fact_temp * np.sqrt(8. * np.log(2.)) # Factor to convert from sigma to FWHM fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) - + # Update spaceKLIP database. if fact_temp is None: self.database.update_obs(key, j, fitsfile, maskfile, blurfwhm=np.nan) else: self.database.update_obs(key, j, fitsfile, maskfile, blurfwhm=fact_temp * np.sqrt(8. * np.log(2.))) - + pass - + def hpf(self, size='auto', types=['SCI', 'SCI_BG', 'REF', 'REF_BG'], subdir='filtered'): """ Blur frames with a Gaussian filter. - + Parameters ---------- size : 'auto' or float or dict of list of float or None, optional @@ -1623,36 +1627,36 @@ def hpf(self, subdir : str, optional Name of the directory where the data products shall be saved. The default is 'blurred'. - + Returns ------- None. - + """ - + # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) - + # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) - + # Loop through FITS files. Nfitsfiles = len(self.database.obs[key]) for j in range(Nfitsfiles): - + # Read FITS file. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) - + # Skip file types that are not in the list of types. size_temp = None if self.database.obs[key]['TYPE'][j] in types: - + # High-pass filter frames. head, tail = os.path.split(fitsfile) log.info(' --> Frame filtering: ' + tail) @@ -1667,7 +1671,7 @@ def hpf(self, erro = parallelized.high_pass_filter_imgs(erro, numthreads=None, filtersize=fourier_sigma_size) else: log.info(' --> Frame filtering: skipped') - + # Write FITS file. if size_temp is None: pass @@ -1675,10 +1679,10 @@ def hpf(self, head_pri['HPFSIZE'] = size_temp fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) - + # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile) - + pass def inject_companions(self, @@ -1893,18 +1897,20 @@ def inject_companions(self, mcomp = mstar[filt] -2.5*np.log10(guess_flux) offsetpsf *= fzero[filt] / 10 ** (mcomp / 2.5) / 1e6 / pxar # MJy/sr - ## Question: am I injecting companions too bright?! Shouldn't I rescale them to the distance - ## of the target? Why doesen't work if I try. - # parallax_mas=12.1549 - # parallax_arcseconds = parallax_mas / 1000 - # distance_parsecs = 1 / parallax_arcseconds - # offsetpsf *= (1 / distance_parsecs) ** 2 - - # Apply scale factor to incorporate the coronagraphic # mask througput. offsetpsf *= scale_factor + # For Test only, we apply a gaussian kernel to the psf we want to inject to test if we are able + # to recover it later when using Analysis.extract_companions + if 'sigma_xy' in kwargs.keys(): + if 'theta_degrees' not in kwargs.keys(): + kwargs['theta_degrees'] = 0 + sigma_xy = kwargs['sigma_xy'] + theta_degrees = kwargs['theta_degrees'] + kernel = gaussian_kernel(sigma_x=sigma_xy[0], sigma_y=sigma_xy[1], theta_degrees=theta_degrees,n=6) + offsetpsf = scipy.ndimage.convolve(offsetpsf, kernel) + # Injected PSF needs to be a 3D array that matches dataset inj_psf_3d = np.array([offsetpsf for k in range(dataset.input.shape[0])]) @@ -1936,56 +1942,56 @@ def update_nircam_centers(self): This information will eventually be applied as updates into the SIAF, after which point this step will become not necessary. - + Returns ------- None. - + """ - + # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) - + # Loop through FITS files. for j in range(len(self.database.obs[key])): - + # Skip file types that are not NIRCam coronagraphy. if self.database.obs[key]['EXP_TYPE'][j] == 'NRC_CORON': - + # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) - + # Update current reference pixel position. head, tail = os.path.split(fitsfile) log.info(' --> Update NIRCam coronagraphy centers: ' + tail) - + # Get current reference pixel position. crpix1 = self.database.obs[key]['CRPIX1'][j] crpix2 = self.database.obs[key]['CRPIX2'][j] - + # Get SIAF reference pixel position. siaf = pysiaf.Siaf('NIRCAM') apsiaf = siaf[self.database.obs[key]['APERNAME'][j]] xsciref, ysciref = (apsiaf.XSciRef, apsiaf.YSciRef) - + # Get true mask center from Jarron. try: crpix1_jarron, crpix2_jarron = crpix_jarron[self.database.obs[key]['APERNAME'][j]] except KeyError: log.warning(' --> Update NIRCam coronagraphy centers: no true mask center found for ' + self.database.obs[key]['APERNAME'][j]) crpix1_jarron, crpix2_jarron = xsciref, ysciref - + # Get filter shift from Jarron. try: xshift_jarron, yshift_jarron = filter_shifts_jarron[self.database.obs[key]['FILTER'][j]] except KeyError: log.warning(' --> Update NIRCam coronagraphy centers: no filter shift found for ' + self.database.obs[key]['FILTER'][j]) xshift_jarron, yshift_jarron = 0., 0. - + # Determine offset between SIAF reference pixel position # and true mask center from Jarron and update current # reference pixel position. Account for filter-dependent @@ -1994,12 +2000,12 @@ def update_nircam_centers(self): log.info(' --> Update NIRCam coronagraphy centers: old = (%.2f, %.2f), new = (%.2f, %.2f)' % (crpix1, crpix2, crpix1 + xoff, crpix2 + yoff)) crpix1 += xoff crpix2 += yoff - + # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile, crpix1=crpix1, crpix2=crpix2) - + pass - + def recenter_frames(self, method='fourier', subpix_first_sci_only=False, @@ -2015,7 +2021,7 @@ def recenter_frames(self, behind the coronagraphic mask for the first SCI frame. Then, shift all other SCI and REF frames by the same amount. For MIRI coronagraphy, do nothing. For all other data types, simply recenter the host star PSF. - + Parameters ---------- method : 'fourier' or 'spline' (not recommended), optional @@ -2044,45 +2050,45 @@ def recenter_frames(self, subdir : str, optional Name of the directory where the data products shall be saved. The default is 'recentered'. - + Returns ------- None. - + """ - + # Update NIRCam coronagraphy centers, i.e., change SIAF CRPIX position # to true mask center determined by Jarron. # self.update_nircam_centers() # shall be run purposely by the user - + # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): os.makedirs(output_dir) - + # Loop through concatenations. for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) - + # Find science and reference files. ww_sci = np.where(self.database.obs[key]['TYPE'] == 'SCI')[0] ww_sci_ta = np.where(self.database.obs[key]['TYPE'] == 'SCI_TA')[0] ww_ref = np.where(self.database.obs[key]['TYPE'] == 'REF')[0] ww_ref_ta = np.where(self.database.obs[key]['TYPE'] == 'REF_TA')[0] - + # Loop through FITS files. ww_all = np.append(ww_sci, ww_ref) ww_all = np.append(ww_all, ww_sci_ta) ww_all = np.append(ww_all, ww_ref_ta) shifts_all = [] for j in ww_all: - + # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs = ut.read_obs(fitsfile) maskfile = self.database.obs[key]['MASKFILE'][j] mask = ut.read_msk(maskfile) - + # Recenter frames. Use different algorithms based on data type. head, tail = os.path.split(fitsfile) log.info(' --> Recenter frames: ' + tail) @@ -2090,14 +2096,14 @@ def recenter_frames(self, raise UserWarning('Please replace nan pixels before attempting to recenter frames') shifts = [] # shift between star position and image center (data.shape // 2) maskoffs_temp = [] # shift between star and coronagraphic mask position - + # SCI and REF data. if j in ww_sci or j in ww_ref: - + # NIRCam coronagraphy. if self.database.obs[key]['EXP_TYPE'][j] in ['NRC_CORON']: for k in range(data.shape[0]): - + # For the first SCI frame, get the star position # and the shift between the star and coronagraphic # mask position. @@ -2129,7 +2135,7 @@ def recenter_frames(self, elif self.database.obs[key]['EXP_TYPE'][j] in ['MIR_4QPM', 'MIR_LYOT']: log.warning(' --> Recenter frames: not implemented for MIRI coronagraphy, skipped') for k in range(data.shape[0]): - + # Do nothing. shifts += [np.array([0., 0.])] maskoffs_temp += [np.array([0., 0.])] @@ -2137,11 +2143,11 @@ def recenter_frames(self, yoffset = self.database.obs[key]['YOFFSET'][j] # arcsec crpix1 = self.database.obs[key]['CRPIX1'][j] # 1-indexed crpix2 = self.database.obs[key]['CRPIX2'][j] # 1-indexed - + # Other data types. else: for k in range(data.shape[0]): - + # Recenter SCI and REF frames to subpixel precision # using the 'BCEN' routine from XARA. # https://github.com/fmartinache/xara @@ -2154,7 +2160,7 @@ def recenter_frames(self, else: shifts += [np.array([0., 0.])] maskoffs_temp += [np.array([0., 0.])] - + # Recenter SCI and REF frames to integer pixel # precision by rolling the image. ww_max = np.unravel_index(np.argmax(data[k]), data[k].shape) @@ -2168,11 +2174,11 @@ def recenter_frames(self, yoffset = 0. # arcsec crpix1 = data.shape[-1]//2 + 1 # 1-indexed crpix2 = data.shape[-2]//2 + 1 # 1-indexed - + # TA data. if j in ww_sci_ta or j in ww_ref_ta: for k in range(data.shape[0]): - + # Center TA frames on the nearest pixel center. This # pixel center is not necessarily the image center, # which is why a subsequent integer pixel recentering @@ -2185,7 +2191,7 @@ def recenter_frames(self, maskoffs_temp += [np.array([0., 0.])] data[k] = ut.imshift(data[k], [shifts[k][0], shifts[k][1]], method=method, kwargs=kwargs) erro[k] = ut.imshift(erro[k], [shifts[k][0], shifts[k][1]], method=method, kwargs=kwargs) - + # Recenter TA frames to integer pixel precision by # rolling the image. ww_max = np.unravel_index(np.argmax(data[k]), data[k].shape) @@ -2210,14 +2216,14 @@ def recenter_frames(self, maskoffs += maskoffs_temp else: maskoffs = maskoffs_temp - + # Compute shift distances. dist = np.sqrt(np.sum(shifts[:, :2]**2, axis=1)) # pix dist *= self.database.obs[key]['PIXSCALE'][j] * 1000 # mas head, tail = os.path.split(self.database.obs[key]['FITSFILE'][j]) log.info(' --> Recenter frames: ' + tail) log.info(' --> Recenter frames: median required shift = %.2f mas' % np.median(dist)) - + # Write FITS file and PSF mask. head_pri['XOFFSET'] = xoffset #arcsec head_pri['YOFFSET'] = yoffset #arcsec @@ -2225,12 +2231,12 @@ def recenter_frames(self, head_sci['CRPIX2'] = crpix2 fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) - + # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile, xoffset=xoffset, yoffset=yoffset, crpix1=crpix1, crpix2=crpix2) - + pass - + @plt.style.context('spaceKLIP.sk_style') def find_nircam_centers(self, data0, @@ -2247,7 +2253,7 @@ def find_nircam_centers(self, """ Find the star position behind the coronagraphic mask using a WebbPSF model. - + Parameters ---------- data0 : list @@ -2273,7 +2279,7 @@ def find_nircam_centers(self, use_coeff : bool, optional Use pre-computed coefficients to generate the WebbPSF model. The default is False. - + Returns ------- xc : float @@ -2284,16 +2290,16 @@ def find_nircam_centers(self, X-shift between star and coronagraphic mask position (pix). yshift : float Y-shift between star and coronagraphic mask position (pix). - + """ - + # Generate host star spectrum. spectrum = webbpsf_ext.stellar_spectrum(spectral_type) - + # Get true mask center. crpix1 = self.database.obs[key]['CRPIX1'][j] - 1 # 0-indexed crpix2 = self.database.obs[key]['CRPIX2'][j] - 1 # 0-indexed - + # Initialize JWST_PSF object. Use odd image size so that PSF is # centered in pixel center. log.info(' --> Recenter frames: generating WebbPSF image for absolute centering (this might take a while)') @@ -2307,11 +2313,11 @@ def find_nircam_centers(self, 'sp': spectrum } psf = JWST_PSF(APERNAME, FILTER, **kwargs) - + # Get SIAF reference pixel position. apsiaf = psf.inst_on.siaf_ap xsciref, ysciref = (apsiaf.XSciRef, apsiaf.YSciRef) - + # Generate model PSF. Apply offset between SIAF reference pixel # position and true mask center. xoff = (crpix1 + 1) - xsciref @@ -2413,10 +2419,10 @@ def find_nircam_centers(self, log.info(f" Plot saved in {output_file}") # plt.show() plt.close(fig) - + # Return star position. return xc, yc, median_xshift, median_yshift - + @plt.style.context('spaceKLIP.sk_style') def align_frames(self, method='fourier', @@ -2430,14 +2436,14 @@ def align_frames(self, subdir='aligned'): """ Align all SCI and REF frames to the first SCI frame. - + Parameters ---------- method : 'fourier' or 'spline' (not recommended), optional Method for shifting the frames. The default is 'fourier'. align_algo : 'leastsq' or 'header' Algorithm to determine the alignment offsets. Default is 'leastsq', - 'header' assumes perfect header offsets. + 'header' assumes perfect header offsets. mask_override : str, optional Mask some pixels when cross correlating for shifts msk_shp : int, optional @@ -2458,13 +2464,13 @@ def align_frames(self, subdir : str, optional Name of the directory where the data products shall be saved. The default is 'aligned'. - + Returns ------- None. - + """ - + # Set output directory. output_dir = os.path.join(self.database.output_dir, subdir) if not os.path.exists(output_dir): @@ -2506,19 +2512,19 @@ def create_rec_mask(h, w, center=None, z=None): mask[center[1]-z:center[1]+z,:] = True return mask - + # Loop through concatenations. database_temp = deepcopy(self.database.obs) for i, key in enumerate(self.database.obs.keys()): log.info('--> Concatenation ' + key) - + # Find science and reference files. ww_sci = np.where(self.database.obs[key]['TYPE'] == 'SCI')[0] if len(ww_sci) == 0: raise UserWarning('Could not find any science files') ww_ref = np.where(self.database.obs[key]['TYPE'] == 'REF')[0] ww_all = np.append(ww_sci, ww_ref) - + # Loop through FITS files. if align_to_file is not None: try: @@ -2529,7 +2535,7 @@ def create_rec_mask(h, w, center=None, z=None): ref_image = np.nanmedian(ref_image, axis=0) shifts_all = [] for j in ww_all: - + # Read FITS file and PSF mask. fitsfile = self.database.obs[key]['FITSFILE'][j] data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs = ut.read_obs(fitsfile) @@ -2559,7 +2565,7 @@ def create_rec_mask(h, w, center=None, z=None): raise UserWarning('Please replace nan pixels before attempting to align frames') shifts = [] for k in range(data.shape[0]): - + # Take the first science frame as reference frame. if j == ww_sci[0] and k == 0: if align_to_file is None: @@ -2570,7 +2576,7 @@ def create_rec_mask(h, w, center=None, z=None): crpix1 = self.database.obs[key]['CRPIX1'][j] #pixels crpix2 = self.database.obs[key]['CRPIX2'][j] #pixels pxsc = self.database.obs[key]['PIXSCALE'][j] #arcsec - + # Align all other SCI and REF frames to the first science # frame. if align_to_file is not None or j != ww_sci[0] or k != 0: @@ -2620,7 +2626,7 @@ def create_rec_mask(h, w, center=None, z=None): pp = p0 # Append shifts to array and apply shift to image - # using defined method. + # using defined method. shifts += [np.array([pp[0], pp[1], pp[2]])] if align_to_file is not None or j != ww_sci[0] or k != 0: data[k] = ut.imshift(data[k], [shifts[k][0], shifts[k][1]], method=method, kwargs=kwargs) @@ -2639,7 +2645,7 @@ def create_rec_mask(h, w, center=None, z=None): maskoffs -= shifts[:, :-1] else: maskoffs = -shifts[:, :-1] - + # Compute shift distances. dist = np.sqrt(np.sum(shifts[:, :2]**2, axis=1)) # pix dist *= self.database.obs[key]['PIXSCALE'][j]*1000 # mas @@ -2656,25 +2662,25 @@ def create_rec_mask(h, w, center=None, z=None): ww = np.where(ww == True)[0] if align_algo != 'header': log.warning(' --> The following frames might not be properly aligned: '+str(ww)) - + # Write FITS file and PSF mask. head_pri['XOFFSET'] = xoffset #arcseconds - head_pri['YOFFSET'] = yoffset #arcseconds + head_pri['YOFFSET'] = yoffset #arcseconds head_sci['CRPIX1'] = crpix1 head_sci['CRPIX2'] = crpix2 fitsfile = ut.write_obs(fitsfile, output_dir, data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs) maskfile = ut.write_msk(maskfile, mask, fitsfile) - + # Update spaceKLIP database. self.database.update_obs(key, j, fitsfile, maskfile, xoffset=xoffset, yoffset=yoffset, crpix1=crpix1, crpix2=crpix2) - + # Plot science frame alignment. colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] fig = plt.figure(figsize=(6.4, 4.8)) ax = plt.gca() for index, j in enumerate(ww_sci): ax.scatter(shifts_all[index][:, 0] * self.database.obs[key]['PIXSCALE'][j] * 1000, - shifts_all[index][:, 1] * self.database.obs[key]['PIXSCALE'][j] * 1000, + shifts_all[index][:, 1] * self.database.obs[key]['PIXSCALE'][j] * 1000, s=5, color=colors[index%len(colors)], marker='o', label='PA = %.0f deg' % self.database.obs[key]['ROLL_REF'][j]) ax.axhline(0., color='gray', lw=1, zorder=-1) # set zorder to ensure lines are drawn behind all the scatter points diff --git a/spaceKLIP/utils.py b/spaceKLIP/utils.py index 45d6d9a..a0c3af3 100644 --- a/spaceKLIP/utils.py +++ b/spaceKLIP/utils.py @@ -1079,6 +1079,38 @@ def interpret_dq_value(dq_value): return {'GOOD'} return dqflags_to_mnemonics(dq_value, pixel) +def gaussian_kernel(sigma_x=1, sigma_y=1, theta_degrees=0, n=6): + """ + Generates a 2D Gaussian kernel with specified standard deviations and rotation. + + Parameters: + sigma_x (float): Standard deviation of the Gaussian in the x direction. + sigma_y (float): Standard deviation of the Gaussian in the y direction. + theta_degrees (float): Rotation angle of the Gaussian kernel in degrees. + + Returns: + numpy.ndarray: The generated Gaussian kernel. + """ + # Ensure kernel size is at least 3x3 and odd + kernel_size_x = max(3, int(n * sigma_x + 1) | 1) # Ensure odd size + kernel_size_y = max(3, int(n * sigma_y + 1) | 1) # Ensure odd size + + # Convert theta from degrees to radians + theta = np.deg2rad(theta_degrees) + + # Create coordinate grids + x = np.linspace(-kernel_size_x // 2, kernel_size_x // 2, kernel_size_x) + y = np.linspace(-kernel_size_y // 2, kernel_size_y // 2, kernel_size_y) + x, y = np.meshgrid(x, y) + + # Rotate the coordinates + x_rot = x * np.cos(theta) + y * np.sin(theta) + y_rot = -x * np.sin(theta) + y * np.cos(theta) + + kernel = np.exp(-(x_rot ** 2 / (2 * sigma_x ** 2) + y_rot ** 2 / (2 * sigma_y ** 2))) + kernel /= kernel.sum() + return kernel + def get_dqmask(dqarr, bitvalues): """Get DQ mask from DQ array