Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type 'potential_fn' properly, not 'Callable' #1055

Closed
Baschdl opened this issue Mar 20, 2024 · 2 comments · Fixed by #1222
Closed

Type 'potential_fn' properly, not 'Callable' #1055

Baschdl opened this issue Mar 20, 2024 · 2 comments · Fixed by #1222
Assignees
Labels
architecture Internal changes without API consequences

Comments

@Baschdl
Copy link
Contributor

Baschdl commented Mar 20, 2024

We currently type potential_fn only as Callable

class ConditionedPotential:
def __init__(
self,
potential_fn: Callable,

despite implicitly requiring it to be a BasePotential to be able to call something like set_x:
self.potential_fn.set_x(x_o)

Changing the type to BasePotential will lead to other required changes down the line.

@Baschdl Baschdl added the architecture Internal changes without API consequences label Mar 20, 2024
@janfb janfb added this to the Hackathon and release 2024 milestone Aug 6, 2024
@janfb
Copy link
Contributor

janfb commented Aug 6, 2024

Note that we have a "hacky" runtime check for this implemented in the base_posterior.py (but not in ConditionedPotential).

if not isinstance(potential_fn, BasePotential):
kwargs_of_callable = list(inspect.signature(potential_fn).parameters.keys())
for key in ["theta", "x_o"]:
assert key in kwargs_of_callable, (
"If you pass a `Callable` as `potential_fn` then it must have "
"`theta` and `x_o` as inputs, even if some of these keyword "
"arguments are unused."
)
# If the `potential_fn` is a Callable then we wrap it as a
# `CallablePotentialWrapper` which inherits from `BasePotential`.
potential_device = "cpu" if device is None else device
potential_fn = CallablePotentialWrapper(
potential_fn, prior=None, x_o=None, device=potential_device
)

From a discussion with my colleagues at Transferlab, I would suggest using a Protocol to define the required properties of the Callable. Here is an outline for implementing this:

Step 1: Define a Protocol

Define a Protocol that includes a method with the required arguments (theta and x_o in this case). This method will define the expected signature for potential_fn.

from typing import Protocol, Any

class PotentialProtocol(Protocol):
    def __call__(self, *, theta: Any, x_o: Any) -> float:
        ...

Step 2: Type Checking with the Protocol

Instead of checking the arguments at runtime, you use the PotentialProtocol as a type hint. This way, any function or callable object passed as potential_fn must adhere to this protocol.

Example Usage in Class

from typing import Any
import inspect

class ConditionedPotential:
    def __init__(self, potential_fn: PotentialProtocol) -> None:
        self.potential_fn = potential_fn

        # Optional: Runtime check if not using static type checking tools
        if not isinstance(potential_fn, BasePotential):
            self._validate_callable_signature(potential_fn)

    def _validate_callable_signature(self, fn: PotentialProtocol) -> None:
        # Get the signature of the callable
        kwargs_of_callable = list(inspect.signature(fn).parameters.keys())
        
        # Ensure required arguments are present
        for key in ["theta", "x_o"]:
            assert key in kwargs_of_callable, (
                "If you pass a `Callable` as `potential_fn` then it must have "
                "`theta` and `x_o` as inputs, even if some of these keyword "
                "arguments are unused."
            )

    def some_method(self, theta: Any, x_o: Any) -> float:
        return self.potential_fn(theta=theta, x_o=x_o)

Explanation

  1. PotentialProtocol Definition:

    • This protocol defines a method __call__ with keyword arguments theta and x_o. The method returns a float. This setup ensures that any callable passed as potential_fn must implement this signature.
  2. ConditionedPotential Class:

    • The constructor now uses PotentialProtocol to type the potential_fn argument. This provides a clear expectation of the function's interface.
    • The _validate_callable_signature method is an optional runtime check to ensure that the callable conforms to the expected signature, useful if you're not relying solely on static type checking.
  3. Static Type Checking:

    • Using tools like mypy, you can catch potential issues at development time, ensuring that all functions passed as potential_fn conform to the PotentialProtocol.

Additional Considerations

  • BasePotential: If you have existing classes that inherit from BasePotential, ensure they conform to the PotentialProtocol either directly or by implementing the necessary methods.

  • Runtime vs. Static Checking: The static type checking provided by the protocol can catch issues during development. The runtime check (_validate_callable_signature) can be kept for additional safety, especially if you're unsure about the inputs or when integrating with dynamically typed parts of the code.

This approach formalizes the expected interface and improves both clarity and safety, reducing the likelihood of runtime errors due to incorrectly structured callables.

@schroedk
Copy link
Contributor

@janfb as discussed last week, the argument potential_fn of class ConditionalPotential must be of type BasePotential.
I created another issue, regarding our discussion of the different callable types: #1223

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
architecture Internal changes without API consequences
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants