Skip to content

Commit

Permalink
correct h5 inference
Browse files Browse the repository at this point in the history
  • Loading branch information
wwbwang committed May 16, 2024
1 parent 2a4c440 commit 2be9ba2
Show file tree
Hide file tree
Showing 5 changed files with 308 additions and 58 deletions.
123 changes: 123 additions & 0 deletions scripts/downsample_h5.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import tifffile\n",
"import numpy as np\n",
"import torch\n",
"import os\n",
"import random\n",
"import h5py\n",
"import cv2\n",
"import math\n",
"from scipy.ndimage.interpolation import zoom\n",
"from PIL import Image\n",
"\n",
"from tqdm import tqdm\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from sr_3dunet.utils.data_utils import random_crop_3d, random_crop_2d, augment_3d, augment_2d, preprocess, get_projection, get_rotated_img, crop_block"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ims_path = '/share/home/wangwb/workspace/sr_3dunet/datasets/cellbody/out_h5/output_res0.h5'\n",
"out_ims_path = '/share/home/wangwb/workspace/sr_3dunet/datasets/cellbody/out_h5/output_res2.h5'\n",
"\n",
"h5 = h5py.File(ims_path, 'r')\n",
"img_total = h5['DataSet']['ResolutionLevel 0']['TimePoint 0']['Channel 0']['Data']\n",
"h5_dir = 'DataSet/ResolutionLevel 2/TimePoint 0/Channel 0/Data'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"step_size = 512\n",
"\n",
"downscale_factor = 4\n",
"\n",
"with h5py.File(out_ims_path, 'w') as f:\n",
" f.create_dataset(h5_dir, shape=img_total.shape, chunks=(1, 128, 128), dtype=img_total.dtype)\n",
"\n",
"len1 = math.ceil(img_total.shape[0]/step_size)\n",
"len2 = math.ceil(img_total.shape[1]/step_size)\n",
"len3 = math.ceil(img_total.shape[2]/step_size)\n",
"pbar1 = tqdm(total=len1*len2*len3, unit='h5_img', desc='inference')\n",
"\n",
"for start_h in range(0, img_total.shape[0], step_size):\n",
" end_h = img_total.shape[0] if start_h+step_size>img_total.shape[0] else start_h+step_size\n",
" for start_w in range(0, img_total.shape[1], step_size):\n",
" end_w = img_total.shape[1] if start_w+step_size>img_total.shape[1] else start_w+step_size\n",
" for start_d in range(0, img_total.shape[2], step_size):\n",
" end_d = img_total.shape[2] if start_d+step_size>img_total.shape[2] else start_d+step_size\n",
" \n",
" new_shape = ((end_h-start_h) // downscale_factor,\n",
" (end_w-start_w) // downscale_factor,\n",
" (end_d-start_d) // downscale_factor)\n",
" \n",
" # img_total[start_h:end_h, start_w:end_w, start_d:end_d].resize(new_shape, resample=Image.BICUBIC)\n",
" new_img = zoom(img_total[start_h:end_h, start_w:end_w, start_d:end_d], zoom = 0.5, order=1)\n",
"\n",
" with h5py.File(out_ims_path, 'r+') as f:\n",
" f[h5_dir][start_h//2:end_h//2, start_w//2:end_w//2, start_d//2:end_d//2] = new_img\n",
"\n",
" pbar1.update(1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"original_matrix = np.random.rand(1000, 1000, 1000)\n",
"\n",
"\n",
"\n",
"new_shape = (original_matrix.shape[0] // downscale_factor,\n",
" original_matrix.shape[1] // downscale_factor,\n",
" original_matrix.shape[2] // downscale_factor)\n",
"\n",
"new_matrix = np.zeros(new_shape)\n",
"for i in range(new_shape[0]):\n",
" for j in range(new_shape[1]):\n",
" for k in range(new_shape[2]):\n",
" new_matrix[i, j, k] = np.mean(original_matrix[i*downscale_factor:(i+1)*downscale_factor,\n",
" j*downscale_factor:(j+1)*downscale_factor,\n",
" k*downscale_factor:(k+1)*downscale_factor])\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "MPCN",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
84 changes: 62 additions & 22 deletions scripts/inference_from_h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@
import torch
import tifffile
import h5py
import math
from os import path as osp
from tqdm import tqdm
from functools import partial

from sr_3dunet.utils.data_utils import preprocess, postprocess
from sr_3dunet.utils.inference_big_tif import handle_bigtif
from sr_3dunet.utils.inference_big_tif import handle_bigtif, extend_block
from sr_3dunet.archs.unet_3d_generator_arch import UNet_3d_Generator

def get_inference_model(args, device) -> UNet_3d_Generator:
"""return an on device model with eval mode"""
# set up model
model = UNet_3d_Generator(in_channels=1, out_channels=1, features=[64, 128, 256], norm_type=None, dim=3)
model = UNet_3d_Generator(in_channels=1, out_channels=1, features=[64, 128, 256, 512], norm_type=None, dim=3)

model_path = args.model_path
assert os.path.isfile(model_path), \
Expand Down Expand Up @@ -51,36 +52,75 @@ def main():
print("Model size: {:.5f}M".format(sum(p.numel() for p in model.parameters())*4/1048576))
print("Model parameters: {}".format(sum(p.numel() for p in model.parameters())))

model = partial(handle_bigtif, model, args.piece_size, args.piece_overlap)

percentiles=[0, 0.9999] # [0.01,0.999999] # [0.01, 0.9985]
dataset_mean=0

h5 = h5py.File(args.input, 'r')
# img = h5['DataSet']['ResolutionLevel 0']['TimePoint 0']['Channel 3']['Data']
img_path = args.h5_dir.split('/')
img = h5
img_total = h5
for key in img_path:
img = img[key]
img_total = img_total[key]
h, w, d = img_total.shape

with h5py.File(args.output, 'w') as f:
f.create_dataset(args.h5_dir, shape=img_total.shape, chunks=(1, 256, 256), dtype=img_total.dtype)

len1 = math.ceil(h/(args.piece_size-args.piece_overlap))
len2 = math.ceil(w/(args.piece_size-args.piece_overlap))
len3 = math.ceil(d/(args.piece_size-args.piece_overlap))
pbar1 = tqdm(total=len1*len2*len3, unit='h5_img', desc='inference')

img = np.clip(img, 0, 65535)
origin_shape = img.shape
img, min_value, max_value = preprocess(img, percentiles, dataset_mean)
piece_size = args.piece_size
overlap = args.piece_overlap

for start_h in range(0, h, piece_size-overlap):
end_h = start_h + piece_size

for start_w in range(0, w, piece_size-overlap):
end_w = start_w + piece_size

for start_d in range(0, d, piece_size-overlap):
end_d = start_d + piece_size

img = img_total[start_h:end_h, start_w:end_w, start_d:end_d]
img = np.clip(img, 0, 65535)
img, min_value, max_value = preprocess(img, percentiles, dataset_mean)

end_h = h if end_h>h else end_h
end_w = w if end_w>w else end_w
end_d = d if end_d>d else end_d
origin_shape = img.shape

if end_h == h or end_w==w or end_d==d:
img = np.pad(img, ((0, piece_size-end_h+start_h),
(0, piece_size-end_w+start_w),
(0, piece_size-end_d+start_d)), mode='constant')

h_cutleft = 0 if start_h==0 else overlap//2
w_cutleft = 0 if start_w==0 else overlap//2
d_cutleft = 0 if start_d==0 else overlap//2

h_cutright = 0 if end_h==h else overlap//2
w_cutright = 0 if end_w==w else overlap//2
d_cutright = 0 if end_d==d else overlap//2

img = img.astype(np.float32)[None, None,]
img = torch.from_numpy(img).to(device) # to float32

img = img.astype(np.float32)[None, None,]
img = torch.from_numpy(img).to(device) # to float32
start_time = time.time()
torch.cuda.synchronize()
out_img = model(img)
out_img = out_img[:,:,0+h_cutleft:end_h-start_h-h_cutright, 0+w_cutleft:end_w-start_w-w_cutright, 0+d_cutleft:end_d-start_d-d_cutright]
torch.cuda.synchronize()
end_time = time.time()

start_time = time.time()
torch.cuda.synchronize()
out_img = model(img)
torch.cuda.synchronize()
end_time = time.time()
print("avg-time_model:", (end_time-start_time)*1000, "ms,", "N, C, H, W, D:", origin_shape)
out_img = out_img[0,0].cpu().numpy()
out_img = postprocess(out_img, min_value, max_value, dataset_mean)

with h5py.File(args.output, 'r+') as f:
f[args.h5_dir][start_h+h_cutleft:end_h-h_cutright, start_w+w_cutleft:end_w-w_cutright, start_d+d_cutleft:end_d-d_cutright] = out_img

out_img = out_img[0,0].cpu().numpy()
with h5py.File(args.output, 'w') as hf:
hf.create_dataset(args.h5_dir, data=postprocess(out_img, min_value, max_value, dataset_mean))
# tifffile.imwrite(args.output, postprocess(out_img, min_value, max_value, dataset_mean))
pbar1.update(1)

if __name__ == '__main__':
main()
127 changes: 95 additions & 32 deletions scripts/showimg.ipynb

Large diffs are not rendered by default.

27 changes: 27 additions & 0 deletions slurm_inference_from_h5.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/bin/bash

#SBATCH --job-name=inference_cellbody
#SBATCH --nodelist=c001
#SBATCH --gres=gpu:2
#SBATCH --ntasks-per-node=2
#SBATCH --ntasks=2
#SBATCH --cpus-per-task=16

source activate MPCN

# TODO num_io_consumer half

# i: Path to a single H5 file
# o: Path to the output single H5 file
# h5_dir: Dictionary of images in the specified H5 file
# model_path:
# piece_size: Determining the size of smaller images
# piece_overlap: Overlap between neighboring small images

CUDA_VISIBLE_DEVICES=0 python scripts/inference_from_h5.py \
-i /share/data/VISoR_Reconstruction/SIAT_ION/LiuCiRong/20230910_CJ004/CJ4-1um-ROI1/CJ4ROI1.ims \
-o datasets/cellbody/out_h5/output_res0.h5 \
--h5_dir "DataSet/ResolutionLevel 0/TimePoint 0/Channel 0/Data" \
--model_path weights/MPCN_VISoR_oldbaseline_cellbody_correctproj_net_g_80000.pth \
--piece_size 128 --piece_overlap 16

5 changes: 1 addition & 4 deletions sr_3dunet/utils/inference_big_tif.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,16 @@ def handle_bigtif(model, piece_size, overlap, img):

img_out[:, :, start_h+h_cutleft:end_h-h_cutright, start_w+w_cutleft:end_w-w_cutright, start_d+d_cutleft:end_d-d_cutright] = img_tmp[
:,:,0+h_cutleft:piece_size-h_cutright, 0+w_cutleft:piece_size-w_cutright, 0+d_cutleft:piece_size-d_cutright]

if end_d==d_now:
break
if end_w==w_now:
break
if end_h==h_now:
break



return img_out[:,:,:h,:w,:d]




# if (end_h-start_h)%piece_size!=0 or (end_w-start_w)%piece_size!=0 or (end_d-start_d)%piece_size!=0:
# extend_img = extend_block(img[:, :, start_h:end_h, start_w:end_w, start_d:end_d], piece_mod_size)
Expand Down

0 comments on commit 2be9ba2

Please sign in to comment.