Skip to content

Commit

Permalink
Merge pull request #769 from anarkiwi/ref
Browse files Browse the repository at this point in the history
batch arg/state, width/height, fix mesh_psd peak.
  • Loading branch information
anarkiwi authored Jul 6, 2023
2 parents 2c11626 + 5a10446 commit cc11965
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 90 deletions.
147 changes: 69 additions & 78 deletions gamutrf/waterfall.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,32 +229,27 @@ def argument_parser():
"--width",
default=28,
type=float,
help="Waterfall width",
help="Waterfall image width",
)
parser.add_argument(
"--height",
default=10,
type=float,
help="Waterfall image height",
)
parser.add_argument(
"--waterfall_height",
default=100,
type=int,
help="Waterfall height",
)
return parser


def reset_fig(
config,
state,
):
# RESET FIGURE
state.fig.clf()
plt.tight_layout()
plt.subplots_adjust(hspace=0.15)
state.ax_psd = state.fig.add_subplot(3, 1, 1)
state.ax = state.fig.add_subplot(3, 1, (2, 3))
state.psd_title = state.ax_psd.text(
0.5, 1.05, "", transform=state.ax_psd.transAxes, va="center", ha="center"
)
def reset_mesh_psd(config, state):
if state.mesh_psd:
state.mesh_psd.remove()

# PSD
XX, YY = np.meshgrid(
np.linspace(
config.min_freq,
Expand All @@ -263,12 +258,30 @@ def reset_fig(
),
np.linspace(state.db_min, state.db_max, config.psd_db_resolution),
)

state.psd_x_edges = XX[0]
state.psd_y_edges = YY[:, 0]

state.mesh_psd = state.ax_psd.pcolormesh(
XX, YY, np.zeros(XX[:-1, :-1].shape), shading="flat"
)


def reset_fig(
config,
state,
):
# RESET FIGURE
state.fig.clf()
plt.tight_layout()
plt.subplots_adjust(hspace=0.15)
state.ax_psd = state.fig.add_subplot(3, 1, 1)
state.ax = state.fig.add_subplot(3, 1, (2, 3))
state.psd_title = state.ax_psd.text(
0.5, 1.05, "", transform=state.ax_psd.transAxes, va="center", ha="center"
)

reset_mesh_psd(config, state)
(state.peak_lns,) = state.ax_psd.plot(
state.X[0],
state.db_min * np.ones(state.freq_data.shape[1]),
Expand Down Expand Up @@ -352,30 +365,27 @@ def reset_fig(
0.5, 1.05, "", transform=state.ax.transAxes, va="center", ha="center"
)

state.ax.xaxis.set_major_locator(MultipleLocator(state.major_tick_separator))
state.ax.xaxis.set_major_formatter("{x:.0f}")
state.ax.xaxis.set_minor_locator(state.minor_tick_separator)
state.ax_psd.xaxis.set_major_locator(MultipleLocator(state.major_tick_separator))
state.ax_psd.xaxis.set_major_formatter("{x:.0f}")
state.ax_psd.xaxis.set_minor_locator(state.minor_tick_separator)
for ax in (state.ax.xaxis, state.ax_psd.xaxis):
ax.set_major_locator(MultipleLocator(state.major_tick_separator))
ax.set_major_formatter("{x:.0f}")
ax.set_minor_locator(state.minor_tick_separator)

for ax in (state.ax_psd.yaxis, state.cbar_ax.yaxis, state.ax.yaxis):
ax.set_animated(True)

state.ax_psd.yaxis.set_animated(True)
state.cbar_ax.yaxis.set_animated(True)
state.ax.yaxis.set_animated(True)
if not config.batch:
plt.show(block=False)
plt.pause(0.1)

state.background = state.fig.canvas.copy_from_bbox(state.fig.bbox)

state.ax.draw_artist(state.mesh)
state.fig.canvas.blit(state.ax.bbox)
if config.savefig_path:
safe_savefig(config.savefig_path)

for ln in state.top_n_lns:
ln.set_alpha(0.75)

if config.savefig_path:
safe_savefig(config.savefig_path)


def init_state(
config,
Expand Down Expand Up @@ -417,12 +427,14 @@ def init_fig(

plt.rcParams["savefig.facecolor"] = "#2A3459"
plt.rcParams["figure.facecolor"] = "#2A3459"
text_color = "#d2d5dd"
plt.rcParams["text.color"] = text_color
plt.rcParams["axes.labelcolor"] = text_color
plt.rcParams["xtick.color"] = text_color
plt.rcParams["ytick.color"] = text_color
plt.rcParams["axes.facecolor"] = text_color
for param in (
"text.color",
"axes.labelcolor",
"xtick.color",
"ytick.color",
"axes.facecolor",
):
plt.rcParams[param] = "#d2d5dd"

state.fig = plt.figure(figsize=(config.width, config.height), dpi=100)
if not config.batch:
Expand Down Expand Up @@ -586,27 +598,13 @@ def update_fig(config, state, zmqr, rotate_secs, save_time):
state.data /= np.max(state.data)
# data /= np.max(data, axis=1)[:,None]

state.fig.canvas.restore_region(state.background)

top_n_bins = state.freq_bins[
np.argsort(
np.nanvar(state.db_data - np.nanmin(state.db_data, axis=0), axis=0)
)[::-1][: config.top_n]
]

for i, ln in enumerate(state.top_n_lns):
ln.set_xdata([top_n_bins[i]] * len(state.Y[:, 0]))

state.fig.canvas.blit(state.ax.yaxis.axes.figure.bbox)

scan_time = scan_df.ts.iloc[-1]
row_time = datetime.datetime.fromtimestamp(scan_time)
state.scan_times.append(scan_time)
state.scan_config_history[scan_time] = scan_configs
if len(state.scan_times) > config.waterfall_height:
remove_time = state.scan_times.pop(0)
if state.save_path:
state.scan_config_history.pop(remove_time)
# assert len(scan_config_history) <= waterfall_height
row_time = datetime.datetime.fromtimestamp(scan_time)
state.scan_config_history.pop(remove_time)

if state.counter % config.y_label_skip == 0:
state.y_labels.append(row_time.strftime("%Y-%m-%d %H:%M:%S"))
Expand All @@ -621,30 +619,21 @@ def update_fig(config, state, zmqr, rotate_secs, save_time):

state.ax.set_yticks(state.y_ticks, labels=state.y_labels)

if state.save_path:
state.scan_config_history[scan_time] = scan_configs

state.counter += 1

if state.counter % config.draw_rate == 0:
XX, YY = np.meshgrid(
np.linspace(
config.min_freq,
config.max_freq,
int(
(config.max_freq - config.min_freq) / (config.freq_resolution)
+ 1
),
),
np.linspace(state.db_min, state.db_max, config.psd_db_resolution),
)
state.fig.canvas.restore_region(state.background)

state.psd_x_edges = XX[0]
state.psd_y_edges = YY[:, 0]
top_n_bins = state.freq_bins[
np.argsort(
np.nanvar(state.db_data - np.nanmin(state.db_data, axis=0), axis=0)
)[::-1][: config.top_n]
]

state.mesh_psd = state.ax_psd.pcolormesh(
XX, YY, np.zeros(XX[:-1, :-1].shape), shading="flat"
)
for i, ln in enumerate(state.top_n_lns):
ln.set_xdata([top_n_bins[i]] * len(state.Y[:, 0]))

state.fig.canvas.blit(state.ax.yaxis.axes.figure.bbox)

# db_norm = db_data
db_norm = (state.db_data - state.db_min) / (state.db_max - state.db_min)
Expand All @@ -653,7 +642,7 @@ def update_fig(config, state, zmqr, rotate_secs, save_time):
(state.db_data - np.nanmin(state.db_data, axis=0)) - config.snr_min
) / (config.snr_max - config.snr_min)

# ax_psd.clear()
reset_mesh_psd(config, state)

state.ax_psd.set_ylim(state.db_min, state.db_max)
state.mesh_psd.set_array(state.cmap_psd(state.data.T))
Expand Down Expand Up @@ -686,11 +675,9 @@ def update_fig(config, state, zmqr, rotate_secs, save_time):

state.sm.set_clim(vmin=state.db_min, vmax=state.db_max)
state.cbar.update_normal(state.sm)
# cbar.draw_all()
state.cbar_ax.draw_artist(state.cbar_ax.yaxis)
state.fig.canvas.blit(state.cbar_ax.yaxis.axes.figure.bbox)
state.ax_psd.draw_artist(state.ax_psd.yaxis)
state.fig.canvas.blit(state.ax_psd.yaxis.axes.figure.bbox)
for ax in (state.cbar_ax.yaxis, state.ax_psd.yaxis):
state.cbar_ax.draw_artist(ax)
state.fig.canvas.blit(ax.axes.figure.bbox)
for ln in state.top_n_lns:
state.ax.draw_artist(ln)

Expand Down Expand Up @@ -734,14 +721,15 @@ def __init__(
base_save_path,
width,
height,
waterfall_height,
batch,
):
self.engine = engine
self.plot_snr = plot_snr
self.savefig_path = savefig_path
self.snr_min = 0
self.snr_max = 50
self.waterfall_height = 100 # number of waterfall rows
self.waterfall_height = waterfall_height # number of waterfall rows
self.marker_distance = 0.1
self.scale = 1e6
self.freq_resolution = sampling_rate / fft_len / self.scale
Expand Down Expand Up @@ -795,7 +783,7 @@ def __init__(self):
self.ax_psd = None
self.ax = None
self.save_path = None
self.mesg_psd = None
self.mesh_psd = None
self.data = None
self.peak_finder = None

Expand All @@ -815,6 +803,7 @@ def waterfall(
rotate_secs,
width,
height,
waterfall_height,
batch,
zmqr,
):
Expand All @@ -830,6 +819,7 @@ def waterfall(
base_save_path,
width,
height,
waterfall_height,
batch,
)
state = WaterfallState()
Expand Down Expand Up @@ -943,6 +933,7 @@ def main():
args.rotate_secs,
args.width,
args.height,
args.waterfall_height,
batch,
zmqr,
)
Expand Down
33 changes: 21 additions & 12 deletions tests/test_waterfall.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,33 @@ class FakeZmqReceiver:
def __init__(self, run_secs, peak_min, peak_max, peak_val):
self.start_time = time.time()
self.run_secs = run_secs
df = pd.DataFrame(
[
{"ts": 1.0, "freq": 1 + (i * 0.001), "db": peak_val / 2}
for i in range(1000)
]
)
df.loc[(df.freq >= peak_min) & (df.freq <= peak_max), "db"] = peak_val
self.fake_results = [
({}, df),
(None, None),
]
self.serve_results = None
self.peak_min = peak_min
self.peak_max = peak_max
self.peak_val = peak_val

def healthy(self):
return time.time() - self.start_time < self.run_secs

def read_buff(self):
if not self.serve_results:
self.serve_results = copy.deepcopy(self.fake_results)
df = pd.DataFrame(
[
{
"ts": time.time(),
"freq": 1 + (i * 0.001),
"db": self.peak_val / 2,
}
for i in range(1000)
]
)
df.loc[
(df.freq >= self.peak_min) & (df.freq <= self.peak_max), "db"
] = self.peak_val
self.serve_results = [
({}, df),
(None, None),
]
return self.serve_results.pop()

def stop(self):
Expand Down Expand Up @@ -68,6 +76,7 @@ def test_run_waterfall(self):
60, # args.rotate_secs,
10, # args.width,
5, # args.height,
10, # args.waterfall_height,
True, # args.batch
zmqr,
)
Expand Down

0 comments on commit cc11965

Please sign in to comment.