Skip to content

Commit

Permalink
Update docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax authored Mar 20, 2024
1 parent b4326c2 commit d0aab23
Showing 1 changed file with 42 additions and 32 deletions.
74 changes: 42 additions & 32 deletions gpax/priors/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,15 @@ def normal_dist(loc: float = None, scale: float = None
Generate a Normal distribution based on provided center (loc) and standard deviation (scale) parameters.
If neither are provided, uses 0 and 1 by default. It can be used to pass custom priors to GP models.
Example:
Examples:
Assign custom prior to kernel lengthscale during GP model initialization
>>> model = gpax.ExactGP(input_dim, kernel, lengthscale_prior_dist=gpax.priors.normal_dist(5, 1))
Train as usual
>>> model.fit(rng_key, X, y)
Assign custom prior to kernel lengthscale during GP model initialization
>>> model = gpax.ExactGP(input_dim, kernel, lengthscale_prior_dist=gpax.priors.normal_dist(5, 1))
Train as usual
>>> model.fit(rng_key, X, y)
"""
loc = loc if loc is not None else 0.0
Expand All @@ -95,15 +95,15 @@ def lognormal_dist(loc: float = None, scale: float = None) -> numpyro.distributi
Generate a LogNormal distribution based on provided center (loc) and standard deviation (scale) parameters.
If neither are provided, uses 0 and 1 by default. It can be used to pass custom priors to GP models.
Example:
Assign custom prior to kernel lengthscale during GP model initialization
Examples:
>>> model = gpax.ExactGP(input_dim, kernel, lengthscale_prior_dist=gpax.priors.lognormal_dist(0, 0.1))
Train as usual
>>> model.fit(rng_key, X, y)
Assign custom prior to kernel lengthscale during GP model initialization
>>> model = gpax.ExactGP(input_dim, kernel, lengthscale_prior_dist=gpax.priors.lognormal_dist(0, 0.1))
Train as usual
>>> model.fit(rng_key, X, y)
"""
loc = loc if loc is not None else 0.0
Expand All @@ -116,15 +116,15 @@ def halfnormal_dist(scale: float = None) -> numpyro.distributions.Distribution:
Generate a half-normal distribution based on provided standard deviation (scale).
If none is provided, uses 1.0 by default. It can be used to pass custom priors to GP models.
Example:
Assign custom prior to noise variance during GP model initialization
Examples:
>>> model = gpax.ExactGP(input_dim, kernel, noise_prior_dist=gpax.priors.halfnormal_dist(0.1))
Train as usual
>>> model.fit(rng_key, X, y)
Assign custom prior to noise variance during GP model initialization
>>> model = gpax.ExactGP(input_dim, kernel, noise_prior_dist=gpax.priors.halfnormal_dist(0.1))
Train as usual
>>> model.fit(rng_key, X, y)
"""
scale = scale if scale is not None else 1.0
Expand All @@ -140,15 +140,15 @@ def gamma_dist(c: float = None,
it attempts to infer it using the range of the input vector divided by 2. The rate parameter defaults to 1.0 if not provided.
It can be used to pass custom priors to GP models.
Example:
Assign custom prior to kernel lengthscale during GP model initialization
>>> model = gpax.ExactGP(input_dm, kernel, lengthscale_prior_dist=gpax.priors.gamma_dist(2, 5))
Examples:
Train as usual
>>> model.fit(rng_key, X, y)
Assign custom prior to kernel lengthscale during GP model initialization
>>> model = gpax.ExactGP(input_dm, kernel, lengthscale_prior_dist=gpax.priors.gamma_dist(2, 5))
Train as usual
>>> model.fit(rng_key, X, y)
"""
if c is None:
Expand All @@ -169,6 +169,16 @@ def uniform_dist(low: float = None,
Generate a Uniform distribution based on provided low and high bounds. If one of the bounds is not provided,
it attempts to infer the missing bound(s) using the minimum or maximum value from the input vector.
It can be used to pass custom priors to GP models.
Examples:
Assign custom prior to kernel lengthscale during GP model initialization
>>> model = gpax.ExactGP(input_dm, kernel, lengthscale_prior_dist=gpax.priors.uniform_dist(1, 3))
Train as usual
>>> model.fit(rng_key, X, y)
"""
if (low is None or high is None) and input_vec is None:
raise ValueError(
Expand Down

0 comments on commit d0aab23

Please sign in to comment.