From 24afde293223eba0eb482aab9c9e3746ba11f647 Mon Sep 17 00:00:00 2001 From: Nathan Habib <30601243+NathanHB@users.noreply.github.com> Date: Thu, 2 Jan 2025 12:20:19 +0100 Subject: [PATCH] fix model parallel (#481) fixes #447 --- src/lighteval/models/transformers/base_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lighteval/models/transformers/base_model.py b/src/lighteval/models/transformers/base_model.py index 9b815d2b0..b9a958ae1 100644 --- a/src/lighteval/models/transformers/base_model.py +++ b/src/lighteval/models/transformers/base_model.py @@ -355,7 +355,7 @@ def init_model_parallel(self, model_parallel: bool | None = None) -> Tuple[bool, return False, None, None self.num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) - self.num_machines = int(os.environ.get("WORLD_SIZE", 0)) // self.num_local_processes + self.num_machines = torch.cuda.device_count() // self.num_local_processes if self.num_machines == 0: logger.info("We are not in a distributed setting. Setting model_parallel to False.") model_parallel = False