Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove logging of figures #2184

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions tests/datamodules/test_fair1m.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import os

import matplotlib.pyplot as plt
import pytest

from torchgeo.datamodules import FAIR1MDataModule
Expand All @@ -29,14 +28,3 @@ def test_val_dataloader(self, datamodule: FAIR1MDataModule) -> None:
def test_predict_dataloader(self, datamodule: FAIR1MDataModule) -> None:
datamodule.setup('predict')
next(iter(datamodule.predict_dataloader()))

def test_plot(self, datamodule: FAIR1MDataModule) -> None:
datamodule.setup('validate')
batch = next(iter(datamodule.val_dataloader()))
sample = {
'image': batch['image'][0],
'boxes': batch['boxes'][0],
'label': batch['label'][0],
}
datamodule.plot(sample)
plt.close()
18 changes: 0 additions & 18 deletions tests/datamodules/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@

from typing import Any

import matplotlib.pyplot as plt
import pytest
import torch
from _pytest.fixtures import SubRequest
from lightning.pytorch import Trainer
from matplotlib.figure import Figure
from rasterio.crs import CRS
from torch import Tensor

Expand All @@ -34,9 +32,6 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
image = torch.arange(3 * 2 * 2).view(3, 2, 2)
return {'image': image, 'crs': CRS.from_epsg(4326), 'bounds': query}

def plot(self, *args: Any, **kwargs: Any) -> Figure:
return plt.figure()


class CustomGeoDataModule(GeoDataModule):
def __init__(self) -> None:
Expand Down Expand Up @@ -73,9 +68,6 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
def __len__(self) -> int:
return self.length

def plot(self, *args: Any, **kwargs: Any) -> Figure:
return plt.figure()


class CustomNonGeoDataModule(NonGeoDataModule):
def __init__(self) -> None:
Expand Down Expand Up @@ -133,11 +125,6 @@ def test_predict(self, datamodule: CustomGeoDataModule) -> None:
batch = datamodule.transfer_batch_to_device(batch, torch.device('cpu'), 1)
batch = datamodule.on_after_batch_transfer(batch, 0)

def test_plot(self, datamodule: CustomGeoDataModule) -> None:
datamodule.setup('validate')
datamodule.plot()
plt.close()

def test_no_datasets(self) -> None:
dm = CustomGeoDataModule()
msg = r'CustomGeoDataModule\.setup must define one of '
Expand Down Expand Up @@ -235,11 +222,6 @@ def test_predict(self, datamodule: CustomNonGeoDataModule) -> None:
batch = next(iter(datamodule.predict_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)

def test_plot(self, datamodule: CustomNonGeoDataModule) -> None:
datamodule.setup('validate')
datamodule.plot()
plt.close()

def test_no_datasets(self) -> None:
dm = CustomNonGeoDataModule()
msg = r'CustomNonGeoDataModule\.setup must define one of '
Expand Down
9 changes: 0 additions & 9 deletions tests/datamodules/test_usavars.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@

import os

import matplotlib.pyplot as plt
import pytest
from _pytest.fixtures import SubRequest

from torchgeo.datamodules import USAVarsDataModule
from torchgeo.datasets import unbind_samples


class TestUSAVarsDataModule:
Expand Down Expand Up @@ -41,10 +39,3 @@ def test_test_dataloader(self, datamodule: USAVarsDataModule) -> None:
assert len(datamodule.test_dataloader()) == 1
batch = next(iter(datamodule.test_dataloader()))
assert batch['image'].shape[0] == datamodule.batch_size

def test_plot(self, datamodule: USAVarsDataModule) -> None:
datamodule.setup('validate')
batch = next(iter(datamodule.val_dataloader()))
sample = unbind_samples(batch)[0]
datamodule.plot(sample)
plt.close()
9 changes: 0 additions & 9 deletions tests/datamodules/test_xview2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@

import os

import matplotlib.pyplot as plt
import pytest

from torchgeo.datamodules import XView2DataModule
from torchgeo.datasets import unbind_samples


class TestXView2DataModule:
Expand All @@ -33,10 +31,3 @@ def test_val_dataloader(self, datamodule: XView2DataModule) -> None:
def test_test_dataloader(self, datamodule: XView2DataModule) -> None:
datamodule.setup('test')
next(iter(datamodule.test_dataloader()))

def test_plot(self, datamodule: XView2DataModule) -> None:
datamodule.setup('validate')
batch = next(iter(datamodule.val_dataloader()))
sample = unbind_samples(batch)[0]
datamodule.plot(sample)
plt.close()
70 changes: 1 addition & 69 deletions tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
EuroSATDataModule,
MisconfigurationException,
)
from torchgeo.datasets import BigEarthNet, EuroSAT, RGBBandsMissingError
from torchgeo.datasets import BigEarthNet, EuroSAT
from torchgeo.main import main
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask
Expand Down Expand Up @@ -61,14 +61,6 @@ def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
return state_dict


def plot(*args: Any, **kwargs: Any) -> None:
return None


def plot_missing_bands(*args: Any, **kwargs: Any) -> None:
raise RGBBandsMissingError()


class TestClassificationTask:
@pytest.mark.parametrize(
'name',
Expand Down Expand Up @@ -186,34 +178,6 @@ def test_invalid_loss(self) -> None:
with pytest.raises(ValueError, match=match):
ClassificationTask(model='resnet18', loss='invalid_loss')

def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(EuroSATDataModule, 'plot', plot)
datamodule = EuroSATDataModule(
root='tests/data/eurosat', batch_size=1, num_workers=0
)
model = ClassificationTask(model='resnet18', in_channels=13, num_classes=10)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(EuroSATDataModule, 'plot', plot_missing_bands)
datamodule = EuroSATDataModule(
root='tests/data/eurosat', batch_size=1, num_workers=0
)
model = ClassificationTask(model='resnet18', in_channels=13, num_classes=10)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_predict(self, fast_dev_run: bool) -> None:
datamodule = PredictClassificationDataModule(
root='tests/data/eurosat', batch_size=1, num_workers=0
Expand Down Expand Up @@ -277,38 +241,6 @@ def test_invalid_loss(self) -> None:
with pytest.raises(ValueError, match=match):
MultiLabelClassificationTask(model='resnet18', loss='invalid_loss')

def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(BigEarthNetDataModule, 'plot', plot)
datamodule = BigEarthNetDataModule(
root='tests/data/bigearthnet', batch_size=1, num_workers=0
)
model = MultiLabelClassificationTask(
model='resnet18', in_channels=14, num_classes=19, loss='bce'
)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(BigEarthNetDataModule, 'plot', plot_missing_bands)
datamodule = BigEarthNetDataModule(
root='tests/data/bigearthnet', batch_size=1, num_workers=0
)
model = MultiLabelClassificationTask(
model='resnet18', in_channels=14, num_classes=19, loss='bce'
)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_predict(self, fast_dev_run: bool) -> None:
datamodule = PredictMultiLabelClassificationDataModule(
root='tests/data/bigearthnet', batch_size=1, num_workers=0
Expand Down
38 changes: 1 addition & 37 deletions tests/trainers/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.nn.modules import Module

from torchgeo.datamodules import MisconfigurationException, NASAMarineDebrisDataModule
from torchgeo.datasets import NASAMarineDebris, RGBBandsMissingError
from torchgeo.datasets import NASAMarineDebris
from torchgeo.main import main
from torchgeo.trainers import ObjectDetectionTask

Expand All @@ -26,10 +26,6 @@ def setup(self, stage: str) -> None:
self.predict_dataset = NASAMarineDebris(**self.kwargs)


def plot_missing_bands(*args: Any, **kwargs: Any) -> None:
raise RGBBandsMissingError()


class ObjectDetectionTestModel(Module):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__()
Expand Down Expand Up @@ -62,10 +58,6 @@ def forward(self, images: Any, targets: Any = None) -> Any:
return output


def plot(*args: Any, **kwargs: Any) -> None:
return None


class TestObjectDetectionTask:
@pytest.mark.parametrize('name', ['nasa_marine_debris', 'vhr10'])
@pytest.mark.parametrize('model_name', ['faster-rcnn', 'fcos', 'retinanet'])
Expand Down Expand Up @@ -120,34 +112,6 @@ def test_invalid_backbone(self) -> None:
def test_pretrained_backbone(self) -> None:
ObjectDetectionTask(backbone='resnet18', weights=True)

def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(NASAMarineDebrisDataModule, 'plot', plot)
datamodule = NASAMarineDebrisDataModule(
root='tests/data/nasa_marine_debris', batch_size=1, num_workers=0
)
model = ObjectDetectionTask(backbone='resnet18', num_classes=2)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(NASAMarineDebrisDataModule, 'plot', plot_missing_bands)
datamodule = NASAMarineDebrisDataModule(
root='tests/data/nasa_marine_debris', batch_size=1, num_workers=0
)
model = ObjectDetectionTask(backbone='resnet18', num_classes=2)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_predict(self, fast_dev_run: bool) -> None:
datamodule = PredictObjectDetectionDataModule(
root='tests/data/nasa_marine_debris', batch_size=1, num_workers=0
Expand Down
38 changes: 1 addition & 37 deletions tests/trainers/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torchvision.models._api import WeightsEnum

from torchgeo.datamodules import MisconfigurationException, TropicalCycloneDataModule
from torchgeo.datasets import RGBBandsMissingError, TropicalCyclone
from torchgeo.datasets import TropicalCyclone
from torchgeo.main import main
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import PixelwiseRegressionTask, RegressionTask
Expand Down Expand Up @@ -51,14 +51,6 @@ def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
return state_dict


def plot(*args: Any, **kwargs: Any) -> None:
return None


def plot_missing_bands(*args: Any, **kwargs: Any) -> None:
raise RGBBandsMissingError()


class TestRegressionTask:
@classmethod
def create_model(*args: Any, **kwargs: Any) -> Module:
Expand Down Expand Up @@ -156,34 +148,6 @@ def test_weight_str_download(self, weights: WeightsEnum) -> None:
in_channels=weights.meta['in_chans'],
)

def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(TropicalCycloneDataModule, 'plot', plot)
datamodule = TropicalCycloneDataModule(
root='tests/data/cyclone', batch_size=1, num_workers=0
)
model = RegressionTask(model='resnet18')
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(TropicalCycloneDataModule, 'plot', plot_missing_bands)
datamodule = TropicalCycloneDataModule(
root='tests/data/cyclone', batch_size=1, num_workers=0
)
model = RegressionTask(model='resnet18')
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_predict(self, fast_dev_run: bool) -> None:
datamodule = PredictRegressionDataModule(
root='tests/data/cyclone', batch_size=1, num_workers=0
Expand Down
Loading
Loading