From cd5dfbe4e09e3e450b384eacbc2d3292734ea9e7 Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Sun, 29 Dec 2024 08:11:49 -0300 Subject: [PATCH] refactor rng_fn method (#212) --- pymc_bart/bart.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pymc_bart/bart.py b/pymc_bart/bart.py index ac2be35..5114b6e 100644 --- a/pymc_bart/bart.py +++ b/pymc_bart/bart.py @@ -55,12 +55,12 @@ def rng_fn( # pylint: disable=W0237 if not size: size = None - if isinstance(cls.Y, (TensorSharedVariable, TensorVariable)): - Y = cls.Y.eval() - else: - Y = cls.Y - if not cls.all_trees: + if isinstance(cls.Y, (TensorSharedVariable, TensorVariable)): + Y = cls.Y.eval() + else: + Y = cls.Y + if size is not None: return np.full((size[0], Y.shape[0]), Y.mean()) else: