Skip to content

Commit

Permalink
add utils for setting up a noise variance model
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Nov 26, 2024
1 parent 762a27f commit f9d0ee9
Showing 1 changed file with 41 additions and 1 deletion.
42 changes: 41 additions & 1 deletion neurobayes/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import List, Dict, Any
import inspect
import re

from typing import List, Dict, Any, Callable

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -237,3 +240,40 @@ def flatten_params_dict(params_dict: Dict[str, Any]) -> Dict[str, Any]:
# Keep searching deeper
flattened.update(flatten_params_dict({key: value}))
return flattened


def set_fn(func: Callable) -> Callable:
"""
Transforms the given deterministic function to use a params dictionary
for its parameters, excluding the first one (assumed to be the dependent variable).
Args:
- func (Callable): The deterministic function to be transformed.
Returns:
- Callable: The transformed function where parameters are accessed
from a `params` dictionary.
"""
# Extract parameter names excluding the first one (assumed to be the dependent variable)
params_names = list(inspect.signature(func).parameters.keys())[1:]

# Create the transformed function definition
transformed_code = f"def {func.__name__}(x, params):\n"

# Retrieve the source code of the function and indent it to be a valid function body
source = inspect.getsource(func).split("\n", 1)[1]
source = " " + source.replace("\n", "\n ")

# Replace each parameter name with its dictionary lookup using regex
for name in params_names:
source = re.sub(rf'\b{name}\b', f'params["{name}"]', source)

# Combine to get the full source
transformed_code += source

# Define the transformed function in the local namespace
local_namespace = {}
exec(transformed_code, globals(), local_namespace)

# Return the transformed function
return local_namespace[func.__name__]

0 comments on commit f9d0ee9

Please sign in to comment.