Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch backend: keras.utils.split_dataset requires tensorflow #20947

Open
oluwandabira opened this issue Feb 23, 2025 · 5 comments
Open

torch backend: keras.utils.split_dataset requires tensorflow #20947

oluwandabira opened this issue Feb 23, 2025 · 5 comments
Assignees

Comments

@oluwandabira
Copy link

The keras.utils.split_dataset function requires tensorflow to be installed even tho the docstring states that the dataset argument can be a torch.utils.data.Dataset (and I assumed that means it would just work on the torch dataset with no issues).

@TaiXeflar
Copy link

TaiXeflar commented Feb 23, 2025

Same pytorch user just here.
Keras 3 will use tf as backend as default if KERAS_BACKEND is not set.
You can export ENV variable or just add this at the top of your python script:

import os
os.environ["KERAS_BACKEND"] = "torch"

I think this will helps?

@oluwandabira
Copy link
Author

Thanks but that isn't the issue. I've set the KERAS_BACKEND environment variable to torch I can fit and evaluate and evaluate my pytorch model with pytorch dataloaders just fine, but spliting my pytorch dataset with keras.utils.split_dataset fails because I don't have tensorflow installed.

@sonali-kumari1
Copy link
Contributor

Hi @oluwandabira -

Could you please provide a standalone code and Pytorch dataset you are trying to split with keras.utils.split_dataset to replicate this issue. Thanks!

@oluwandabira
Copy link
Author

oluwandabira commented Feb 25, 2025

import os
os.environ["KERAS_BACKEND"] = "torch"
import torch
import keras

class Dataset(torch.utils.data.Dataset):
    def __init__(self, len : int):
        super().__init__()
        self._len = len

    def __len__(self):
        return self._len

    def __getitem__(self, index):
        return torch.rand((224, 224, 3)), torch.rand(())
    

dataset = Dataset(10)

train, test = keras.utils.split_dataset(dataset, left_size=0.7)
pip freeze absl-py==2.1.0 contourpy==1.3.1 cycler==0.12.1 filelock==3.17.0 fire==0.7.0 fonttools==4.56.0 fsspec==2025.2.0 h5py==3.13.0 Jinja2==3.1.5 joblib==1.4.2 keras==3.8.0 kiwisolver==1.4.8 markdown-it-py==3.0.0 MarkupSafe==3.0.2 matplotlib==3.10.0 mdurl==0.1.2 ml_dtypes==0.5.1 mpmath==1.3.0 namex==0.0.8 networkx==3.4.2 numpy==2.2.3 nvidia-cublas-cu12==12.4.5.8 nvidia-cuda-cupti-cu12==12.4.127 nvidia-cuda-nvrtc-cu12==12.4.127 nvidia-cuda-runtime-cu12==12.4.127 nvidia-cudnn-cu12==9.1.0.70 nvidia-cufft-cu12==11.2.1.3 nvidia-curand-cu12==10.3.5.147 nvidia-cusolver-cu12==11.6.1.9 nvidia-cusparse-cu12==12.3.1.170 nvidia-cusparselt-cu12==0.6.2 nvidia-nccl-cu12==2.21.5 nvidia-nvjitlink-cu12==12.4.127 nvidia-nvtx-cu12==12.4.127 optree==0.14.0 packaging==24.2 pillow==11.1.0 Pygments==2.19.1 pyparsing==3.2.1 python-dateutil==2.9.0.post0 rich==13.9.4 scikit-learn==1.6.1 scipy==1.15.2 six==1.17.0 sympy==1.13.1 termcolor==2.5.0 threadpoolctl==3.5.0 torch==2.6.0 torchvision==0.21.0 triton==3.2.0 typing_extensions==4.12.2

Running the above code in the above environment results in a ImportError: This requires the tensorflow module. You can install it via 'pip install tensorflow'

@sonali-kumari1
Copy link
Contributor

Hi @oluwandabira -

Thanks for sharing the reproducible code. We will look into it and update you soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants