Skip to content

Commit

Permalink
simplify plot helper by removing logimshow
Browse files Browse the repository at this point in the history
  • Loading branch information
joseph-long committed Jun 3, 2024
1 parent b5d2007 commit 070efdc
Showing 1 changed file with 18 additions and 26 deletions.
44 changes: 18 additions & 26 deletions doodads/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,24 @@
'add_colorbar',
'imshow',
'matshow',
'logimshow',
'image_grid',
'show_diff',
'three_panel_diff_plot',
'norm',
'zscale',
'contrast_limits_plot',
'inferno_k',
'inferno_g',
'magma_k',
'magma_g',
'gray_k',
'gray_g',
'complex_color'
)
inferno_k = matplotlib.cm.inferno.copy()
inferno_k.set_bad('k')
inferno_g = matplotlib.cm.inferno.copy()
inferno_g.set_bad('0.5')

magma_k = matplotlib.cm.magma.copy()
magma_k.set_bad('k')
Expand Down Expand Up @@ -139,7 +144,13 @@ def imshow(im, *args, ax=None, log=False, colorbar=True, title=None, origin='cen
raise NotImplementedError("No log=True for complex images (yet)")
im = complex_color(im)
if log:
mappable = logimshow(im, *args, ax=ax, **kwargs)
vmin = kwargs.pop('vmin') if 'vmin' in kwargs else None
vmax = kwargs.pop('vmax') if 'vmax' in kwargs else None
norm = astroviz.simple_norm(im, stretch='log', min_cut=vmin, max_cut=vmax)
kwargs.update({
'norm': norm
})
mappable = ax.imshow(im, *args, **kwargs)
else:
mappable = ax.imshow(im, *args, **kwargs)
if colorbar:
Expand All @@ -154,16 +165,6 @@ def imshow(im, *args, ax=None, log=False, colorbar=True, title=None, origin='cen
ax.set(xlim=(ctr_x-crop, ctr_x+crop), ylim=(ctr_y-crop, ctr_y+crop))
return mappable

@supply_argument(ax=lambda: gca())
def logimshow(im, *args, ax=None, **kwargs):
vmin = kwargs.pop('vmin') if 'vmin' in kwargs else None
vmax = kwargs.pop('vmax') if 'vmax' in kwargs else None
norm = astroviz.simple_norm(im, stretch='log', vmin=vmin, vmax=vmax)
kwargs.update({
'norm': norm
})
return ax.imshow(im, *args, **kwargs)

@supply_argument(ax=lambda: gca())
def matshow(im, *args, **kwargs):
kwargs.update({'origin': 'upper'})
Expand All @@ -190,10 +191,7 @@ def image_grid(cube, columns, colorbar=False, cmap=None, fig=None, log=False, ma
if idx >= cube.shape[0]:
break
ax = fig.add_subplot(gs[row, col])
if log:
im = logimshow(cube[idx], cmap=cmap, vmin=vmin, vmax=vmax)
else:
im = ax.imshow(cube[idx], cmap=cmap, vmin=vmin, vmax=vmax)
im = imshow(cube[idx], cmap=cmap, vmin=vmin, vmax=vmax, log=log, ax=ax)
if colorbar:
add_colorbar(im)
return fig
Expand Down Expand Up @@ -243,7 +241,7 @@ def show_diff(im1, im2, ax=None, vmin=None, vmax=None, cmap=matplotlib.cm.RdBu_r
clim_min = vmin
else:
clim_min = -clim
im = ax.imshow(diff, vmin=clim_min, vmax=clim, cmap=cmap, **kwargs) # pylint: disable=invalid-unary-operand-type
im = imshow(diff, vmin=clim_min, vmax=clim, cmap=cmap, ax=ax, colorbar=False, **kwargs) # pylint: disable=invalid-unary-operand-type
if colorbar:
cbar = add_colorbar(im)
if as_percent:
Expand Down Expand Up @@ -276,14 +274,8 @@ def three_panel_diff_plot(image_a, image_b, title_a='', title_b='',
fig = ax_a.figure
if match_clim and (not 'vmin' in kwargs) and (not 'vmax' in kwargs):
kwargs.update({'vmin': np.min([image_a, image_b]), 'vmax': np.max([image_a, image_b])})
if log:
mappable_a = logimshow(image_a, ax=ax_a, **kwargs)
mappable_b = logimshow(image_b, ax=ax_b, **kwargs)
else:
mappable_a = ax_a.imshow(image_a, **kwargs)
mappable_b = ax_b.imshow(image_b, **kwargs)
add_colorbar(mappable_a)
add_colorbar(mappable_b)
imshow(image_a, ax=ax_a, log=log, **kwargs)
imshow(image_b, ax=ax_b, log=log, **kwargs)
ax_a.set_title(title_a)
ax_b.set_title(title_b)
ax_aminusb.set_title(title_diff)
Expand Down Expand Up @@ -353,4 +345,4 @@ def complex_color(z, log=False):
c = np.vectorize(hls_to_rgb)(h,l,s)
c = np.array(c)
c = c.swapaxes(0,2)
return c
return c

0 comments on commit 070efdc

Please sign in to comment.