From 2fd6da8653d8b706b5b9298a2821cbac3524e6d5 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Wed, 11 Dec 2024 14:10:52 +0100 Subject: [PATCH] Implement proper truncation for prior distributions Currently, when sampled startpoints are outside the bounds, their value is set to the upper/lower bounds. This may put too much probability mass on the bounds. With these changes, we properly sample from the respective truncated distributions. Closes #330. --- doc/conf.py | 1 + doc/example/distributions.ipynb | 28 +++-- petab/v1/distributions.py | 181 +++++++++++++++++++++++++++----- petab/v1/priors.py | 46 ++++++-- tests/v1/test_distributions.py | 64 ++++++----- 5 files changed, 248 insertions(+), 72 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 4dbd3009..002a7bf0 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -85,6 +85,7 @@ nb_execution_mode = "force" nb_execution_raise_on_error = True nb_execution_show_tb = True +nb_execution_timeout = 90 # max. seconds/cell source_suffix = { ".rst": "restructuredtext", diff --git a/doc/example/distributions.ipynb b/doc/example/distributions.ipynb index 86235fe1..06c08272 100644 --- a/doc/example/distributions.ipynb +++ b/doc/example/distributions.ipynb @@ -23,7 +23,10 @@ }, { "metadata": { - "collapsed": true + "collapsed": true, + "jupyter": { + "is_executing": true + } }, "cell_type": "code", "source": [ @@ -42,7 +45,7 @@ " if ax is None:\n", " fig, ax = plt.subplots()\n", "\n", - " sample = prior.sample(10000)\n", + " sample = prior.sample(20_000)\n", "\n", " # pdf\n", " xmin = min(sample.min(), prior.lb_scaled if prior.bounds is not None else sample.min())\n", @@ -138,11 +141,13 @@ "metadata": {}, "cell_type": "code", "source": [ + "# different, because transformation!=LIN\n", "plot(Prior(UNIFORM, (0.01, 2), transformation=LOG10))\n", "plot(Prior(PARAMETER_SCALE_UNIFORM, (0.01, 2), transformation=LOG10))\n", "\n", + "# same, because transformation=LIN\n", "plot(Prior(UNIFORM, (0.01, 2), transformation=LIN))\n", - "plot(Prior(PARAMETER_SCALE_UNIFORM, (0.01, 2), transformation=LIN))\n" + "plot(Prior(PARAMETER_SCALE_UNIFORM, (0.01, 2), transformation=LIN))" ], "id": "5ca940bc24312fc6", "outputs": [], @@ -151,15 +156,18 @@ { "metadata": {}, "cell_type": "markdown", - "source": "To prevent the sampled parameters from exceeding the bounds, the sampled parameters are clipped to the bounds. The bounds are defined in the parameter table. Note that the current implementation does not support sampling from a truncated distribution. Instead, the samples are clipped to the bounds. This may introduce unwanted bias, and thus, should only be used with caution (i.e., the bounds should be chosen wide enough):", + "source": "The given distributions are truncated at the bounds defined in the parameter table:", "id": "b1a8b17d765db826" }, { "metadata": {}, "cell_type": "code", "source": [ - "plot(Prior(NORMAL, (0, 1), bounds=(-4, 4))) # negligible clipping-bias at 4 sigma\n", - "plot(Prior(UNIFORM, (0, 1), bounds=(0.1, 0.9))) # significant clipping-bias" + "plot(Prior(NORMAL, (0, 1), bounds=(-2, 2)))\n", + "plot(Prior(UNIFORM, (0, 1), bounds=(0.1, 0.9)))\n", + "plot(Prior(UNIFORM, (1e-8, 1), bounds=(0.1, 0.9), transformation=LOG10))\n", + "plot(Prior(LAPLACE, (0, 1), bounds=(-0.5, 0.5)))\n", + "plot(Prior(PARAMETER_SCALE_UNIFORM, (-3, 1), bounds=(1e-2, 1), transformation=LOG10))\n" ], "id": "4ac42b1eed759bdd", "outputs": [], @@ -175,9 +183,11 @@ "metadata": {}, "cell_type": "code", "source": [ - "plot(Prior(NORMAL, (10, 1), bounds=(6, 14), transformation=\"log10\"))\n", - "plot(Prior(PARAMETER_SCALE_NORMAL, (10, 1), bounds=(10**6, 10**14), transformation=\"log10\"))\n", - "plot(Prior(LAPLACE, (10, 2), bounds=(6, 14)))" + "plot(Prior(NORMAL, (10, 1), bounds=(6, 11), transformation=\"log10\"))\n", + "plot(Prior(PARAMETER_SCALE_NORMAL, (10, 1), bounds=(10**9, 10**14), transformation=\"log10\"))\n", + "plot(Prior(LAPLACE, (10, 2), bounds=(6, 14)))\n", + "plot(Prior(LOG_LAPLACE, (1, 0.5), bounds=(0.5, 8)))\n", + "plot(Prior(LOG_NORMAL, (2, 1), bounds=(0.5, 8)))" ], "id": "581e1ac431860419", "outputs": [], diff --git a/petab/v1/distributions.py b/petab/v1/distributions.py index 418f5b44..e0e5c279 100644 --- a/petab/v1/distributions.py +++ b/petab/v1/distributions.py @@ -28,15 +28,65 @@ class Distribution(abc.ABC): If a float, the distribution is transformed to its corresponding log distribution with the given base (e.g., Normal -> Log10Normal). If ``False``, no transformation is applied. + :param trunc: The truncation points (lower, upper) of the distribution + or ``None`` if the distribution is not truncated. """ - def __init__(self, log: bool | float = False): + def __init__( + self, *, log: bool | float = False, trunc: tuple[float, float] = None + ): if log is True: log = np.exp(1) + + if trunc == (-np.inf, np.inf): + trunc = None + + if trunc is not None and trunc[0] > trunc[1]: + raise ValueError( + "The lower truncation limit must be smaller " + "than the upper truncation limit." + ) + self._logbase = log + self._trunc = trunc + + self._cd_low = None + self._cd_high = None + self._truncation_normalizer = 1 + + if self._trunc is not None: + try: + # the cumulative density of the transformed distribution at the + # truncation limits + self._cd_low = self._cdf_transformed_untruncated( + self.trunc_low + ) + self._cd_high = self._cdf_transformed_untruncated( + self.trunc_high + ) + # normalization factor for the PDF of the transformed + # distribution to account for truncation + self._truncation_normalizer = 1 / ( + self._cd_high - self._cd_low + ) + except NotImplementedError: + pass + + @property + def trunc_low(self) -> float: + """The lower truncation limit of the transformed distribution.""" + return self._trunc[0] if self._trunc else -np.inf + + @property + def trunc_high(self) -> float: + """The upper truncation limit of the transformed distribution.""" + return self._trunc[1] if self._trunc else np.inf - def _undo_log(self, x: np.ndarray | float) -> np.ndarray | float: - """Undo the log transformation. + def _exp(self, x: np.ndarray | float) -> np.ndarray | float: + """Exponentiate / undo the log transformation according. + + Exponentiate if a log transformation is applied to the distribution. + Otherwise, return the input. :param x: The sample to transform. :return: The transformed sample @@ -45,9 +95,12 @@ def _undo_log(self, x: np.ndarray | float) -> np.ndarray | float: return x return self._logbase**x - def _apply_log(self, x: np.ndarray | float) -> np.ndarray | float: + def _log(self, x: np.ndarray | float) -> np.ndarray | float: """Apply the log transformation. + Compute the log of x with the specified base if a log transformation + is applied to the distribution. Otherwise, return the input. + :param x: The value to transform. :return: The transformed value. """ @@ -61,12 +114,17 @@ def sample(self, shape=None) -> np.ndarray: :param shape: The shape of the sample. :return: A sample from the distribution. """ - sample = self._sample(shape) - return self._undo_log(sample) + sample = ( + self._exp(self._sample(shape)) + if self._trunc is None + else self._inverse_transform_sample(shape) + ) + + return sample @abc.abstractmethod def _sample(self, shape=None) -> np.ndarray: - """Sample from the underlying distribution. + """Sample from the underlying distribution, accounting for truncation. :param shape: The shape of the sample. :return: A sample from the underlying distribution, @@ -85,7 +143,11 @@ def pdf(self, x): chain_rule_factor = ( (1 / (x * np.log(self._logbase))) if self._logbase else 1 ) - return self._pdf(self._apply_log(x)) * chain_rule_factor + return ( + self._pdf(self._log(x)) + * chain_rule_factor + * self._truncation_normalizer + ) @abc.abstractmethod def _pdf(self, x): @@ -104,13 +166,71 @@ def logbase(self) -> bool | float: """ return self._logbase + def cdf(self, x): + """Cumulative distribution function at x. + + :param x: The value at which to evaluate the CDF. + :return: The value of the CDF at ``x``. + """ + return self._cdf_transformed_untruncated(x) - self._cd_low + + def _cdf_transformed_untruncated(self, x): + """Cumulative distribution function of the transformed, but untruncated + distribution at x. + + :param x: The value at which to evaluate the CDF. + :return: The value of the CDF at ``x``. + """ + return self._cdf_untransformed_untruncated(self._log(x)) + + def _cdf_untransformed_untruncated(self, x): + """Cumulative distribution function of the underlying + (untransformed, untruncated) distribution at x. + + :param x: The value at which to evaluate the CDF. + :return: The value of the CDF at ``x``. + """ + raise NotImplementedError + + def _ppf_untransformed_untruncated(self, q): + """Percent point function of the underlying + (untransformed, untruncated) distribution at q. + + :param q: The quantile at which to evaluate the PPF. + :return: The value of the PPF at ``q``. + """ + raise NotImplementedError + + def _ppf_transformed_untruncated(self, q): + """Percent point function of the transformed, but untruncated + distribution at q. + + :param q: The quantile at which to evaluate the PPF. + :return: The value of the PPF at ``q``. + """ + return self._exp(self._ppf_untransformed_untruncated(q)) + + def _inverse_transform_sample(self, shape): + """Generate an inverse transform sample from the transformed and + truncated distribution. + + :param shape: The shape of the sample. + :return: The sample. + """ + uniform_sample = np.random.uniform( + low=self._cd_low, high=self._cd_high, size=shape + ) + return self._ppf_transformed_untruncated(uniform_sample) + class Normal(Distribution): """A (log-)normal distribution. :param loc: The location parameter of the distribution. :param scale: The scale parameter of the distribution. - :param truncation: The truncation limits of the distribution. + :param trunc: The truncation limits of the distribution. + ``None`` if the distribution is not truncated. The truncation limits + are the truncation limits of the transformed distribution. :param log: If ``True``, the distribution is transformed to a log-normal distribution. If a float, the distribution is transformed to a log-normal distribution with the given base. @@ -124,19 +244,15 @@ def __init__( self, loc: float, scale: float, - truncation: tuple[float, float] | None = None, + trunc: tuple[float, float] | None = None, log: bool | float = False, ): - super().__init__(log=log) self._loc = loc self._scale = scale - self._truncation = truncation - - if truncation is not None: - raise NotImplementedError("Truncation is not yet implemented.") + super().__init__(log=log, trunc=trunc) def __repr__(self): - trunc = f", truncation={self._truncation}" if self._truncation else "" + trunc = f", trunc={self._trunc}" if self._trunc else "" log = f", log={self._logbase}" if self._logbase else "" return f"Normal(loc={self._loc}, scale={self._scale}{trunc}{log})" @@ -146,6 +262,12 @@ def _sample(self, shape=None): def _pdf(self, x): return norm.pdf(x, loc=self._loc, scale=self._scale) + def _cdf_untransformed_untruncated(self, x): + return norm.cdf(x, loc=self._loc, scale=self._scale) + + def _ppf_untransformed_untruncated(self, q): + return norm.ppf(q, loc=self._loc, scale=self._scale) + @property def loc(self): """The location parameter of the underlying distribution.""" @@ -177,9 +299,9 @@ def __init__( *, log: bool | float = False, ): - super().__init__(log=log) self._low = low self._high = high + super().__init__(log=log) def __repr__(self): log = f", log={self._logbase}" if self._logbase else "" @@ -191,13 +313,21 @@ def _sample(self, shape=None): def _pdf(self, x): return uniform.pdf(x, loc=self._low, scale=self._high - self._low) + def _cdf_untransformed_untruncated(self, x): + return uniform.cdf(x, loc=self._low, scale=self._high - self._low) + + def _ppf_untransformed_untruncated(self, q): + return uniform.ppf(q, loc=self._low, scale=self._high - self._low) + class Laplace(Distribution): """A (log-)Laplace distribution. :param loc: The location parameter of the distribution. :param scale: The scale parameter of the distribution. - :param truncation: The truncation limits of the distribution. + :param trunc: The truncation limits of the distribution. + ``None`` if the distribution is not truncated. The truncation limits + are the truncation limits of the transformed distribution. :param log: If ``True``, the distribution is transformed to a log-Laplace distribution. If a float, the distribution is transformed to a log-Laplace distribution with the given base. @@ -211,18 +341,15 @@ def __init__( self, loc: float, scale: float, - truncation: tuple[float, float] | None = None, + trunc: tuple[float, float] | None = None, log: bool | float = False, ): - super().__init__(log=log) self._loc = loc self._scale = scale - self._truncation = truncation - if truncation is not None: - raise NotImplementedError("Truncation is not yet implemented.") + super().__init__(log=log, trunc=trunc) def __repr__(self): - trunc = f", truncation={self._truncation}" if self._truncation else "" + trunc = f", trunc={self._trunc}" if self._trunc else "" log = f", log={self._logbase}" if self._logbase else "" return f"Laplace(loc={self._loc}, scale={self._scale}{trunc}{log})" @@ -232,6 +359,12 @@ def _sample(self, shape=None): def _pdf(self, x): return laplace.pdf(x, loc=self._loc, scale=self._scale) + def _cdf_untransformed_untruncated(self, x): + return laplace.cdf(x, loc=self._loc, scale=self._scale) + + def _ppf_untransformed_untruncated(self, q): + return laplace.ppf(q, loc=self._loc, scale=self._scale) + @property def loc(self): """The location parameter of the underlying distribution.""" diff --git a/petab/v1/priors.py b/petab/v1/priors.py index f0f37f75..8f29bed3 100644 --- a/petab/v1/priors.py +++ b/petab/v1/priors.py @@ -40,6 +40,9 @@ __all__ = ["priors_to_measurements"] +# TODO: does anybody really rely on the old behavior? +USE_PROPER_TRUNCATION = True + class Prior: """A PEtab parameter prior. @@ -88,26 +91,49 @@ def __init__( self._bounds = bounds self._transformation = transformation + truncation = bounds if USE_PROPER_TRUNCATION else None + if truncation is not None: + # for uniform, we don't want to implement truncation and just + # adapt the distribution parameters + if type_ == C.PARAMETER_SCALE_UNIFORM: + parameters = ( + max(parameters[0], scale(truncation[0], transformation)), + min(parameters[1], scale(truncation[1], transformation)), + ) + elif type_ == C.UNIFORM: + parameters = ( + max(parameters[0], truncation[0]), + min(parameters[1], truncation[1]), + ) + # create the underlying distribution match type_, transformation: case (C.UNIFORM, _) | (C.PARAMETER_SCALE_UNIFORM, C.LIN): self.distribution = Uniform(*parameters) case (C.NORMAL, _) | (C.PARAMETER_SCALE_NORMAL, C.LIN): - self.distribution = Normal(*parameters) + self.distribution = Normal(*parameters, trunc=truncation) case (C.LAPLACE, _) | (C.PARAMETER_SCALE_LAPLACE, C.LIN): - self.distribution = Laplace(*parameters) + self.distribution = Laplace(*parameters, trunc=truncation) case (C.PARAMETER_SCALE_UNIFORM, C.LOG): self.distribution = Uniform(*parameters, log=True) case (C.LOG_NORMAL, _) | (C.PARAMETER_SCALE_NORMAL, C.LOG): - self.distribution = Normal(*parameters, log=True) + self.distribution = Normal( + *parameters, log=True, trunc=truncation + ) case (C.LOG_LAPLACE, _) | (C.PARAMETER_SCALE_LAPLACE, C.LOG): - self.distribution = Laplace(*parameters, log=True) + self.distribution = Laplace( + *parameters, log=True, trunc=truncation + ) case (C.PARAMETER_SCALE_UNIFORM, C.LOG10): self.distribution = Uniform(*parameters, log=10) case (C.PARAMETER_SCALE_NORMAL, C.LOG10): - self.distribution = Normal(*parameters, log=10) + self.distribution = Normal( + *parameters, log=10, trunc=truncation + ) case (C.PARAMETER_SCALE_LAPLACE, C.LOG10): - self.distribution = Laplace(*parameters, log=10) + self.distribution = Laplace( + *parameters, log=10, trunc=truncation + ) case _: raise ValueError( "Unsupported distribution type / transformation: " @@ -149,9 +175,8 @@ def sample(self, shape=None) -> np.ndarray: def _scale_sample(self, sample): """Scale the sample to the parameter space""" - # if self.on_parameter_scale: - # return sample - + # we also need to scale paramterScale* distributions, because + # internally, they are handled as (unscaled) log-distributions return scale(sample, self.transformation) def _clip_to_bounds(self, x): @@ -159,8 +184,7 @@ def _clip_to_bounds(self, x): :param x: The values to clip. Assumed to be on the parameter scale. """ - # TODO: replace this by proper truncation - if self.bounds is None: + if self.bounds is None or USE_PROPER_TRUNCATION: return x return np.maximum( diff --git a/tests/v1/test_distributions.py b/tests/v1/test_distributions.py index 9df830fa..1f2ecf59 100644 --- a/tests/v1/test_distributions.py +++ b/tests/v1/test_distributions.py @@ -27,6 +27,11 @@ Uniform(2, 4, log=10), Laplace(1, 2), Laplace(1, 0.5, log=True), + Normal(2, 1, trunc=(1, 2)), + Normal(2, 1, log=True, trunc=(0.5, 8)), + Normal(2, 1, log=10), + Laplace(1, 2, trunc=(1, 2)), + Laplace(1, 0.5, log=True, trunc=(0.5, 8)), ], ) def test_sample_matches_pdf(distribution): @@ -53,34 +58,37 @@ def cdf(x): # Test samples match scipy CDFs reference_pdf = None - if isinstance(distribution, Normal) and distribution.logbase is False: - reference_pdf = norm.pdf(sample, distribution.loc, distribution.scale) - elif isinstance(distribution, Uniform) and distribution.logbase is False: - reference_pdf = uniform.pdf( - sample, distribution._low, distribution._high - distribution._low - ) - elif isinstance(distribution, Laplace) and distribution.logbase is False: - reference_pdf = laplace.pdf( - sample, distribution.loc, distribution.scale - ) - elif isinstance(distribution, Normal) and distribution.logbase == np.exp( - 1 - ): - reference_pdf = lognorm.pdf( - sample, scale=np.exp(distribution.loc), s=distribution.scale - ) - elif isinstance(distribution, Uniform) and distribution.logbase == np.exp( - 1 - ): - reference_pdf = loguniform.pdf( - sample, np.exp(distribution._low), np.exp(distribution._high) - ) - elif isinstance(distribution, Laplace) and distribution.logbase == np.exp( - 1 - ): - reference_pdf = loglaplace.pdf( - sample, c=1 / distribution.scale, scale=np.exp(distribution.loc) - ) + if distribution._trunc is None and distribution.logbase is False: + if isinstance(distribution, Normal): + reference_pdf = norm.pdf( + sample, distribution.loc, distribution.scale + ) + elif isinstance(distribution, Uniform): + reference_pdf = uniform.pdf( + sample, + distribution._low, + distribution._high - distribution._low, + ) + elif isinstance(distribution, Laplace): + reference_pdf = laplace.pdf( + sample, distribution.loc, distribution.scale + ) + + if distribution._trunc is None and distribution.logbase == np.exp(1): + if isinstance(distribution, Normal): + reference_pdf = lognorm.pdf( + sample, scale=np.exp(distribution.loc), s=distribution.scale + ) + elif isinstance(distribution, Uniform): + reference_pdf = loguniform.pdf( + sample, np.exp(distribution._low), np.exp(distribution._high) + ) + elif isinstance(distribution, Laplace): + reference_pdf = loglaplace.pdf( + sample, + c=1 / distribution.scale, + scale=np.exp(distribution.loc), + ) if reference_pdf is not None: assert_allclose( distribution.pdf(sample), reference_pdf, rtol=1e-10, atol=1e-14