diff --git a/src/lighteval/models/transformers/base_model.py b/src/lighteval/models/transformers/base_model.py index 9b815d2b0..b9a958ae1 100644 --- a/src/lighteval/models/transformers/base_model.py +++ b/src/lighteval/models/transformers/base_model.py @@ -355,7 +355,7 @@ def init_model_parallel(self, model_parallel: bool | None = None) -> Tuple[bool, return False, None, None self.num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) - self.num_machines = int(os.environ.get("WORLD_SIZE", 0)) // self.num_local_processes + self.num_machines = torch.cuda.device_count() // self.num_local_processes if self.num_machines == 0: logger.info("We are not in a distributed setting. Setting model_parallel to False.") model_parallel = False