diff --git a/src/lighteval/data.py b/src/lighteval/data.py index 7cb105e63..dd1fd5341 100644 --- a/src/lighteval/data.py +++ b/src/lighteval/data.py @@ -25,8 +25,15 @@ from typing import Iterator, Tuple import torch +from packaging import version from torch.utils.data import Dataset -from torch.utils.data.distributed import DistributedSampler, T_co + + +if version.parse(torch.__version__) >= version.parse("2.5.0"): + from torch.utils.data.distributed import DistributedSampler, _T_co +else: + from torch.utils.data.distributed import DistributedSampler + from torch.utils.data.distributed import T_co as _T_co from lighteval.tasks.requests import ( GreedyUntilRequest, @@ -318,7 +325,7 @@ class GenDistributedSampler(DistributedSampler): as our samples are sorted by length. """ - def __iter__(self) -> Iterator[T_co]: + def __iter__(self) -> Iterator[_T_co]: if self.shuffle: # deterministically shuffle based on epoch and seed g = torch.Generator()