Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve tests to detect circular imports and resolve all of them #1357

Merged
merged 9 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ dependencies = [
"scikit-learn",
"scipy<1.13",
"tensorboard",
"torch>=1.13.0",
"torch>=1.13.0, <2.6.0",
"tqdm",
"pymc>=5.0.0",
"zuko>=1.2.0",
Expand Down
4 changes: 3 additions & 1 deletion sbi/inference/posteriors/score_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from sbi.neural_nets.estimators.shape_handling import (
reshape_to_batch_event,
)
from sbi.samplers.score import Corrector, Diffuser, Predictor
from sbi.samplers.score.correctors import Corrector
from sbi.samplers.score.diffuser import Diffuser
from sbi.samplers.score.predictors import Predictor
from sbi.sbi_types import Shape
from sbi.utils import check_prior
from sbi.utils.torchutils import ensure_theta_batched
Expand Down
8 changes: 4 additions & 4 deletions sbi/inference/posteriors/vi_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.samplers.vi import (
from sbi.samplers.vi.vi_divergence_optimizers import get_VI_method
from sbi.samplers.vi.vi_pyro_flows import get_flow_builder
from sbi.samplers.vi.vi_quality_control import get_quality_metric
from sbi.samplers.vi.vi_utils import (
adapt_variational_distribution,
check_variational_distribution,
get_VI_method,
get_flow_builder,
get_quality_metric,
make_object_deepcopy_compatible,
move_all_tensor_to_device,
)
Expand Down
3 changes: 1 addition & 2 deletions sbi/inference/potentials/score_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
reshape_to_sample_batch_event,
)
from sbi.sbi_types import TorchTransform
from sbi.utils import mcmc_transform
from sbi.utils.sbiutils import within_support
from sbi.utils.sbiutils import mcmc_transform, within_support
from sbi.utils.torchutils import ensure_theta_batched


Expand Down
4 changes: 4 additions & 0 deletions sbi/samplers/rejection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from sbi.samplers.rejection.rejection import (
accept_reject_sample,
rejection_sample,
)
2 changes: 1 addition & 1 deletion sbi/samplers/score/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from sbi.samplers.score.correctors import Corrector, get_corrector
from sbi.samplers.score.diffuser import Diffuser
from sbi.samplers.score.predictors import Predictor, get_predictor
from sbi.samplers.score.score import Diffuser
10 changes: 4 additions & 6 deletions sbi/samplers/score/score.py → sbi/samplers/score/diffuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
from torch import Tensor
from tqdm.auto import tqdm

from sbi.inference.potentials.score_based_potential import (
PosteriorScoreBasedPotential,
)
from sbi.samplers.score import Corrector, Predictor, get_corrector, get_predictor
from sbi.samplers.score.correctors import Corrector, get_corrector
from sbi.samplers.score.predictors import Predictor, get_predictor


class Diffuser:
Expand All @@ -19,7 +17,7 @@ class Diffuser:

def __init__(
self,
score_based_potential: PosteriorScoreBasedPotential,
score_based_potential: 'PosteriorScoreBasedPotential', # noqa: F821 # type: ignore
predictor: Union[str, Predictor],
corrector: Optional[Union[str, Corrector]] = None,
predictor_params: Optional[dict] = None,
Expand Down Expand Up @@ -64,7 +62,7 @@ def __init__(
def set_predictor(
self,
predictor: Union[str, Predictor],
score_based_potential: PosteriorScoreBasedPotential,
score_based_potential: 'PosteriorScoreBasedPotential', # noqa: F821 # type: ignore
**kwargs,
):
"""Set the predictor for the diffusion-based sampler."""
Expand Down
12 changes: 5 additions & 7 deletions sbi/samplers/score/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@
import torch
from torch import Tensor

from sbi.inference.potentials.score_based_potential import (
PosteriorScoreBasedPotential,
)

PREDICTORS = {}


def get_predictor(
name: str, score_based_potential: PosteriorScoreBasedPotential, **kwargs
name: str,
score_based_potential: 'PosteriorScoreBasedPotential', # noqa: F821 # type: ignore
**kwargs,
) -> "Predictor":
"""Helper function to get predictor by name.

Expand Down Expand Up @@ -54,7 +52,7 @@ class Predictor(ABC):

def __init__(
self,
potential_fn: PosteriorScoreBasedPotential,
potential_fn: 'PosteriorScoreBasedPotential', # noqa: F821 # type: ignore
):
"""Initialize predictor.

Expand Down Expand Up @@ -94,7 +92,7 @@ def predict(self, theta: Tensor, t1: Tensor, t0: Tensor) -> Tensor:
class EulerMaruyama(Predictor):
def __init__(
self,
potential_fn: PosteriorScoreBasedPotential,
potential_fn: 'PosteriorScoreBasedPotential', # noqa: F821 # type: ignore
eta: float = 1.0,
):
"""Simple Euler-Maruyama discretization of the associated family of reverse
Expand Down
7 changes: 0 additions & 7 deletions sbi/samplers/vi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,3 @@
)
from sbi.samplers.vi.vi_pyro_flows import get_default_flows, get_flow_builder
from sbi.samplers.vi.vi_quality_control import get_quality_metric
from sbi.samplers.vi.vi_utils import (
adapt_variational_distribution,
check_variational_distribution,
detach_all_non_leaf_tensors,
make_object_deepcopy_compatible,
move_all_tensor_to_device,
)
3 changes: 1 addition & 2 deletions sbi/samplers/vi/vi_divergence_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from torch.optim.rmsprop import RMSprop
from torch.optim.sgd import SGD

from sbi.inference.potentials.base_potential import BasePotential
from sbi.samplers.vi.vi_utils import (
filter_kwrags_for_func,
make_object_deepcopy_compatible,
Expand All @@ -47,7 +46,7 @@ class DivergenceOptimizer(ABC):

def __init__(
self,
potential_fn: BasePotential,
potential_fn: 'BasePotential', # noqa: F821 # type: ignore
q: PyroTransformedDistribution,
prior: Optional[Distribution] = None,
n_particles: int = 256,
Expand Down
4 changes: 2 additions & 2 deletions sbi/utils/restriction_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from torch.utils import data
from torch.utils.data.sampler import SubsetRandomSampler, WeightedRandomSampler

from sbi.samplers import rejection
from sbi.samplers.importance.sir import sampling_importance_resampling
from sbi.samplers.rejection.rejection import accept_reject_sample
from sbi.sbi_types import Shape
from sbi.utils.sbiutils import (
get_simulations_since_round,
Expand Down Expand Up @@ -684,7 +684,7 @@ def sample(
sample_with = self._sample_with if sample_with is None else sample_with

if sample_with == "rejection":
samples, acceptance_rate = accept_reject_sample(
samples, acceptance_rate = rejection.accept_reject_sample(
proposal=self._prior,
accept_reject_fn=self._accept_reject_fn,
num_samples=num_samples,
Expand Down
70 changes: 62 additions & 8 deletions tests/circular_import_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import contextlib
import importlib
import inspect
import pkgutil
import random
import sys
from types import ModuleType


def import_first_function(module_name: str):
# This is a helper which emulates the following:
# from sbi.xxx import yyy
"""This is a helper which imports the first function from a module.
Performs: `from sbi.module_name import first_function_name`
Args:
module_name: Name of the module.
"""

module = importlib.import_module(module_name)
functions = inspect.getmembers(module, inspect.isfunction)

Expand All @@ -20,33 +29,78 @@ def import_first_function(module_name: str):
pass


def find_submodules(package_name):
# This is a helper which finds all modules from which we could import
def reset_environment():
"""This is a helper which resets the environment by deleting all sbi modules."""
# NOTE: Beware when working with sys.modules directly as this can lead to side
# effects, e.g., sporadic errors with pickle!
for module_name in list(sys.modules.keys()):
if module_name.startswith("sbi"):
if module_name in globals():
del globals()[module_name]
elif module_name in locals():
del locals()[module_name]


def find_submodules(package_name: str):
"""This is a helper which finds all submodules of a package.
Args:
package_name: Name of the package.
"""
submodules = []
package = __import__(package_name)

for _, name, _ in pkgutil.walk_packages(package.__path__):
full_name = package.__name__ + "." + name
submodules.append(full_name)
def walk_submodules(package: ModuleType):
"""This is a recursive helper function which walks all submodules of a package.
Args:
package: The package to crawl through.
"""
for _, name, is_pkg in pkgutil.walk_packages(
package.__path__, package.__name__ + "."
):
submodules.append(name)
if is_pkg:
# There are some wanted import errors for deprecated modules
with contextlib.suppress(ImportError):
walk_submodules(importlib.import_module(name))

walk_submodules(package)
return submodules


def test_for_circular_imports():
"""This test checks for circular imports in the sbi package.
In order to do so, it is tested if we can directly import from all submodules i.e.
`from sbi.module_name import first_function_name` or `import sbi.module_name`
without any import errors.
"""
modules = find_submodules("sbi")
# Permute the list of modules
random.shuffle(modules)

errors = []
for module_name in modules:
# Try to import
if "sbi.examples" in module_name:
# This is not really a module :/ Hence skip it...
continue
try:
reset_environment()
# Tests if we can: import module_name
module = importlib.import_module(module_name)
reset_environment()
# Tests if we can: from module_name import xxx
import_first_function(module.__name__)

del module
except ImportError as e:
raise AssertionError(f"Cannot import {module_name}. Error: {e}") from e
# NOTE: There might be other errors which are intended
if "circular import" in str(e):
# This is a circular import detected
errors.append(f"Circular import detected in {module_name}. Error: {e}")
print(f"Circular import detected in {module_name}")

assert len(errors) == 0, "\n".join(errors)