Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/DecentAI/cellDetection into…
Browse files Browse the repository at this point in the history
… main

Temp pull for UNET3 update
  • Loading branch information
stevekangLunit committed Feb 24, 2023
2 parents a89b8b4 + fc69477 commit a08f3a1
Show file tree
Hide file tree
Showing 10 changed files with 327 additions and 129 deletions.
Empty file modified CellSegmentation/data_inspection.py
100644 → 100755
Empty file.
13 changes: 9 additions & 4 deletions CellSegmentation/dataset.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@ def __len__(self):

def __getitem__(self, idx):
if idx < self.no_of_ims:
image = np.array(io.imread(self.images[idx]),dtype=np.float32)
mask = np.array(io.imread(self.masks[idx]),dtype=np.float32)
# binarize mask 0 - background 1-foreground

name = self.images[idx]
image = np.array(io.imread(name),dtype=np.float32)
mask_name = name.replace(".tif","_mask.png")
mask_name = mask_name.replace("_images/","_images_masks/")
mask = np.array(io.imread(mask_name),dtype=np.float32)

mask[mask==255.0] = 1.0
else:
print('set idx out of bound!')
Expand All @@ -35,4 +39,5 @@ def __getitem__(self, idx):
image = augmentations["image"]
mask = augmentations["mask"]

return image, mask
return image, mask

2 changes: 2 additions & 0 deletions CellSegmentation/init_weights.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

import torch
import torch.nn as nn
from torch.nn import init
Expand Down Expand Up @@ -57,4 +58,5 @@ def init_weights(net, init_type='normal'):
elif init_type == 'orthogonal':
net.apply(weights_init_orthogonal)
else:

raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
34 changes: 19 additions & 15 deletions CellSegmentation/label_generation.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import json
import os

from tqdm import tqdm


dir_to_test_ims = r'/lunit/home/stevekang/decentAI/MCF7/LIVECell_dataset_2021/images_mask/livecell_test_images'
dir_to_train_ims = r'/lunit/home/stevekang/decentAI/MCF7/LIVECell_dataset_2021/images_mask/livecell_train_val_images'
dir_to_val_labels = r'/lunit/home/stevekang/decentAI/MCF7/LIVECell_dataset_2021/annotations/LIVECell_single_cells/mcf7/val.json'
dir_to_train_labels = r'/lunit/home/stevekang/decentAI/MCF7/LIVECell_dataset_2021/annotations/LIVECell_single_cells/mcf7/train.json'
dir_to_test_labels = r'/lunit/home/stevekang/decentAI/MCF7/LIVECell_dataset_2021/annotations/LIVECell_single_cells/mcf7/test.json'
dir_to_test_ims = r'decentAI/cellDetection/MCF7/LIVECell_dataset_2021/images/livecell_test_images'
dir_to_train_ims = r'decentAI/cellDetection/MCF7/LIVECell_dataset_2021/images/livecell_train_val_images'
dir_to_val_labels = r'decentAI/cellDetection/MCF7/LIVECell_dataset_2021/annotations/LIVECell_single_cells/mcf7/val.json'
dir_to_train_labels = r'decentAI/cellDetection/MCF7/LIVECell_dataset_2021/annotations/LIVECell_single_cells/mcf7/train.json'
dir_to_test_labels = r'decentAI/cellDetection/MCF7/LIVECell_dataset_2021/annotations/LIVECell_single_cells/mcf7/test.json'


# dir_to_train_labels = r'C:\Code\Dataset\LIVECell_dataset_2021\annotations\LIVECell\livecell_coco_train.json'
# dir_to_test_labels = r'C:\Code\Dataset\LIVECell_dataset_2021\annotations\LIVECell\livecell_coco_test.json'
Expand Down Expand Up @@ -102,12 +104,6 @@ def generate_segmentation_mask(im, annotation):



dir_to_masks = dir_to_test_ims + '_masks'
if not os.path.isdir(dir_to_masks):
os.makedirs(dir_to_masks)

# load image file paths
ims = glob.glob(dir_to_test_ims + '/*.tif')

# load label json object
with open(dir_to_val_labels) as f:
Expand All @@ -119,11 +115,19 @@ def generate_segmentation_mask(im, annotation):
with open(dir_to_test_labels) as f:
coco_test_labels = json.load(f)

# generate label masks
for image in ims:

dir_to_masks = dir_to_train_ims + '_masks'
if not os.path.isdir(dir_to_masks):
os.makedirs(dir_to_masks)

# load image file paths
ims = glob.glob(dir_to_train_ims + '/*.tif')

for image in tqdm(ims):
imname = os.path.split(image)[-1]
im = io.imread(image)
single_label, train_val_flag = get_image_id_by_name(imname, coco_test_labels, coco_test_labels)
annotations = get_annotations_by_id(single_label, train_val_flag, coco_test_labels, coco_test_labels)
single_label, train_val_flag = get_image_id_by_name(imname, coco_val_labels, coco_train_labels)
annotations = get_annotations_by_id(single_label, train_val_flag, coco_val_labels, coco_train_labels)

segmentation_mask = generate_segmentation_mask(im, annotations)
io.imsave(dir_to_masks+'/'+ imname[:-4]+'_mask.png', segmentation_mask)
4 changes: 2 additions & 2 deletions CellSegmentation/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding

if is_batchnorm:
for i in range(1, n + 1):
conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),nn.Dropout(p=0.3),
conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
nn.BatchNorm2d(out_size), nn.ReLU(inplace=True),)
setattr(self, 'conv%d' % i, conv)
in_size = out_size
else:
for i in range(1, n + 1):
conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), nn.Dropout(p=0.3), nn.ReLU(inplace=True), )
conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), nn.ReLU(inplace=True), )
setattr(self, 'conv%d' % i, conv)
in_size = out_size

Expand Down
3 changes: 3 additions & 0 deletions CellSegmentation/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
tqdm
imagecodecs
argparse
1 change: 1 addition & 0 deletions CellSegmentation/runcommand.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python train.py --mode test --best_pth UNet3_no_sigmoid_MCD.pth.tar
148 changes: 82 additions & 66 deletions CellSegmentation/train.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from unet import Unet, UNet3Plus

from unet import Unet,UNet3Plus

from dataset import CellDataset
from utils import (
load_checkpoint,
Expand All @@ -14,45 +16,50 @@
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch.nn.functional as F
import numpy as np

import argparse
import os


LEARNING_RATE = 1E-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 8
# hyperparams

LEARNING_RATE = 1E-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# DEVICE = "cpu"
BATCH_SIZE = 8
NUM_EPOCHS = 1
NUM_EPOCHS = 20

NUM_WORKERS = 8
IMAGE_HEIGHT = 520
IMAGE_WIDTH = 704
PIN_MEMORY = True
LOAD_MODEL = False
# TRAIN_IM_DIR = r'C:\Code\Dataset\LIVECell_dataset_2021\images\images\livecell_train_val_images'
# TRAIN_MASK_DIR = r'C:\Code\Dataset\LIVECell_dataset_2021\images\images\livecell_train_val_images_masks'
# VAL_IM_DIR = r'C:\Code\Dataset\LIVECell_dataset_2021\images\images\livecell_test_images'
# VAL_MASK_DIR = r'C:\Code\Dataset\LIVECell_dataset_2021\images\images\livecell_test_images_masks'

# dir_to_test_ims = r'/lunit/home/stevekang/decentAI/MCF7/LIVECell_dataset_2021/images/livecell_test_images'
# dir_to_train_ims = r'/lunit/home/stevekang/decentAI/MCF7/LIVECell_dataset_2021/images/livecell_train_val_images'
TRAIN_IM_DIR = r'/lunit/home/stevekang/cellDetection/MCF7/LIVECell_dataset_2021/images/livecell_train_val_images'
TRAIN_MASK_DIR = r'/lunit/home/stevekang/cellDetection/MCF7/LIVECell_dataset_2021/images/livecell_train_val_images_masks'
VAL_IM_DIR = r'/lunit/home/stevekang/cellDetection/MCF7/LIVECell_dataset_2021/images/livecell_test_images'
VAL_MASK_DIR = r'/lunit/home/stevekang/cellDetection/MCF7/LIVECell_dataset_2021/images/livecell_test_images_masks'

TRAIN_IM_DIR = r'/lunit/home/stevekang/decentAI/MCF7/LIVECell_dataset_2021/images_mask/livecell_train_val_images'
TRAIN_MASK_DIR = r'/lunit/home/stevekang/decentAI/MCF7/LIVECell_dataset_2021/images_mask/livecell_train_val_images_masks'
VAL_IM_DIR = r'/lunit/home/stevekang/decentAI/MCF7/LIVECell_dataset_2021/images_mask/livecell_test_images'
VAL_MASK_DIR = r'/lunit/home/stevekang/decentAI/MCF7/LIVECell_dataset_2021/images_mask/livecell_test_images_masks'


def train_model(loader, model, optimizer, loss_fcn, scaler):
loop = tqdm(loader)
# model= nn.DataParallel(model)
# model.to(device)

for batch_idx, (data, targets) in enumerate(loader):
data = data.to(device=DEVICE).unsqueeze(1)
targets = targets.to(device=DEVICE).unsqueeze(1)
# forward path through model

with torch.cuda.amp.autocast():
pred = model(data)

_,ch,h,w = targets.shape
targets = F.interpolate(targets, size=((h//32)*32, (w//32)*32), mode='bilinear', align_corners=True)
pred = F.interpolate(pred, size=((h//32)*32, (w//32)*32), mode='bilinear', align_corners=True)

loss = loss_fcn(pred,targets)
# backward path
optimizer.zero_grad()
Expand All @@ -66,32 +73,34 @@ def train_model(loader, model, optimizer, loss_fcn, scaler):


def main():
# train_transform = A.Compose(
# [
# A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
# A.Rotate(limit=35, p=1.0),
# A.HorizontalFlip(p=0.5),
# A.VerticalFlip(p=0.1),
# A.Normalize(
# mean=[0.0, 0.0, 0.0],
# std=[1.0, 1.0, 1.0],
# max_pixel_value=255.0,
# ),
# ToTensorV2(),
# ],
# )

# val_transforms = A.Compose(
# [
# A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
# A.Normalize(
# mean=[0.0, 0.0, 0.0],
# std=[1.0, 1.0, 1.0],
# max_pixel_value=255.0,
# ),
# ToTensorV2(),
# ],
# )

train_transform = A.Compose(
[
A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
A.Rotate(limit=35, p=1.0),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.1),
A.Normalize(
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
max_pixel_value=255.0,
),
ToTensorV2(),
],
)

val_transforms = A.Compose(
[
A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
A.Normalize(
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
max_pixel_value=255.0,
),
ToTensorV2(),
],
)

train_loader, val_loader = get_loaders(
TRAIN_IM_DIR,
TRAIN_MASK_DIR,
Expand All @@ -110,33 +119,40 @@ def main():
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

parser = argparse.ArgumentParser(description='test or train')
parser.add_argument('--mode', type=str, default="test")
parser.add_argument('--best_pth', type=str, default="UNet3_no_sigmoid_MCD.pth.tar")
args = parser.parse_args()

if LOAD_MODEL:
load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)


check_accuracy(val_loader, model, device=DEVICE)
scaler = torch.cuda.amp.GradScaler()

for epoch in range(NUM_EPOCHS):
train_model(train_loader, model, optimizer, loss_fn, scaler)

# save model
checkpoint = {
"state_dict": model.state_dict(),
"optimizer":optimizer.state_dict(),
}
save_checkpoint(checkpoint)

# check accuracy
if(args.mode == "test"):
LOAD_MODEL = True
load_checkpoint(torch.load(args.best_pth), model)
check_accuracy(val_loader, model, device=DEVICE)

# print some examples to a folder
# save_predictions_as_imgs(
# val_loader, model, folder="saved_images/", device=DEVICE
# )
img_dir = "saved_test_images_UNet3/"
isExist = os.path.exists(img_dir)
if not isExist:
os.makedirs(img_dir)
save_predictions_as_imgs(val_loader, model, folder=img_dir, device=DEVICE)
if(args.mode == "train"):
scaler = torch.cuda.amp.GradScaler()


for epoch in range(NUM_EPOCHS):
train_model(train_loader, model, optimizer, loss_fn, scaler)

# save model
checkpoint = {
"state_dict": model.state_dict(),
"optimizer":optimizer.state_dict(),
}
save_checkpoint(checkpoint, filename=f"UNet3_training_epoch_{epoch}.pth.tar")
# check accuracy
check_accuracy(val_loader, model, device=DEVICE)

# img_dir = "saved_train_images_UNet3/"
# isExist = os.path.exists(img_dir)
# if not isExist:
# os.makedirs(img_dir)
# save_predictions_as_imgs(val_loader, model, folder=img_dir, device=DEVICE)


if __name__ == "__main__":
Expand Down
38 changes: 23 additions & 15 deletions CellSegmentation/unet.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ def forward(self,x):

return self.out(x)


#example = torch.rand(8,1,704,520) # no of ims, RGB channels, im widths, im heights
#print(example.shape)
#myunet = Unet(1,1)
#out = myunet.forward(example)
#print(out.shape)


import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -278,32 +286,32 @@ def forward(self, inputs):
# print("B4 : ", inputs.shape , "Aftr : ", x.shape)
## -------------Encoder-------------
h1 = self.conv1(x) # h1->320*320*64

h1 = self.dropout(h1)
h2 = self.maxpool1(h1)
h2 = self.conv2(h2) # h2->160*160*128

h2 = self.dropout(h2)
h3 = self.maxpool2(h2)
h3 = self.conv3(h3) # h3->80*80*256

h3 = self.dropout(h3)
h4 = self.maxpool3(h3)
h4 = self.conv4(h4) # h4->40*40*512

h4 = self.dropout(h4)
h5 = self.maxpool4(h4)
hd5 = self.conv5(h5) # h5->20*20*1024

h5 = self.dropout(h5)
## -------------Decoder-------------
h1_PT_hd4 = self.h1_PT_hd4_relu(self.h1_PT_hd4_bn(self.h1_PT_hd4_conv(self.h1_PT_hd4(h1))))
h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2))))
h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3))))
h1_PT_hd4 = self.h1_PT_hd4_relu(self.h1_PT_hd4_bn(self.dropout(self.h1_PT_hd4_conv(self.h1_PT_hd4(h1)))))
h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.dropout( self.h2_PT_hd4_conv(self.h2_PT_hd4(h2)))))
h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.dropout (self.h3_PT_hd4_conv(self.h3_PT_hd4(h3)))))
h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4)))
hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5))))
hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1(torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4), 1)))) # hd4->40*40*UpChannels

h1_PT_hd3 = self.h1_PT_hd3_relu(self.h1_PT_hd3_bn(self.h1_PT_hd3_conv(self.h1_PT_hd3(h1))))
h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2))))
h1_PT_hd3 = self.h1_PT_hd3_relu(self.h1_PT_hd3_bn(self.dropout(self.h1_PT_hd3_conv(self.h1_PT_hd3(h1)))))
h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.dropout(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2)))))
h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3)))
hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4))))
hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5))))
hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.dropout( self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4)))))
hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.dropout(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5)))))
hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1(torch.cat((h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3), 1)))) # hd3->80*80*UpChannels

h1_PT_hd2 = self.h1_PT_hd2_relu(self.h1_PT_hd2_bn(self.h1_PT_hd2_conv(self.h1_PT_hd2(h1))))
Expand All @@ -315,11 +323,11 @@ def forward(self, inputs):

h1_Cat_hd1 = self.h1_Cat_hd1_relu(self.h1_Cat_hd1_bn(self.h1_Cat_hd1_conv(h1)))
hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(self.hd2_UT_hd1(hd2))))
hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3)))) # error here?
hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3))))
hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(self.hd4_UT_hd1(hd4))))
hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5))))
hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1(torch.cat((h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1), 1)))) # hd1->320*320*UpChannels

d1 = self.outconv1(hd1) # d1->320*320*n_classes
return F.sigmoid(d1)


return d1
Loading

0 comments on commit a08f3a1

Please sign in to comment.