-
Notifications
You must be signed in to change notification settings - Fork 159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Type 'potential_fn' properly, not 'Callable' #1055
Comments
Note that we have a "hacky" runtime check for this implemented in the sbi/sbi/inference/posteriors/base_posterior.py Lines 55 to 69 in 366b67f
From a discussion with my colleagues at Transferlab, I would suggest using a Step 1: Define a ProtocolDefine a from typing import Protocol, Any
class PotentialProtocol(Protocol):
def __call__(self, *, theta: Any, x_o: Any) -> float:
... Step 2: Type Checking with the ProtocolInstead of checking the arguments at runtime, you use the Example Usage in Classfrom typing import Any
import inspect
class ConditionedPotential:
def __init__(self, potential_fn: PotentialProtocol) -> None:
self.potential_fn = potential_fn
# Optional: Runtime check if not using static type checking tools
if not isinstance(potential_fn, BasePotential):
self._validate_callable_signature(potential_fn)
def _validate_callable_signature(self, fn: PotentialProtocol) -> None:
# Get the signature of the callable
kwargs_of_callable = list(inspect.signature(fn).parameters.keys())
# Ensure required arguments are present
for key in ["theta", "x_o"]:
assert key in kwargs_of_callable, (
"If you pass a `Callable` as `potential_fn` then it must have "
"`theta` and `x_o` as inputs, even if some of these keyword "
"arguments are unused."
)
def some_method(self, theta: Any, x_o: Any) -> float:
return self.potential_fn(theta=theta, x_o=x_o) Explanation
Additional Considerations
This approach formalizes the expected interface and improves both clarity and safety, reducing the likelihood of runtime errors due to incorrectly structured callables. |
We currently type
potential_fn
only asCallable
sbi/sbi/utils/conditional_density_utils.py
Lines 273 to 276 in bae6994
despite implicitly requiring it to be a
BasePotential
to be able to call something likeset_x
:sbi/sbi/utils/conditional_density_utils.py
Line 330 in bae6994
Changing the type to
BasePotential
will lead to other required changes down the line.The text was updated successfully, but these errors were encountered: