Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataModules: add configurable args to dataloader #2333

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

anton-emanuel
Copy link

@anton-emanuel anton-emanuel commented Oct 4, 2024

Closes #2332.

@github-actions github-actions bot added the datamodules PyTorch Lightning datamodules label Oct 4, 2024
@anton-emanuel anton-emanuel marked this pull request as draft October 4, 2024 08:09
@adamjstewart
Copy link
Collaborator

How much of a speed difference do you notice?

At the moment, the datamodules are all designed to pass additional kwargs to the dataset class. So we should technically add these new args to every datamodule, not just to the base class. We could design things such that all kwargs are either passed to the data loader or the dataset class, but I'm worried users might get confused by this. Of course, they don't need to understand the feature if they don't want to use it.

I guess my concern is, why stop at only these variables?

@adamjstewart adamjstewart added this to the 0.7.0 milestone Oct 4, 2024
@anton-emanuel
Copy link
Author

anton-emanuel commented Oct 4, 2024

Varying all of those three parameters I was able to at least half the time for [_TrainingEpochLoop].train_dataloader_next (probably even less). But I am reading from COGs over internet, so I guess it's very different from having the raster data on the same device.

I agree, it would be better to solve for all arguments, and that there is a trade off between full configurability and simplicity. I just added those variables since those were the ones I wanted to vary.

Could it be an idea to do something like what is done in BaseDataModule._valid_attribute, but just taking kwargs starting with dataloader_ and put in dataloader (I feel like I have seen this solution somewhere in TorchGeo, but might have been another library)?

# ...
DATALOADER_KWARG_PREFIX = "dataloader_"

    # ...
    dataloader_kwargs = {
        k.replace(DATALOADER_KWARG_PREFIX, ""): v
        for k, v in kwargs.items()
        if k.startswith(DATALOADER_KWARG_PREFIX)
    }
    
    # ...
    DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=self.num_workers,
        collate_fn=self.collate_fn,
        persistent_workers=self.num_workers > 0,
        **dataloader_kwargs
    )

@anton-emanuel anton-emanuel force-pushed the feat/datamodules-configurable-dataloader branch from 1af0682 to 1f56161 Compare October 8, 2024 11:38
@anton-emanuel anton-emanuel force-pushed the feat/datamodules-configurable-dataloader branch from 1f56161 to 8257b8e Compare October 8, 2024 11:38
@anton-emanuel
Copy link
Author

anton-emanuel commented Oct 8, 2024

I changed it to use kwargs now

        self.dataloader_kwargs, self.kwargs = split_prefixed_kwargs(
            'dataloader_', **kwargs
        )

@anton-emanuel
Copy link
Author

@microsoft-github-policy-service agree company="Kongsberg Satellite Services"

@anton-emanuel anton-emanuel marked this pull request as ready for review October 8, 2024 12:19
@DimitrisMantas
Copy link
Contributor

DimitrisMantas commented Oct 8, 2024

I think it woul be better for kwargs to be a JSON-like dict, something like {"dataset": ..., "dataloader": ...}. We could even turn this into a typed dict or a dataclass so that it can be clearly documented and tested. I think the transformers library likes to use these config structures a lot

@anton-emanuel anton-emanuel force-pushed the feat/datamodules-configurable-dataloader branch from 8fc54ef to b5291a6 Compare October 8, 2024 17:07
@anton-emanuel anton-emanuel force-pushed the feat/datamodules-configurable-dataloader branch from 9a78733 to 88cedbc Compare October 8, 2024 17:23
@github-actions github-actions bot added the testing Continuous integration testing label Oct 9, 2024
@adamjstewart
Copy link
Collaborator

I feel like I have seen this solution somewhere in TorchGeo, but might have been another library

Yes, we do something similar in our combined GeoDatasets (NAIP + Chesapeake, Sentinel-2 + CDL/EuroCrops/NCCM/SAS).

I think it woul be better for kwargs to be a JSON-like dict

The downside of this is that it would be backwards-incompatible and affect all datamodules.

I'm still thinking about this idea. I wish there was a way to predict a good default value. I'm curious if any other maintainers have experimented with these variables and how important they are to tune. In the meantime, anyone can do this themselves by overriding the respective methods like so:

from torchgeo.datamodules import TropicalCycloneDataModule

class MyTropicalCycloneDataModule(TropicalCycloneDataModule):
    def _dataloader_factory(self, split):
        dataset = self._valid_attribute(f'{split}_dataset', 'dataset')                   
        batch_size = self._valid_attribute(f'{split}_batch_size', 'batch_size')          
        return DataLoader(                                                               
            dataset=dataset,                                                             
            batch_size=batch_size,                                                       
            shuffle=split == 'train',                                                    
            num_workers=self.num_workers,                                                
            collate_fn=self.collate_fn,                                                  
            persistent_workers=self.num_workers > 0,  
            ...                                   
        )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datamodules PyTorch Lightning datamodules testing Continuous integration testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Expose more arguments to dataloader
3 participants