Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v2.8.1 #736

Merged
merged 4 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .github/workflows/test_datasets.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
name: datasets

on:
pull_request:
branches: [ master, dev ]
paths:
- 'src/**'
- 'tests/**'
- '.github/workflows/**'

jobs:
call-base-test-workflow:
uses: ./.github/workflows/base_test_workflow.yml
with:
module-to-test: datasets
8 changes: 4 additions & 4 deletions docs/datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ datasets.base_dataset.BaseDataset(
## CUB-200-2011

```python
datasets.cub.CUB(*args, **kwargs)
datasets.CUB(*args, **kwargs)
```

**Defined splits**:
Expand Down Expand Up @@ -75,7 +75,7 @@ train_and_test_dataset = CUB(root="data",
## Cars196

```python
datasets.cars196.Cars196(*args, **kwargs)
datasets.Cars196(*args, **kwargs)
```

**Defined splits**:
Expand Down Expand Up @@ -110,7 +110,7 @@ train_and_test_dataset = Cars196(root="data",
## INaturalist2018

```python
datasets.inaturalist2018.INaturalist2018(*args, **kwargs)
datasets.INaturalist2018(*args, **kwargs)
```

**Defined splits**:
Expand Down Expand Up @@ -146,7 +146,7 @@ train_and_test_dataset = INaturalist2018(root="data",
## StanfordOnlineProducts

```python
datasets.sop.StanfordOnlineProducts(*args, **kwargs)
datasets.StanfordOnlineProducts(*args, **kwargs)
```

**Defined splits**:
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.8.0"
__version__ = "2.8.1"
5 changes: 5 additions & 0 deletions src/pytorch_metric_learning/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .base_dataset import BaseDataset
from .cars196 import Cars196
from .cub import CUB
from .inaturalist2018 import INaturalist2018
from .sop import StanfordOnlineProducts
2 changes: 2 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
)
device_from_environ = os.environ.get("TEST_DEVICE", "cuda")
with_collect_stats = os.environ.get("WITH_COLLECT_STATS", "false")
test_datasets = os.environ.get("TEST_DATASETS", "false")

TEST_DTYPES = [getattr(torch, x) for x in dtypes_from_environ]
TEST_DEVICE = torch.device(device_from_environ)

assert c_f.COLLECT_STATS is False

WITH_COLLECT_STATS = True if with_collect_stats == "true" else False
TEST_DATASETS = True if test_datasets == "true" else False
c_f.COLLECT_STATS = WITH_COLLECT_STATS

print(
Expand Down
7 changes: 5 additions & 2 deletions tests/datasets/test_cars196.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from pytorch_metric_learning.datasets.cars196 import Cars196
from pytorch_metric_learning.datasets import Cars196
from .. import TEST_DATASETS


class TestCars196(unittest.TestCase):
Expand All @@ -19,6 +20,7 @@ def setUpClass(cls):
if os.path.exists(cls.CARS_196_ROOT):
cls.ALREADY_EXISTS = True

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
def test_Cars196(self):
train_test_data = Cars196(
root=TestCars196.CARS_196_ROOT, split="train+test", download=True
Expand All @@ -34,6 +36,7 @@ def test_Cars196(self):
self.assertTrue(len(train_data) == 8054)
self.assertTrue(len(test_data) == 8131)

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
def test_CARS_196_dataloader(self):
test_data = Cars196(
root=TestCars196.CARS_196_ROOT,
Expand All @@ -50,5 +53,5 @@ def test_CARS_196_dataloader(self):

@classmethod
def tearDownClass(cls):
if not cls.ALREADY_EXISTS:
if not cls.ALREADY_EXISTS and os.path.isdir(cls.CARS_196_ROOT):
shutil.rmtree(cls.CARS_196_ROOT)
7 changes: 5 additions & 2 deletions tests/datasets/test_cub.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from pytorch_metric_learning.datasets.cub import CUB
from pytorch_metric_learning.datasets import CUB
from .. import TEST_DATASETS


class TestCUB(unittest.TestCase):
Expand All @@ -19,6 +20,7 @@ def setUpClass(cls):
if os.path.exists(cls.CUB_ROOT):
cls.ALREADY_EXISTS = True

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
def test_CUB(self):
train_test_data = CUB(root=TestCUB.CUB_ROOT, split="train+test", download=True)
train_data = CUB(root=TestCUB.CUB_ROOT, split="train", download=True)
Expand All @@ -28,6 +30,7 @@ def test_CUB(self):
self.assertTrue(len(train_data) == 5864)
self.assertTrue(len(test_data) == 5924)

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
def test_CUB_dataloader(self):
test_data = CUB(
root=TestCUB.CUB_ROOT,
Expand All @@ -44,5 +47,5 @@ def test_CUB_dataloader(self):

@classmethod
def tearDownClass(cls):
if not cls.ALREADY_EXISTS:
if not cls.ALREADY_EXISTS and os.path.isdir(cls.CUB_ROOT):
shutil.rmtree(cls.CUB_ROOT)
7 changes: 5 additions & 2 deletions tests/datasets/test_inaturalist2018.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from pytorch_metric_learning.datasets.inaturalist2018 import INaturalist2018
from pytorch_metric_learning.datasets import INaturalist2018
from .. import TEST_DATASETS


class TestINaturalist2018(unittest.TestCase):
Expand All @@ -19,6 +20,7 @@ def setUpClass(cls):
if os.path.exists(cls.INATURALIST2018_ROOT):
cls.ALREADY_EXISTS = True

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
def test_INaturalist2018(self):
train_test_data = INaturalist2018(
root=TestINaturalist2018.INATURALIST2018_ROOT,
Expand All @@ -36,6 +38,7 @@ def test_INaturalist2018(self):
self.assertTrue(len(train_data) == 325846)
self.assertTrue(len(test_data) == 136093)

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
def test_INaturalist2018_dataloader(self):
test_data = INaturalist2018(
root=TestINaturalist2018.INATURALIST2018_ROOT,
Expand All @@ -52,5 +55,5 @@ def test_INaturalist2018_dataloader(self):

@classmethod
def tearDownClass(cls):
if not cls.ALREADY_EXISTS:
if not cls.ALREADY_EXISTS and os.path.isdir(cls.INATURALIST2018_ROOT):
shutil.rmtree(cls.INATURALIST2018_ROOT)
8 changes: 6 additions & 2 deletions tests/datasets/test_sop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from pytorch_metric_learning.datasets.sop import StanfordOnlineProducts
from pytorch_metric_learning.datasets import StanfordOnlineProducts
from .. import TEST_DATASETS


class TestStanfordOnlineProducts(unittest.TestCase):
Expand All @@ -19,6 +20,7 @@ def setUpClass(cls):
if os.path.exists(cls.SOP_ROOT):
cls.ALREADY_EXISTS = True

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
def test_SOP(self):
train_test_data = StanfordOnlineProducts(
root=TestStanfordOnlineProducts.SOP_ROOT, split="train+test", download=True
Expand All @@ -34,6 +36,7 @@ def test_SOP(self):
self.assertTrue(len(train_data) == 59551)
self.assertTrue(len(test_data) == 60502)

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
def test_SOP_dataloader(self):
test_data = StanfordOnlineProducts(
root=TestStanfordOnlineProducts.SOP_ROOT,
Expand All @@ -48,7 +51,8 @@ def test_SOP_dataloader(self):
self.assertTupleEqual(tuple(inputs.shape), (8, 3, 224, 224))
self.assertTupleEqual(tuple(labels.shape), (8,))

@unittest.skipUnless(TEST_DATASETS, "TEST_DATASETS is false")
@classmethod
def tearDownClass(cls):
if not cls.ALREADY_EXISTS:
if not cls.ALREADY_EXISTS and os.path.isdir(cls.SOP_ROOT):
shutil.rmtree(cls.SOP_ROOT)
Loading