diff --git a/ml4h/models/diffusion_blocks.py b/ml4h/models/diffusion_blocks.py index fcb8126cc..cc4f35ff2 100644 --- a/ml4h/models/diffusion_blocks.py +++ b/ml4h/models/diffusion_blocks.py @@ -896,7 +896,7 @@ def plot_images(self, epoch=None, logs=None, num_rows=1, num_cols=4, reseed=None def plot_reconstructions( self, batch, diffusion_amount=0, - epoch=None, logs=None, num_rows=4, num_cols=4, + epoch=None, logs=None, num_rows=4, num_cols=4, prefix='./figures/', ): images = batch[0][self.input_map.input_name()] self.normalizer.update_state(images) @@ -921,7 +921,11 @@ def plot_reconstructions( plt.imshow(generated_images[index], cmap='gray') plt.axis("off") plt.tight_layout() - plt.show() + now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M') + figure_path = os.path.join(prefix, f'diffusion_image_generations_{now_string}{IMAGE_EXT}') + if not os.path.exists(os.path.dirname(figure_path)): + os.makedirs(os.path.dirname(figure_path)) + plt.savefig(figure_path, bbox_inches="tight") plt.close() def control_plot_images( diff --git a/ml4h/models/train.py b/ml4h/models/train.py index 612c40c8f..683565c75 100755 --- a/ml4h/models/train.py +++ b/ml4h/models/train.py @@ -375,6 +375,7 @@ def train_diffusion_control_model(args, supervised=False): if args.inspect_model: data, labels, paths = big_batch_from_minibatch_generator(generate_test, 1) predictions_to_pngs(data, args.tensor_maps_in, args.tensor_maps_in, data, labels, paths, '{args.output_folder}/{args.id}/') + model.plot_reconstructions(data, prefix=f'{args.output_folder}/{args.id}/') interpolate_controlled_generations(model, args.tensor_maps_out, args.tensor_maps_out[0], args.batch_size, f'{args.output_folder}/{args.id}/') if model.input_map.axes() == 2: