Skip to content

Commit

Permalink
Merge pull request #17 from simon-hirsch/add_gamma_dist
Browse files Browse the repository at this point in the history
Add gamma distribution
  • Loading branch information
simon-hirsch authored Sep 11, 2024
2 parents c568892 + e38dc1c commit 4beb4d6
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 3 deletions.
4 changes: 3 additions & 1 deletion docs/distributions.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,6 @@ All distributions are based on `scipy.stats` distributions. We implement the pro

::: rolch.DistributionT

::: rolch.DistributionJSU
::: rolch.DistributionJSU

::: rolch.DistributionGamma
8 changes: 7 additions & 1 deletion src/rolch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
online_coordinate_descent_path,
soft_threshold,
)
from rolch.distributions import DistributionJSU, DistributionNormal, DistributionT
from rolch.distributions import (
DistributionGamma,
DistributionJSU,
DistributionNormal,
DistributionT,
)
from rolch.gram import (
init_forget_vector,
init_gram,
Expand Down Expand Up @@ -54,6 +59,7 @@
"DistributionNormal",
"DistributionT",
"DistributionJSU",
"DistributionGamma",
"init_forget_vector",
"init_gram",
"update_gram",
Expand Down
8 changes: 7 additions & 1 deletion src/rolch/distributions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from .gamma import DistributionGamma
from .johnsonsu import DistributionJSU
from .normal import DistributionNormal
from .studentt import DistributionT

__all__ = [DistributionNormal, DistributionT, DistributionJSU]
__all__ = [
"DistributionNormal",
"DistributionT",
"DistributionJSU",
"DistributionGamma",
]
143 changes: 143 additions & 0 deletions src/rolch/distributions/gamma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import numpy as np
import scipy.special as spc
import scipy.stats as st

from rolch.abc import Distribution, LinkFunction
from rolch.link import LogLink


class DistributionGamma(Distribution):
"""The Gamma Distribution for GAMLSS.
The distribution function is defined as in GAMLSS as:
$$
f(y|\mu,\sigma)=\\frac{y^{(1/\sigma^2-1)}\exp[-y/(\sigma^2 \mu)]}{(\sigma^2 \mu)^{(1/\sigma^2)} \Gamma(1/\sigma^2)}
$$
with the location and shape parameters $\mu, \sigma > 0$.
!!! Note
The function is parameterized as GAMLSS' GA() distribution.
This parameterization is different to the `scipy.stats.gamma(alpha, loc, scale)` parameterization.
We can use `DistributionGamma().gamlss_to_scipy(mu, sigma)` to map the distribution parameters to scipy.
The `scipy.stats.gamma()` distribution is defined as:
$$
f(x, \\alpha, \\beta) = \\frac{\\beta^\\alpha x^{\\alpha - 1} \exp[-\\beta x]}{\Gamma(\\alpha)}
$$
with the paramters $\\alpha, \\beta >0$. The parameters can be mapped as follows:
$$
\\alpha = 1/\sigma^2 \Leftrightarrow \sigma = \sqrt{1 / \\alpha}
$$
and
$$
\\beta = 1/(\sigma^2\mu).
$$
Args:
loc_link (LinkFunction, optional): The link function for $\mu$. Defaults to LogLink().
scale_link (LinkFunction, optional): The link function for $\sigma$. Defaults to LogLink().
"""

def __init__(
self, loc_link: LinkFunction = LogLink(), scale_link: LinkFunction = LogLink()
):
self.loc_link = loc_link
self.scale_link = scale_link
# Set up links as dict
self.links = {0: self.loc_link, 1: self.scale_link}
# Set distribution params
self.n_params = 2
self.corresponding_gamlss = "GA"
self.scipy_dist = st.gamma

def theta_to_params(self, theta):
mu = theta[:, 0]
sigma = theta[:, 1]
return mu, sigma

@staticmethod
def gamlss_to_scipy(mu: np.ndarray, sigma: np.ndarray):
"""Map GAMLSS Parameters to scipy parameters.
Args:
mu (np.ndarray): mu parameter
sigma (np.ndarray): sigma parameter
Returns:
tuple: Tuple of (alpha, loc, scale) for scipy.stats.gamma(alpha, loc, scale)
"""
alpha = 1 / sigma**2
beta = 1 / (sigma**2 * mu)
loc = 0
scale = 1 / beta
return alpha, loc, scale

def dl1_dp1(self, y, theta, param=0):
mu, sigma = self.theta_to_params(theta)

if param == 0:
return (y - mu) / ((sigma**2) * (mu**2))

if param == 1:
return (2 / sigma**3) * (
(y / mu)
- np.log(y)
+ np.log(mu)
+ np.log(sigma**2)
- 1
+ spc.digamma(1 / (sigma**2))
)

def dl2_dp2(self, y, theta, param=0):
mu, sigma = self.theta_to_params(theta)
if param == 0:
# MU
return -1 / ((sigma**2) * (mu**2))

if param == 1:
# SIGMA
return (4 / sigma**4) - (4 / sigma**6) * spc.polygamma(1, (1 / sigma**2))

def dl2_dpp(self, y, theta, params=(0, 1)):
if sorted(params) == [0, 1]:
return np.zeros_like(y)

def link_function(self, y, param=0):
return self.links[param].link(y)

def link_inverse(self, y, param=0):
return self.links[param].inverse(y)

def link_derivative(self, y, param=0):
return self.links[param].derivative(y)

def initial_values(self, y, param=0, axis=None):
if param == 0:
return (y + np.mean(y, axis=None)) / 2
if param == 1:
return np.ones_like(y)

def cdf(self, y, theta):
mu, sigma = self.theta_to_params(theta)
return self.scipy_dist(*self.gamlss_to_scipy(mu, sigma)).cdf(y)

def pdf(self, y, theta):
mu, sigma = self.theta_to_params(theta)
return self.scipy_dist(*self.gamlss_to_scipy(mu, sigma)).pdf(y)

def ppf(self, q, theta):
mu, sigma = self.theta_to_params(theta)
return self.scipy_dist(*self.gamlss_to_scipy(mu, sigma)).ppf(q)

def rvs(self, size, theta):
mu, sigma = self.theta_to_params(theta)
return (
self.scipy_dist(*self.gamlss_to_scipy(mu, sigma))
.rvs((size, theta.shape[0]))
.T
)

0 comments on commit 4beb4d6

Please sign in to comment.