Skip to content

Commit

Permalink
Improve tests to detect circular imports and resolve all of them (#1357)
Browse files Browse the repository at this point in the history
* Improve tests to resolve circular imports

* ruff format

* Ruff format compatible -> Ignore string typing

* Ignore types given as string

* Also pyright ignore string types

* Add better documentation of the test

* Test should work

* Fix comment

* torch < 2.6
  • Loading branch information
manuelgloeckler authored Jan 29, 2025
1 parent 448cef2 commit 6d527f7
Show file tree
Hide file tree
Showing 12 changed files with 88 additions and 41 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ dependencies = [
"scikit-learn",
"scipy<1.13",
"tensorboard",
"torch>=1.13.0",
"torch>=1.13.0, <2.6.0",
"tqdm",
"pymc>=5.0.0",
"zuko>=1.2.0",
Expand Down
4 changes: 3 additions & 1 deletion sbi/inference/posteriors/score_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from sbi.neural_nets.estimators.shape_handling import (
reshape_to_batch_event,
)
from sbi.samplers.score import Corrector, Diffuser, Predictor
from sbi.samplers.score.correctors import Corrector
from sbi.samplers.score.diffuser import Diffuser
from sbi.samplers.score.predictors import Predictor
from sbi.sbi_types import Shape
from sbi.utils import check_prior
from sbi.utils.torchutils import ensure_theta_batched
Expand Down
8 changes: 4 additions & 4 deletions sbi/inference/posteriors/vi_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.samplers.vi import (
from sbi.samplers.vi.vi_divergence_optimizers import get_VI_method
from sbi.samplers.vi.vi_pyro_flows import get_flow_builder
from sbi.samplers.vi.vi_quality_control import get_quality_metric
from sbi.samplers.vi.vi_utils import (
adapt_variational_distribution,
check_variational_distribution,
get_VI_method,
get_flow_builder,
get_quality_metric,
make_object_deepcopy_compatible,
move_all_tensor_to_device,
)
Expand Down
3 changes: 1 addition & 2 deletions sbi/inference/potentials/score_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
reshape_to_sample_batch_event,
)
from sbi.sbi_types import TorchTransform
from sbi.utils import mcmc_transform
from sbi.utils.sbiutils import within_support
from sbi.utils.sbiutils import mcmc_transform, within_support
from sbi.utils.torchutils import ensure_theta_batched


Expand Down
4 changes: 4 additions & 0 deletions sbi/samplers/rejection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from sbi.samplers.rejection.rejection import (
accept_reject_sample,
rejection_sample,
)
2 changes: 1 addition & 1 deletion sbi/samplers/score/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from sbi.samplers.score.correctors import Corrector, get_corrector
from sbi.samplers.score.diffuser import Diffuser
from sbi.samplers.score.predictors import Predictor, get_predictor
from sbi.samplers.score.score import Diffuser
10 changes: 4 additions & 6 deletions sbi/samplers/score/score.py → sbi/samplers/score/diffuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
from torch import Tensor
from tqdm.auto import tqdm

from sbi.inference.potentials.score_based_potential import (
PosteriorScoreBasedPotential,
)
from sbi.samplers.score import Corrector, Predictor, get_corrector, get_predictor
from sbi.samplers.score.correctors import Corrector, get_corrector
from sbi.samplers.score.predictors import Predictor, get_predictor


class Diffuser:
Expand All @@ -19,7 +17,7 @@ class Diffuser:

def __init__(
self,
score_based_potential: PosteriorScoreBasedPotential,
score_based_potential: 'PosteriorScoreBasedPotential', # noqa: F821 # type: ignore
predictor: Union[str, Predictor],
corrector: Optional[Union[str, Corrector]] = None,
predictor_params: Optional[dict] = None,
Expand Down Expand Up @@ -64,7 +62,7 @@ def __init__(
def set_predictor(
self,
predictor: Union[str, Predictor],
score_based_potential: PosteriorScoreBasedPotential,
score_based_potential: 'PosteriorScoreBasedPotential', # noqa: F821 # type: ignore
**kwargs,
):
"""Set the predictor for the diffusion-based sampler."""
Expand Down
12 changes: 5 additions & 7 deletions sbi/samplers/score/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@
import torch
from torch import Tensor

from sbi.inference.potentials.score_based_potential import (
PosteriorScoreBasedPotential,
)

PREDICTORS = {}


def get_predictor(
name: str, score_based_potential: PosteriorScoreBasedPotential, **kwargs
name: str,
score_based_potential: 'PosteriorScoreBasedPotential', # noqa: F821 # type: ignore
**kwargs,
) -> "Predictor":
"""Helper function to get predictor by name.
Expand Down Expand Up @@ -54,7 +52,7 @@ class Predictor(ABC):

def __init__(
self,
potential_fn: PosteriorScoreBasedPotential,
potential_fn: 'PosteriorScoreBasedPotential', # noqa: F821 # type: ignore
):
"""Initialize predictor.
Expand Down Expand Up @@ -94,7 +92,7 @@ def predict(self, theta: Tensor, t1: Tensor, t0: Tensor) -> Tensor:
class EulerMaruyama(Predictor):
def __init__(
self,
potential_fn: PosteriorScoreBasedPotential,
potential_fn: 'PosteriorScoreBasedPotential', # noqa: F821 # type: ignore
eta: float = 1.0,
):
"""Simple Euler-Maruyama discretization of the associated family of reverse
Expand Down
7 changes: 0 additions & 7 deletions sbi/samplers/vi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,3 @@
)
from sbi.samplers.vi.vi_pyro_flows import get_default_flows, get_flow_builder
from sbi.samplers.vi.vi_quality_control import get_quality_metric
from sbi.samplers.vi.vi_utils import (
adapt_variational_distribution,
check_variational_distribution,
detach_all_non_leaf_tensors,
make_object_deepcopy_compatible,
move_all_tensor_to_device,
)
3 changes: 1 addition & 2 deletions sbi/samplers/vi/vi_divergence_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from torch.optim.rmsprop import RMSprop
from torch.optim.sgd import SGD

from sbi.inference.potentials.base_potential import BasePotential
from sbi.samplers.vi.vi_utils import (
filter_kwrags_for_func,
make_object_deepcopy_compatible,
Expand All @@ -47,7 +46,7 @@ class DivergenceOptimizer(ABC):

def __init__(
self,
potential_fn: BasePotential,
potential_fn: 'BasePotential', # noqa: F821 # type: ignore
q: PyroTransformedDistribution,
prior: Optional[Distribution] = None,
n_particles: int = 256,
Expand Down
4 changes: 2 additions & 2 deletions sbi/utils/restriction_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from torch.utils import data
from torch.utils.data.sampler import SubsetRandomSampler, WeightedRandomSampler

from sbi.samplers import rejection
from sbi.samplers.importance.sir import sampling_importance_resampling
from sbi.samplers.rejection.rejection import accept_reject_sample
from sbi.sbi_types import Shape
from sbi.utils.sbiutils import (
get_simulations_since_round,
Expand Down Expand Up @@ -684,7 +684,7 @@ def sample(
sample_with = self._sample_with if sample_with is None else sample_with

if sample_with == "rejection":
samples, acceptance_rate = accept_reject_sample(
samples, acceptance_rate = rejection.accept_reject_sample(
proposal=self._prior,
accept_reject_fn=self._accept_reject_fn,
num_samples=num_samples,
Expand Down
70 changes: 62 additions & 8 deletions tests/circular_import_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import contextlib
import importlib
import inspect
import pkgutil
import random
import sys
from types import ModuleType


def import_first_function(module_name: str):
# This is a helper which emulates the following:
# from sbi.xxx import yyy
"""This is a helper which imports the first function from a module.
Performs: `from sbi.module_name import first_function_name`
Args:
module_name: Name of the module.
"""

module = importlib.import_module(module_name)
functions = inspect.getmembers(module, inspect.isfunction)

Expand All @@ -20,33 +29,78 @@ def import_first_function(module_name: str):
pass


def find_submodules(package_name):
# This is a helper which finds all modules from which we could import
def reset_environment():
"""This is a helper which resets the environment by deleting all sbi modules."""
# NOTE: Beware when working with sys.modules directly as this can lead to side
# effects, e.g., sporadic errors with pickle!
for module_name in list(sys.modules.keys()):
if module_name.startswith("sbi"):
if module_name in globals():
del globals()[module_name]
elif module_name in locals():
del locals()[module_name]


def find_submodules(package_name: str):
"""This is a helper which finds all submodules of a package.
Args:
package_name: Name of the package.
"""
submodules = []
package = __import__(package_name)

for _, name, _ in pkgutil.walk_packages(package.__path__):
full_name = package.__name__ + "." + name
submodules.append(full_name)
def walk_submodules(package: ModuleType):
"""This is a recursive helper function which walks all submodules of a package.
Args:
package: The package to crawl through.
"""
for _, name, is_pkg in pkgutil.walk_packages(
package.__path__, package.__name__ + "."
):
submodules.append(name)
if is_pkg:
# There are some wanted import errors for deprecated modules
with contextlib.suppress(ImportError):
walk_submodules(importlib.import_module(name))

walk_submodules(package)
return submodules


def test_for_circular_imports():
"""This test checks for circular imports in the sbi package.
In order to do so, it is tested if we can directly import from all submodules i.e.
`from sbi.module_name import first_function_name` or `import sbi.module_name`
without any import errors.
"""
modules = find_submodules("sbi")
# Permute the list of modules
random.shuffle(modules)

errors = []
for module_name in modules:
# Try to import
if "sbi.examples" in module_name:
# This is not really a module :/ Hence skip it...
continue
try:
reset_environment()
# Tests if we can: import module_name
module = importlib.import_module(module_name)
reset_environment()
# Tests if we can: from module_name import xxx
import_first_function(module.__name__)

del module
except ImportError as e:
raise AssertionError(f"Cannot import {module_name}. Error: {e}") from e
# NOTE: There might be other errors which are intended
if "circular import" in str(e):
# This is a circular import detected
errors.append(f"Circular import detected in {module_name}. Error: {e}")
print(f"Circular import detected in {module_name}")

assert len(errors) == 0, "\n".join(errors)

0 comments on commit 6d527f7

Please sign in to comment.