Skip to content

Commit

Permalink
Merge pull request #14 from kathryn-baker/main
Browse files Browse the repository at this point in the history
allowing parameters to be initialised from a list
  • Loading branch information
t-bz authored Oct 7, 2024
2 parents c6ae740 + c25f895 commit 67a646b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
5 changes: 4 additions & 1 deletion base.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,10 @@ def _initialize_parameter(
# define initial and default value(s)
for value, value_str in zip([initial, default], ["initial", "default"]):
if not isinstance(value, Tensor):
value = float(value) * torch.ones(size)
if isinstance(value, list):
value = torch.tensor(value)
else:
value = float(value) * torch.ones(size)
value_size = value.shape
if value.dim() == 1 and isinstance(size, int):
value_size = value.shape[0]
Expand Down
20 changes: 20 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,26 @@ def test_ndim_parameter(self, linear_model, parameter_name, ndim_size):
assert param.shape == ndim_size
assert raw_param.shape == ndim_size

def test_parameter_lists(self, linear_model, parameter_name, ndim_size):
parameter_prior = NormalPrior(loc=torch.zeros(ndim_size), scale=torch.ones(ndim_size))
kwargs = {
f"{parameter_name}_size": ndim_size,
f"{parameter_name}_default": torch.ones(ndim_size).tolist(),
f"{parameter_name}_initial": torch.ones(ndim_size).tolist(),
f"{parameter_name}_prior": parameter_prior,
f"{parameter_name}_constraint": Interval(lower_bound=-1.5, upper_bound=1.5),
}
m = ParameterModule(
model=linear_model,
parameter_names=[parameter_name],
**kwargs,
)
param = getattr(m, parameter_name)
raw_param = getattr(m, f"raw_{parameter_name}")

assert param.shape == ndim_size
assert raw_param.shape == ndim_size

def test_parameter_mask(self, extensive_parameter_module):
parameter_name = extensive_parameter_module.calibration_parameter_names[0]
param = extensive_parameter_module.calibration_parameters[0]
Expand Down

0 comments on commit 67a646b

Please sign in to comment.