Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

errors in examples/time_series/bayesian_var_model.ipynb #768

Open
ikuzmin404 opened this issue Jan 23, 2025 · 0 comments
Open

errors in examples/time_series/bayesian_var_model.ipynb #768

ikuzmin404 opened this issue Jan 23, 2025 · 0 comments

Comments

@ikuzmin404
Copy link

errors in examples/time_series/bayesian_var_model.ipynb

Error in bayesian_var_model.ipynb:
Notebook url: https://github.com/pymc-devs/pymc-examples/tree/main/examples/time_series/bayesian_var_model.ipynb

Issue description

Error in imports: replace from pymc.sampling_jax import sample_blackjax_nuts with from pymc.sampling.jax import sample_blackjax_nuts

Error on creating betaX in make_model and make_hierarchical_model:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[44], line 1
----> 1 make_model(n_lags, n_eqs, df, priors)

Cell In[43], line 44, in make_model(n_lags, n_eqs, df, priors, mv_norm, prior_checks)
     41 data_obs = pm.Data("data_obs", df.values[n_lags:], dims=["time", "equations"])
     43 betaX = calc_ar_step(lag_coefs, n_eqs, n_lags, df)
---> 44 betaX = pm.Deterministic(
     45     "betaX",
     46     betaX,
     47     dims=[
     48         "time",
     49     ],
     50 )
     51 mean = alpha + betaX
     53 if mv_norm:

File c:\Users\Ivan\anaconda3\envs\pymc_env\Lib\site-packages\pymc\model\core.py:2254, in Deterministic(name, var, model, dims)
   2252 var = var.copy(model.name_for(name))
   2253 model.deterministics.append(var)
-> 2254 model.add_named_variable(var, dims)
   2256 from pymc.printing import str_for_potential_or_deterministic
   2258 var.str_repr = types.MethodType(
   2259     functools.partial(str_for_potential_or_deterministic, dist_name="Deterministic"), var
   2260 )

File c:\Users\Ivan\anaconda3\envs\pymc_env\Lib\site-packages\pymc\model\core.py:1472, in Model.add_named_variable(self, var, dims)
   1470     # This check implicitly states that only vars with .ndim attribute can have dims
   1471     if var.ndim != len(dims):
-> 1472         raise ValueError(
   1473             f"{var} has {var.ndim} dims but {len(dims)} dim labels were provided."
   1474         )
   1475     self.named_vars_to_dims[var.name] = dims
   1477 self.named_vars[var.name] = var

ValueError: betaX has 2 dims but 1 dim labels were provided.

Proposed solution

Adding another dimension to broken piece of code (namely "equations") solves the problem:

betaX = pm.Deterministic(
            "betaX",
            betaX,
            dims=[
                "time",
                "equations",
            ],
        )

Another issue

This error, AFAIC, is purely Windows-related (see here). In function make_hierarchical_model this line breaks:
idata.extend(sample_blackjax_nuts(2000, random_seed=120)). Same error with sample_numpyro_nuts.

First error is RuntimeError: Incorrect output dtype for return value #0: Expected: int64, Actual: int32. It's being fixed as in this issue.

But then another error shows:

TypeError: true_fun and false_fun output must have identical types, got
Proposal(state=IntegratorState(position=['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs.  ShapedArray(float32[6])'], momentum=['ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[3,2,3])', 'ShapedArray(float64[3])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[3,2,3])', 'ShapedArray(float64[3])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[3,2,3])', 'ShapedArray(float64[3])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[3,2,3])', 'ShapedArray(float64[3])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[3,2,3])', 'ShapedArray(float64[3])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[3,2,3])', 'ShapedArray(float64[3])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[3,2,3])', 'ShapedArray(float64[3])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[3,2,3])', 'ShapedArray(float64[3])', 'ShapedArray(float64[6])'], logdensity='ShapedArray(float64[])', logdensity_grad=['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs.  ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])']), energy='ShapedArray(float64[])', weight='ShapedArray(float64[])', sum_log_p_accept='ShapedArray(float64[])').

and I have no idea how to solve it.

Possible solution

The workaround (if it is not an issue for Linux systems) is to use simple pm.sample instead of sample_blackjax_nuts if code is running on Windows (can be checked with if os.name == 'nt' for example).

This behavior was also fixed in numpy 2.0 (link to release notes), so this solution may be temporary.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant