Skip to content

Commit

Permalink
Merge pull request #183 from kammerje/stretch_align
Browse files Browse the repository at this point in the history
Add sqrt_stretch and custom mask to align steps
  • Loading branch information
AarynnCarter authored Jun 18, 2024
2 parents 3df125f + 612c7dd commit f7ffaba
Showing 1 changed file with 69 additions and 3 deletions.
72 changes: 69 additions & 3 deletions spaceKLIP/imagetools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1599,6 +1599,7 @@ def recenter_frames(self,
method='fourier',
subpix_first_sci_only=False,
spectral_type='G2V',
shft_exp=1,
kwargs={},
subdir='recentered'):
"""
Expand All @@ -1624,6 +1625,8 @@ def recenter_frames(self,
spectral_type : str, optional
Host star spectral type for the WebbPSF model used to determine the
star position behind the coronagraphic mask. The default is 'G2V'.
shft_exp : float, optional
Take image to the given power before cross correlating for shifts, default is 1. For instance, 1/2 helps align nircam bar/narrow data (or other data with weird speckles)
kwargs : dict, optional
Keyword arguments for the scipy.ndimage.shift routine. The default
is {}.
Expand Down Expand Up @@ -1691,6 +1694,7 @@ def recenter_frames(self,
xc, yc, xshift, yshift = self.find_nircam_centers(data0=data[k].copy(),
key=key,
j=j,
shft_exp=shft_exp,
spectral_type=spectral_type,
date=head_pri['DATE-BEG'],
output_dir=output_dir)
Expand Down Expand Up @@ -1820,6 +1824,7 @@ def find_nircam_centers(self,
key,
j,
spectral_type='G2V',
shft_exp=1,
date=None,
output_dir=None,
fov_pix=65,
Expand All @@ -1840,6 +1845,8 @@ def find_nircam_centers(self,
spectral_type : str, optional
Host star spectral type for the WebbPSF model used to determine the
star position behind the coronagraphic mask. The default is 'G2V'.
shft_exp : float, optional
Take image to the given power before cross correlating for shifts, default is 1.
date : str, optional
Observation date in the format 'YYYY-MM-DDTHH:MM:SS.MMM'. The
default is None.
Expand Down Expand Up @@ -1919,9 +1926,16 @@ def find_nircam_centers(self,
xycen=(xc, yc),
npix=fov_pix)

if shft_exp == 1:
img1 = datasub* masksub
img2 = model_psf* masksub
else:
img1 = np.power(np.abs(datasub), shft_exp)* masksub
img2 = np.power(np.abs(model_psf), shft_exp) * masksub

# Determine relative shift between data and model PSF.
shift, error, phasediff = phase_cross_correlation(datasub * masksub,
model_psf * masksub,
shift, error, phasediff = phase_cross_correlation(img1,
img2,
upsample_factor=1000,
normalization=None)
yshift, xshift = shift
Expand Down Expand Up @@ -1973,6 +1987,9 @@ def find_nircam_centers(self,
def align_frames(self,
method='fourier',
align_algo='leastsq',
mask_override=None,
msk_shp=8,
shft_exp=1,
kwargs={},
subdir='aligned'):
"""
Expand All @@ -1985,6 +2002,12 @@ def align_frames(self,
align_algo : 'leastsq' or 'header'
Algorithm to determine the alignment offsets. Default is 'leastsq',
'header' assumes perfect header offsets.
mask_override : str, optional
Mask some pixels when cross correlating for shifts
msk_shp : int, optional
Shape (height or radius) for custom mask invoked by "mask_override"
shft_exp : float, optional
Take image to the given power before cross correlating for shifts, default is 1. For instance, 1/2 helps align nircam bar/narrow data (or other data with weird speckles)
kwargs : dict, optional
Keyword arguments for the scipy.ndimage.shift routine. The default
is {}.
Expand All @@ -2002,6 +2025,31 @@ def align_frames(self,
output_dir = os.path.join(self.database.output_dir, subdir)
if not os.path.exists(output_dir):
os.makedirs(output_dir)

# useful masks for computing shifts:
def create_circular_mask(h, w, center=None, radius=None):

if center is None: # use the middle of the image
center = (int(w/2), int(h/2))
if radius is None: # use the smallest distance between the center and image walls
radius = min(center[0], center[1], w-center[0], h-center[1])

Y, X = np.ogrid[:h, :w]
dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)

mask = dist_from_center <= radius
return mask

def create_rec_mask(h, w, center=None, z=None):
if center is None: # use the middle of the image
center = (int(w/2), int(h/2))
if z is None:
z = h//4

mask = np.zeros((h,w), dtype=bool)
mask[center[1]-z:center[1]+z,:] = True

return mask

# Loop through concatenations.
database_temp = deepcopy(self.database.obs)
Expand All @@ -2024,6 +2072,20 @@ def align_frames(self,
data, erro, pxdq, head_pri, head_sci, is2d, imshifts, maskoffs = ut.read_obs(fitsfile)
maskfile = self.database.obs[key]['MASKFILE'][j]
mask = ut.read_msk(maskfile)
if mask_override is not None:
if mask_override == 'circ':
mask_circ = create_circular_mask(data[0].shape[0],data[0].shape[1], radius=msk_shp)
elif mask_override == 'rec':
mask_circ = create_rec_mask(data[0].shape[0],data[0].shape[1], z=msk_shp)
else:
raise ValueError('There are `circ` and `rec` custom masks available')
mask_temp = data[0].copy()
mask_temp[~mask_circ] = 1
mask_temp[mask_circ] = 0
elif mask is None:
mask_temp = np.ones_like(data[0])
else:
mask_temp = mask

# Align frames.
head, tail = os.path.split(fitsfile)
Expand Down Expand Up @@ -2064,10 +2126,14 @@ def align_frames(self,
if (np.abs(xshift) < 1e-3) and (np.abs(yshift) < 1e-3):
p0 = np.array([0., 0., 1.])
if align_algo == 'leastsq':
if shft_exp != 1:
args = (np.power(np.abs(data[k]), shft_exp), np.power(np.abs(ref_image), shft_exp), mask_temp, method, kwargs)
else:
args = (data[k], ref_image, mask_temp, method, kwargs)
# Use header values to initiate least squares fit
pp = leastsq(ut.alignlsq,
p0,
args=(data[k], ref_image, mask, method, kwargs))[0]
args=args)[0]
elif align_algo == 'header':
# Just assume the header values are correct
pp = p0
Expand Down

0 comments on commit f7ffaba

Please sign in to comment.