-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch backend: keras.utils.split_dataset requires tensorflow #20947
Comments
Same pytorch user just here.
I think this will helps? |
Thanks but that isn't the issue. I've set the |
Hi @oluwandabira - Could you please provide a standalone code and Pytorch dataset you are trying to split with |
import os
os.environ["KERAS_BACKEND"] = "torch"
import torch
import keras
class Dataset(torch.utils.data.Dataset):
def __init__(self, len : int):
super().__init__()
self._len = len
def __len__(self):
return self._len
def __getitem__(self, index):
return torch.rand((224, 224, 3)), torch.rand(())
dataset = Dataset(10)
train, test = keras.utils.split_dataset(dataset, left_size=0.7) pip freeze
absl-py==2.1.0
contourpy==1.3.1
cycler==0.12.1
filelock==3.17.0
fire==0.7.0
fonttools==4.56.0
fsspec==2025.2.0
h5py==3.13.0
Jinja2==3.1.5
joblib==1.4.2
keras==3.8.0
kiwisolver==1.4.8
markdown-it-py==3.0.0
MarkupSafe==3.0.2
matplotlib==3.10.0
mdurl==0.1.2
ml_dtypes==0.5.1
mpmath==1.3.0
namex==0.0.8
networkx==3.4.2
numpy==2.2.3
nvidia-cublas-cu12==12.4.5.8
nvidia-cuda-cupti-cu12==12.4.127
nvidia-cuda-nvrtc-cu12==12.4.127
nvidia-cuda-runtime-cu12==12.4.127
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.2.1.3
nvidia-curand-cu12==10.3.5.147
nvidia-cusolver-cu12==11.6.1.9
nvidia-cusparse-cu12==12.3.1.170
nvidia-cusparselt-cu12==0.6.2
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127
optree==0.14.0
packaging==24.2
pillow==11.1.0
Pygments==2.19.1
pyparsing==3.2.1
python-dateutil==2.9.0.post0
rich==13.9.4
scikit-learn==1.6.1
scipy==1.15.2
six==1.17.0
sympy==1.13.1
termcolor==2.5.0
threadpoolctl==3.5.0
torch==2.6.0
torchvision==0.21.0
triton==3.2.0
typing_extensions==4.12.2
Running the above code in the above environment results in a |
Hi @oluwandabira - Thanks for sharing the reproducible code. We will look into it and update you soon. |
The
keras.utils.split_dataset
function requires tensorflow to be installed even tho the docstring states that thedataset
argument can be atorch.utils.data.Dataset
(and I assumed that means it would just work on the torch dataset with no issues).The text was updated successfully, but these errors were encountered: