Skip to content

Commit

Permalink
Distributed training (#25)
Browse files Browse the repository at this point in the history
* distributed training

* torchrun to start the training, but it quit directly

* fixed several distributed bugs. I am going to change the Dataset to map rather than IterableDataset since it does not work with DistributedSampler

* mode of dataset

* fix remove_contact bug

* distributed training is working.
  • Loading branch information
xiuliren authored May 4, 2023
1 parent 3027181 commit c365d9c
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 93 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*.npy
*.chkpt
*.yaml
*events.out.tfevents.*
Expand Down
11 changes: 11 additions & 0 deletions examples/launch_gpu_node_bash.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/bin/bash

module load slurm
module load cuda
module load cudnn

CORES=8

#neutrain-pre --config-file ./config.yaml
srun -p gpu --gpus 2 -C v100 --cpus-per-gpu=$CORES --pty bash -i

25 changes: 23 additions & 2 deletions examples/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,29 @@ module load slurm
module load cuda
module load cudnn

CORES=12
export TF_CPP_MIN_LOG_LEVEL=1
export NCCL_DEBUG="INFO"
export TORCH_DISTRIBUTED_DEBUG="INFO"
export TORCH_SHOW_CPP_STACKTRACES="1"

CORES_PER_GPU=6
NUM_TRAINERS=2
RANK=0
CONSTRAIN="a100"

#module list

#neutrain-pre --config-file ./config.yaml
srun -p gpu --gpus 1 --cpus-per-gpu=$CORES neutrain-pre --config-file ./config.yaml
#srun -p gpu --gpus 1 --cpus-per-gpu=$CORES neutrain-pre --config-file ./config.yaml
#python -m torch.distributed.launch --nproc_per_node=2 neutrain-affs-vol -c whole_brain_affs.yaml

#torchrun \
srun -p gpu --gpus $NUM_TRAINERS --cpus-per-gpu=$CORES_PER_GPU -C $CONSTRAIN torchrun \
--standalone \
--nnodes=1 \
--nproc_per_node $NUM_TRAINERS \
--no_python \
neutrain-affs-vol -c whole_brain_affs.yaml


#/mnt/home/jwu/code/neutorch/neutorch/train/whole_brain_affinity_map.py -c whole_brain_affs.yaml
98 changes: 67 additions & 31 deletions neutorch/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,20 @@ def load_cfg(cfg_file: str, freeze: bool = True):
cfg.freeze()
return cfg

def worker_init_fn(worker_id: int):
worker_info = torch.utils.data.get_worker_info()
# def worker_init_fn(worker_id: int):
# worker_info = torch.utils.data.get_worker_info()

# the dataset copy in this worker process
dataset = worker_info.dataset
overall_start = 0
overall_end = dataset.sample_num

# configure the dataset to only process the split workload
per_worker = int(math.ceil(
(overall_end - overall_start) / float(worker_info.num_workers)))
worker_id = worker_info.id
dataset.start = overall_start + worker_id * per_worker
dataset.end = min(dataset.start + per_worker, overall_end)
# # the dataset copy in this worker process
# dataset = worker_info.dataset
# overall_start = 0
# overall_end = dataset.sample_num

# # configure the dataset to only process the split workload
# per_worker = int(math.ceil(
# (overall_end - overall_start) / float(worker_info.num_workers)))
# worker_id = worker_info.id
# dataset.start = overall_start + worker_id * per_worker
# dataset.end = min(dataset.start + per_worker, overall_end)

def path_to_dataset_name(path: str, dataset_names: list):
for dataset_name in dataset_names:
Expand All @@ -52,9 +52,9 @@ def to_tensor(arr):
return arr


class DatasetBase(torch.utils.data.IterableDataset):
class DatasetBase(torch.utils.data.Dataset):
def __init__(self,
samples: list,
samples: List[AbstractSample],
):
"""
Parameters:
Expand Down Expand Up @@ -99,24 +99,26 @@ def random_patch(self):
patch = sample.random_patch
# patch.to_tensor()
return patch.image, patch.label

def __next__(self):

def __len__(self):
patch_num = 0
for sample in self.samples:
patch_num += len(sample)
return patch_num

def __getitem__(self, index: int):
"""return a random patch from a random sample
the exact index does not matter!
Args:
index (int): index of the patch
"""
image_chunk, label_chunk = self.random_patch
image = to_tensor(image_chunk.array)
label = to_tensor(label_chunk.array)

return image, label

def __iter__(self):
"""generate random patches from samples
Yields:
tuple[tensor, tensor]: image and label tensors
"""
while True:
yield next(self)


class SemanticDataset(DatasetBase):
def __init__(self, samples: list):
#patch_size: Cartesian = DEFAULT_PATCH_SIZE):
Expand Down Expand Up @@ -152,7 +154,7 @@ def from_config(cls, cfg: CfgNode, is_train: bool, **kwargs):


class OrganelleDataset(SemanticDataset):
def __init__(self, samples: list,
def __init__(self, samples: List[AbstractSample],
num_classes: int = 1,
skip_classes: list = None,
selected_classes: list = None):
Expand Down Expand Up @@ -226,23 +228,57 @@ def __next__(self):
return image, target

class AffinityMapVolumeWithMask(DatasetBase):
def __init__(self, samples: list):
def __init__(self, samples: List[AbstractSample]):
super().__init__(samples)

@classmethod
def from_config(cls, cfg: CfgNode, **kwargs):
def from_config(cls, cfg: CfgNode, mode: str = 'training', **kwargs):
"""construct affinity map volume with mask dataset
Args:
cfg (CfgNode): the configuration node
mode (str, optional): ['training', 'validation', 'test']. Defaults to 'train'.
"""
output_patch_size = Cartesian.from_collection(
cfg.train.patch_size)

worker_info = torch.utils.data.get_worker_info()
sample_names = cfg.dataset[mode]

# single process data loading
iter_start = 0
iter_stop = len(sample_names)
if worker_info is not None:
# multiple processing for data loading, split workload
worker_id = worker_info.id
if len(sample_names) > worker_info.num_workers:
per_worker = int(math.ceil(
(iter_end - iter_start) / float(worker_info.num_workers)))
iter_start = iter_start + worker_id * per_worker
iter_end = min(iter_start + per_worker, iter_end)
else:
iter_start = worker_id % len(sample_names)
iter_stop = iter_start + 1

samples = []
for sample_name in cfg.samples:
for sample_name in sample_names[iter_start : iter_stop]:
sample_cfg = cfg.samples[sample_name]
sample_class = eval(sample_cfg.type)
sample = sample_class.from_config(
sample_cfg, output_patch_size)
samples.append(sample)
return cls(samples)

def __len__(self):
# num = 0
# for sample in self.samples:
# num += sample.
# return num
# return self.sample_num * cfg.system.gpus * cfg.train.batch_size * 16
# our patches are randomly sampled from chunk or volume and is close to unlimited.
# return a big enough number to make distributed sampler work
return 1024

class AffinityMapDataset(DatasetBase):
def __init__(self, samples: list):
super().__init__(samples)
Expand Down
19 changes: 18 additions & 1 deletion neutorch/data/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ def sampling_weight(self) -> int:
"""
return 1

def __len__(self):
"""number of patches
we simplly return a number as default value to make it work with distributed sampler.
"""
return 64

@cached_property
def transform(self):
return Compose([
Expand Down Expand Up @@ -196,6 +202,10 @@ def random_patch_center(self):
center = Cartesian(cz, cy, cx)
return center

def __len__(self):
patch_num = np.prod(self.center_stop - self.center_start + 1)
return patch_num

def patch_from_center(self, center: Cartesian):
start = center - self.patch_size_before_transform // 2
bbox = BoundingBox.from_delta(start, self.patch_size_before_transform)
Expand Down Expand Up @@ -343,7 +353,11 @@ def sampling_weight(self) -> int:
block_num = len(self.nonzero_block_bounding_boxes)
block_size = self.label.block_size * self.voxel_size_factors
return np.product(block_size) * block_num


def __len__(self):
patch_num = len(self.nonzero_block_bounding_boxes) * \
np.prod(self.label.block_size - self.patch_size_before_transform + 1)
return patch_num


class SampleWithPointAnnotation(Sample):
Expand Down Expand Up @@ -476,6 +490,9 @@ def random_patch(self):

return Patch(image, label)

def __len__(self):
return self.synapses.pre_num


class SemanticSample(Sample):
def __init__(self,
Expand Down
Loading

0 comments on commit c365d9c

Please sign in to comment.