-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
27 changed files
with
1,306 additions
and
202 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import sys | ||
|
||
# These imports won't be available at runtime, but will help VSCode completion. | ||
from .api import API as API | ||
from .api import AutoMasterAddressPort as AutoMasterAddressPort | ||
from .config import * | ||
from .interface import Interface | ||
from .utils import ClinicaClusterResolverWarning as ClinicaClusterResolverWarning | ||
from .utils import Rank0Filter as Rank0Filter | ||
|
||
sys.modules[__name__] = Interface() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from .auto_master_addr_port import AutoMasterAddressPort | ||
from .base import API | ||
from .default import DefaultAPI | ||
from .slurm import SlurmAPI | ||
from .torchelastic import TorchElasticAPI | ||
|
||
__all__ = [ | ||
"API", | ||
"AutoMasterAddressPort", | ||
"DefaultAPI", | ||
"SlurmAPI", | ||
"TorchElasticAPI", | ||
] |
46 changes: 46 additions & 0 deletions
46
clinicadl/utils/maps_manager/cluster/api/auto_master_addr_port.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
#! /usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
import os | ||
from functools import wraps | ||
from typing import Callable, Type | ||
|
||
from ..config import __all__ as all_API_methods | ||
from .base import API | ||
|
||
# Defines a class decorator to make wraps any API methods so that the Master Address | ||
# and the Master Port are set in order to allow the process group to initialize | ||
# correctly. | ||
|
||
env_variables_set: bool = False | ||
|
||
|
||
def set_master_addr_port_env_variables(func): | ||
# The parameter should be a method of a subclass of the API abstract class. | ||
@wraps(func) | ||
def wrapper(self): | ||
global env_variables_set | ||
if not env_variables_set: | ||
env_variables_set = True # must be done before actually setting the variable to prevent stackoverflow | ||
os.environ["MASTER_ADDR"] = self.master_address() | ||
os.environ["MASTER_PORT"] = str(self.port()) | ||
return func(self) | ||
|
||
return wrapper | ||
|
||
|
||
def decorate_methods(cls: Type[API], func_to_apply: Callable) -> Type[API]: | ||
# Decorate all API methods defined in the config file with the given function. | ||
for obj_name in dir(cls): | ||
if obj_name in all_API_methods: | ||
decorated = func_to_apply(getattr(cls, obj_name)) | ||
setattr(cls, obj_name, decorated) | ||
|
||
return cls | ||
|
||
|
||
def AutoMasterAddressPort(cls: Type[API]) -> Type[API]: | ||
# When we call a cluster API function for the first time, we set the MASTER_ADDR | ||
# and MASTER_PORT environment variables, so that the Pytorch wrapper | ||
# DistributedDataParallel can set up communication correctly. | ||
return decorate_methods(cls, func_to_apply=set_master_addr_port_env_variables) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
#! /usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import List, Union | ||
|
||
|
||
class API(ABC): | ||
priority: int = 5000 | ||
name: str = "AbstractAPI" | ||
|
||
@abstractmethod | ||
def is_launcher(self) -> bool: | ||
""" | ||
Detects if the given API is the one used to launch the current job. | ||
""" | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def rank(self) -> int: | ||
""" | ||
Property containing the rank of the process. | ||
""" | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def local_rank(self) -> int: | ||
""" | ||
Property containing the local rank of the process. | ||
""" | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def world_size(self) -> int: | ||
""" | ||
Property containing the number of processes launched. | ||
""" | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def local_world_size(self) -> int: | ||
""" | ||
Property containing the number of processes launched of each node. | ||
""" | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def num_nodes(self) -> int: | ||
""" | ||
Property containing the number of nodes. | ||
""" | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def cpus(self) -> int: | ||
""" | ||
Property containing the number of CPUs allocated to each process. | ||
""" | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def gpus(self) -> List[str]: | ||
""" | ||
Property containing all GPUs ids. | ||
""" | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def nodelist(self) -> Union[str, List[str]]: | ||
""" | ||
Property containing the list of nodes. | ||
""" | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def master_address(self) -> str: | ||
""" | ||
Property containing the master node. | ||
""" | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def port(self) -> int: | ||
""" | ||
Property containing the port to communicate with the master process. | ||
""" | ||
raise NotImplementedError() | ||
|
||
def is_master(self) -> bool: | ||
""" | ||
Detects whether or not the given process is the master (i.e. rank 0) | ||
""" | ||
return self.rank() == 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
#! /usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
import os | ||
import socket | ||
from contextlib import closing | ||
from typing import List, Optional | ||
|
||
from .auto_master_addr_port import AutoMasterAddressPort | ||
from .base import API | ||
|
||
|
||
@AutoMasterAddressPort | ||
class DefaultAPI(API): | ||
priority: int = 0 | ||
name: str = "Sequential" | ||
|
||
def __init__(self): | ||
self.current_port: Optional[int] = None | ||
|
||
@staticmethod | ||
def find_available_port() -> int: | ||
""" | ||
Tries to bind to local port until it finds one which is available. | ||
This is used to set the master port environment variable. | ||
""" | ||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: | ||
sock.bind(("localhost", 0)) | ||
port = sock.getsockname()[1] | ||
return port | ||
|
||
def is_launcher(self) -> bool: | ||
return True | ||
|
||
def rank(self) -> int: | ||
return 0 | ||
|
||
def local_rank(self) -> int: | ||
return 0 | ||
|
||
def world_size(self) -> int: | ||
return 1 | ||
|
||
def local_world_size(self) -> int: | ||
return 1 | ||
|
||
def num_nodes(self) -> int: | ||
return 1 | ||
|
||
def cpus(self) -> int: | ||
return len(os.sched_getaffinity(0)) | ||
|
||
def gpus(self) -> List[str]: | ||
return [] | ||
|
||
def nodelist(self) -> List[str]: | ||
return ["localhost"] | ||
|
||
def master_address(self) -> str: | ||
return "localhost" | ||
|
||
def port(self) -> int: | ||
if self.current_port is None: | ||
self.current_port = self.find_available_port() | ||
return self.current_port |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
#! /usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
import os | ||
from typing import List | ||
|
||
from ..utils import get_first_host | ||
from .auto_master_addr_port import AutoMasterAddressPort | ||
from .base import API | ||
|
||
|
||
@AutoMasterAddressPort | ||
class SlurmAPI(API): | ||
priority: int = 10000 | ||
name: str = "Slurm" | ||
|
||
def is_launcher(self) -> bool: | ||
return "SLURM_STEP_ID" in os.environ | ||
|
||
def rank(self) -> int: | ||
return int(os.environ["SLURM_PROCID"]) | ||
|
||
def local_rank(self) -> int: | ||
return int(os.environ["SLURM_LOCALID"]) | ||
|
||
def world_size(self) -> int: | ||
return int(os.environ["SLURM_STEP_NUM_TASKS"]) | ||
|
||
def local_world_size(self) -> int: | ||
return int(os.environ["SLURM_STEP_TASKS_PER_NODE"]) | ||
|
||
def num_nodes(self) -> int: | ||
return int(os.environ["SLURM_STEP_NUM_NODES"]) | ||
|
||
def cpus(self) -> int: | ||
cpu = int(os.environ.get("SLURM_CPUS_PER_TASK", 0)) | ||
return cpu or len(os.sched_getaffinity(0)) | ||
|
||
def gpus(self) -> List[str]: | ||
step_gpus = os.environ.get("SLURM_STEP_GPUS", None) | ||
if step_gpus is not None: | ||
return step_gpus.split(",") | ||
return [] | ||
|
||
def nodelist(self) -> str: | ||
return os.environ["SLURM_STEP_NODELIST"] | ||
|
||
def master_address(self) -> str: | ||
return get_first_host(self.nodelist()) | ||
|
||
def jobid(self) -> int: | ||
return int(os.environ["SLURM_JOB_ID"]) | ||
|
||
def port(self) -> int: | ||
return 10000 + self.jobid() % 20000 |
Oops, something went wrong.