Skip to content

Commit

Permalink
Distributed Data Parallelism (#402)
Browse files Browse the repository at this point in the history
* add fully_sharded_data_parallel option
  • Loading branch information
ncassereau authored Sep 21, 2023
1 parent a1464c5 commit 1672205
Show file tree
Hide file tree
Showing 27 changed files with 1,306 additions and 202 deletions.
1 change: 1 addition & 0 deletions clinicadl/resources/config/train_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ gpu = true
n_proc = 2
batch_size = 8
evaluation_steps = 0
fully_sharded_data_parallel = false
amp = false

[Reproducibility]
Expand Down
1 change: 1 addition & 0 deletions clinicadl/train/tasks/classification_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
@train_option.n_proc
@train_option.batch_size
@train_option.evaluation_steps
@train_option.fully_sharded_data_parallel
@train_option.amp
# Reproducibility
@train_option.seed
Expand Down
1 change: 1 addition & 0 deletions clinicadl/train/tasks/reconstruction_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
@train_option.n_proc
@train_option.batch_size
@train_option.evaluation_steps
@train_option.fully_sharded_data_parallel
@train_option.amp
# Reproducibility
@train_option.seed
Expand Down
1 change: 1 addition & 0 deletions clinicadl/train/tasks/regression_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
@train_option.n_proc
@train_option.batch_size
@train_option.evaluation_steps
@train_option.fully_sharded_data_parallel
@train_option.amp
# Reproducibility
@train_option.seed
Expand Down
1 change: 1 addition & 0 deletions clinicadl/train/tasks/task_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def task_launcher(network_task: str, task_options_list: List[str], **kwargs):
"dropout",
"epochs",
"evaluation_steps",
"fully_sharded_data_parallel",
"gpu",
"learning_rate",
"multi_cohort",
Expand Down
11 changes: 11 additions & 0 deletions clinicadl/utils/cli_param/train_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@
help="Fix the number of iterations to perform before computing an evaluation. Default will only "
"perform one evaluation at the end of each epoch.",
)
fully_sharded_data_parallel = cli_param.option_group.computational_group.option(
"--fully_sharded_data_parallel",
"-fsdp",
type=bool,
is_flag=True,
help="Enables Fully Sharded Data Parallel with Pytorch to save memory at the cost of communications. "
"Currently this only enables ZeRO Stage 1 but will be entirely replaced by FSDP in a later patch, "
"this flag is already set to FSDP to that the zero flag is never actually removed.",
default=False,
)

amp = cli_param.option_group.computational_group.option(
"--amp/--no-amp",
type=bool,
Expand Down
11 changes: 11 additions & 0 deletions clinicadl/utils/maps_manager/cluster/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import sys

# These imports won't be available at runtime, but will help VSCode completion.
from .api import API as API
from .api import AutoMasterAddressPort as AutoMasterAddressPort
from .config import *
from .interface import Interface
from .utils import ClinicaClusterResolverWarning as ClinicaClusterResolverWarning
from .utils import Rank0Filter as Rank0Filter

sys.modules[__name__] = Interface()
13 changes: 13 additions & 0 deletions clinicadl/utils/maps_manager/cluster/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .auto_master_addr_port import AutoMasterAddressPort
from .base import API
from .default import DefaultAPI
from .slurm import SlurmAPI
from .torchelastic import TorchElasticAPI

__all__ = [
"API",
"AutoMasterAddressPort",
"DefaultAPI",
"SlurmAPI",
"TorchElasticAPI",
]
46 changes: 46 additions & 0 deletions clinicadl/utils/maps_manager/cluster/api/auto_master_addr_port.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*-

import os
from functools import wraps
from typing import Callable, Type

from ..config import __all__ as all_API_methods
from .base import API

# Defines a class decorator to make wraps any API methods so that the Master Address
# and the Master Port are set in order to allow the process group to initialize
# correctly.

env_variables_set: bool = False


def set_master_addr_port_env_variables(func):
# The parameter should be a method of a subclass of the API abstract class.
@wraps(func)
def wrapper(self):
global env_variables_set
if not env_variables_set:
env_variables_set = True # must be done before actually setting the variable to prevent stackoverflow
os.environ["MASTER_ADDR"] = self.master_address()
os.environ["MASTER_PORT"] = str(self.port())
return func(self)

return wrapper


def decorate_methods(cls: Type[API], func_to_apply: Callable) -> Type[API]:
# Decorate all API methods defined in the config file with the given function.
for obj_name in dir(cls):
if obj_name in all_API_methods:
decorated = func_to_apply(getattr(cls, obj_name))
setattr(cls, obj_name, decorated)

return cls


def AutoMasterAddressPort(cls: Type[API]) -> Type[API]:
# When we call a cluster API function for the first time, we set the MASTER_ADDR
# and MASTER_PORT environment variables, so that the Pytorch wrapper
# DistributedDataParallel can set up communication correctly.
return decorate_methods(cls, func_to_apply=set_master_addr_port_env_variables)
93 changes: 93 additions & 0 deletions clinicadl/utils/maps_manager/cluster/api/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*-

from abc import ABC, abstractmethod
from typing import List, Union


class API(ABC):
priority: int = 5000
name: str = "AbstractAPI"

@abstractmethod
def is_launcher(self) -> bool:
"""
Detects if the given API is the one used to launch the current job.
"""
raise NotImplementedError()

@abstractmethod
def rank(self) -> int:
"""
Property containing the rank of the process.
"""
raise NotImplementedError()

@abstractmethod
def local_rank(self) -> int:
"""
Property containing the local rank of the process.
"""
raise NotImplementedError()

@abstractmethod
def world_size(self) -> int:
"""
Property containing the number of processes launched.
"""
raise NotImplementedError()

@abstractmethod
def local_world_size(self) -> int:
"""
Property containing the number of processes launched of each node.
"""
raise NotImplementedError()

@abstractmethod
def num_nodes(self) -> int:
"""
Property containing the number of nodes.
"""
raise NotImplementedError()

@abstractmethod
def cpus(self) -> int:
"""
Property containing the number of CPUs allocated to each process.
"""
raise NotImplementedError()

@abstractmethod
def gpus(self) -> List[str]:
"""
Property containing all GPUs ids.
"""
raise NotImplementedError()

@abstractmethod
def nodelist(self) -> Union[str, List[str]]:
"""
Property containing the list of nodes.
"""
raise NotImplementedError()

@abstractmethod
def master_address(self) -> str:
"""
Property containing the master node.
"""
raise NotImplementedError()

@abstractmethod
def port(self) -> int:
"""
Property containing the port to communicate with the master process.
"""
raise NotImplementedError()

def is_master(self) -> bool:
"""
Detects whether or not the given process is the master (i.e. rank 0)
"""
return self.rank() == 0
65 changes: 65 additions & 0 deletions clinicadl/utils/maps_manager/cluster/api/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*-

import os
import socket
from contextlib import closing
from typing import List, Optional

from .auto_master_addr_port import AutoMasterAddressPort
from .base import API


@AutoMasterAddressPort
class DefaultAPI(API):
priority: int = 0
name: str = "Sequential"

def __init__(self):
self.current_port: Optional[int] = None

@staticmethod
def find_available_port() -> int:
"""
Tries to bind to local port until it finds one which is available.
This is used to set the master port environment variable.
"""
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
sock.bind(("localhost", 0))
port = sock.getsockname()[1]
return port

def is_launcher(self) -> bool:
return True

def rank(self) -> int:
return 0

def local_rank(self) -> int:
return 0

def world_size(self) -> int:
return 1

def local_world_size(self) -> int:
return 1

def num_nodes(self) -> int:
return 1

def cpus(self) -> int:
return len(os.sched_getaffinity(0))

def gpus(self) -> List[str]:
return []

def nodelist(self) -> List[str]:
return ["localhost"]

def master_address(self) -> str:
return "localhost"

def port(self) -> int:
if self.current_port is None:
self.current_port = self.find_available_port()
return self.current_port
55 changes: 55 additions & 0 deletions clinicadl/utils/maps_manager/cluster/api/slurm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*-

import os
from typing import List

from ..utils import get_first_host
from .auto_master_addr_port import AutoMasterAddressPort
from .base import API


@AutoMasterAddressPort
class SlurmAPI(API):
priority: int = 10000
name: str = "Slurm"

def is_launcher(self) -> bool:
return "SLURM_STEP_ID" in os.environ

def rank(self) -> int:
return int(os.environ["SLURM_PROCID"])

def local_rank(self) -> int:
return int(os.environ["SLURM_LOCALID"])

def world_size(self) -> int:
return int(os.environ["SLURM_STEP_NUM_TASKS"])

def local_world_size(self) -> int:
return int(os.environ["SLURM_STEP_TASKS_PER_NODE"])

def num_nodes(self) -> int:
return int(os.environ["SLURM_STEP_NUM_NODES"])

def cpus(self) -> int:
cpu = int(os.environ.get("SLURM_CPUS_PER_TASK", 0))
return cpu or len(os.sched_getaffinity(0))

def gpus(self) -> List[str]:
step_gpus = os.environ.get("SLURM_STEP_GPUS", None)
if step_gpus is not None:
return step_gpus.split(",")
return []

def nodelist(self) -> str:
return os.environ["SLURM_STEP_NODELIST"]

def master_address(self) -> str:
return get_first_host(self.nodelist())

def jobid(self) -> int:
return int(os.environ["SLURM_JOB_ID"])

def port(self) -> int:
return 10000 + self.jobid() % 20000
Loading

0 comments on commit 1672205

Please sign in to comment.