Skip to content

Commit

Permalink
Add support checks for the link functions and distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-hirsch committed Feb 26, 2025
1 parent dcabf77 commit 733dfec
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 58 deletions.
11 changes: 11 additions & 0 deletions src/rolch/base/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ def dl2_dpp(
) -> np.ndarray:
"""Take the first derivative of the likelihood function with respect to both parameters."""

def _validate_links(self):
for param, link in self.links.items():
if link.link_support[0] < self.parameter_support[param][0]:
raise ValueError(
f"Lower bound of parameter support {param} does not match link function."
)
if link.link_support[1] > self.parameter_support[param][1]:
raise ValueError(
f"Upper bound of parameter support {param} does not match link function."
)

def _validate_dln_dpn_inputs(
self, y: np.ndarray, theta: np.ndarray, param: int
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions src/rolch/base/link.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Dict
from typing import Tuple

import numpy as np

Expand All @@ -9,7 +9,7 @@ class LinkFunction(ABC):

@property
@abstractmethod
def link_support(self) -> Dict[int, float]:
def link_support(self) -> Tuple[float, float]:
"""The support of the distribution."""
pass

Expand Down
22 changes: 11 additions & 11 deletions src/rolch/distributions/gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,29 +44,29 @@ class DistributionGamma(Distribution, ScipyMixin):
scale_link (LinkFunction, optional): The link function for $\sigma$. Defaults to LogLink().
"""

def __init__(
self,
loc_link: LinkFunction = LogLink(),
scale_link: LinkFunction = LogLink(),
) -> None:
self.loc_link: LinkFunction = loc_link
self.scale_link: LinkFunction = scale_link
self.links: dict[int, LinkFunction] = {0: self.loc_link, 1: self.scale_link}
self.corresponding_gamlss: str = "GA"

parameter_names = {0: "mu", 1: "sigma"}
parameter_support = {
0: (np.nextafter(0, 1), np.inf),
1: (np.nextafter(0, 1), np.inf),
}
distribution_support = (np.nextafter(0, 1), np.inf)

# Scipy equivalent and parameter mapping rolch -> scipy
scipy_dist = st.gamma
# Theta columns do not map 1:1 to scipy parameters for gamma
# So we have to overload theta_to_scipy_params
scipy_names = {}

def __init__(
self,
loc_link: LinkFunction = LogLink(),
scale_link: LinkFunction = LogLink(),
) -> None:
self.loc_link: LinkFunction = loc_link
self.scale_link: LinkFunction = scale_link
self.links: dict[int, LinkFunction] = {0: self.loc_link, 1: self.scale_link}
self.corresponding_gamlss: str = "GA"
self._validate_links()

def theta_to_scipy_params(self, theta: np.ndarray) -> dict:
"""Map GAMLSS Parameters to scipy parameters.
Expand Down
39 changes: 20 additions & 19 deletions src/rolch/distributions/johnsonsu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple
from typing import Dict, Tuple

import numpy as np
import scipy.stats as st
Expand All @@ -19,24 +19,6 @@ class DistributionJSU(Distribution, ScipyMixin):
3 : Tail behaviour
"""

def __init__(
self,
loc_link: LinkFunction = IdentityLink(),
scale_link: LinkFunction = LogLink(),
skew_link: LinkFunction = IdentityLink(),
tail_link: LinkFunction = LogLink(),
) -> None:
self.loc_link = loc_link
self.scale_link = scale_link
self.skew_link = skew_link
self.tail_link = tail_link
self.links = [
self.loc_link,
self.scale_link,
self.skew_link,
self.tail_link,
]

parameter_names = {0: "mu", 1: "sigma", 2: "nu", 3: "tau"}
parameter_support = {
0: (-np.inf, np.inf),
Expand All @@ -50,6 +32,25 @@ def __init__(
scipy_dist = st.johnsonsu
scipy_names = {"mu": "loc", "sigma": "scale", "nu": "a", "tau": "b"}

def __init__(
self,
loc_link: LinkFunction = IdentityLink(),
scale_link: LinkFunction = LogLink(),
skew_link: LinkFunction = IdentityLink(),
tail_link: LinkFunction = LogLink(),
) -> None:
self.loc_link = loc_link
self.scale_link = scale_link
self.skew_link = skew_link
self.tail_link = tail_link
self.links: Dict[int, LinkFunction] = {
0: self.loc_link,
1: self.scale_link,
2: self.skew_link,
3: self.tail_link,
}
self._validate_links()

def dl1_dp1(self, y: np.ndarray, theta: np.ndarray, param: int = 0) -> np.ndarray:
self._validate_dln_dpn_inputs(y, theta, param)
mu, sigma, nu, tau = self.theta_to_params(theta)
Expand Down
24 changes: 14 additions & 10 deletions src/rolch/distributions/normal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, Union
from typing import Dict, Tuple, Union

import numpy as np
import scipy.stats as st
Expand All @@ -11,15 +11,6 @@
class DistributionNormal(Distribution, ScipyMixin):
"""Corresponds to GAMLSS NO() and scipy.stats.norm()"""

def __init__(
self,
loc_link: LinkFunction = IdentityLink(),
scale_link: LinkFunction = LogLink(),
) -> None:
self.loc_link: LinkFunction = loc_link
self.scale_link: LinkFunction = scale_link
self.links: list[LinkFunction] = [self.loc_link, self.scale_link]

parameter_names = {0: "mu", 1: "sigma"}
parameter_support = {0: (-np.inf, np.inf), 1: (np.nextafter(0, 1), np.inf)}
distribution_support = (-np.inf, np.inf)
Expand All @@ -28,6 +19,19 @@ def __init__(
scipy_dist = st.norm
scipy_names = {"mu": "loc", "sigma": "scale"}

def __init__(
self,
loc_link: LinkFunction = IdentityLink(),
scale_link: LinkFunction = LogLink(),
) -> None:
self.loc_link: LinkFunction = loc_link
self.scale_link: LinkFunction = scale_link
self.links: Dict[int, LinkFunction] = {
0: self.loc_link,
1: self.scale_link,
}
self._validate_links()

def dl1_dp1(self, y: np.ndarray, theta: np.ndarray, param: int = 0) -> np.ndarray:
self._validate_dln_dpn_inputs(y, theta, param)
mu, sigma = self.theta_to_params(theta)
Expand Down
33 changes: 17 additions & 16 deletions src/rolch/distributions/studentt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import Dict, Tuple

import numpy as np
import scipy.special as sp
Expand All @@ -12,21 +12,6 @@
class DistributionT(Distribution, ScipyMixin):
"""Corresponds to GAMLSS TF() and scipy.stats.t()"""

def __init__(
self,
loc_link: LinkFunction = IdentityLink(),
scale_link: LinkFunction = LogLink(),
tail_link: LinkFunction = LogShiftTwoLink(),
) -> None:
self.loc_link: LinkFunction = loc_link
self.scale_link: LinkFunction = scale_link
self.tail_link: LinkFunction = tail_link
self.links: List[LinkFunction] = [
self.loc_link,
self.scale_link,
self.tail_link,
]

parameter_names = {0: "mu", 1: "sigma", 2: "nu"}
parameter_support = {
0: (-np.inf, np.inf),
Expand All @@ -39,6 +24,22 @@ def __init__(
scipy_dist = st.t
scipy_names = {"mu": "loc", "sigma": "scale", "nu": "df"}

def __init__(
self,
loc_link: LinkFunction = IdentityLink(),
scale_link: LinkFunction = LogLink(),
tail_link: LinkFunction = LogShiftTwoLink(),
) -> None:
self.loc_link: LinkFunction = loc_link
self.scale_link: LinkFunction = scale_link
self.tail_link: LinkFunction = tail_link
self.links: Dict[LinkFunction] = {
0: self.loc_link,
1: self.scale_link,
2: self.tail_link,
}
self._validate_links()

def dl1_dp1(self, y: np.ndarray, theta: np.ndarray, param: int = 0) -> np.ndarray:
self._validate_dln_dpn_inputs(y, theta, param)
mu, sigma, nu = self.theta_to_params(theta)
Expand Down
19 changes: 19 additions & 0 deletions src/rolch/link.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Tuple

import numpy as np

from rolch.base import LinkFunction
Expand All @@ -14,6 +16,8 @@ class LogLink(LinkFunction):
The log-link function is defined as \(g(x) = \log(x)\).
"""

link_support = (np.nextafter(0, 1), np.inf)

def __init__(self):
pass

Expand Down Expand Up @@ -43,6 +47,8 @@ class IdentityLink(LinkFunction):
The identity link is defined as \(g(x) = x\).
"""

link_support = (-np.inf, np.inf)

def __init__(self):
pass

Expand Down Expand Up @@ -75,6 +81,10 @@ class LogShiftValueLink(LinkFunction):
def __init__(self, value: float):
self.value = value

@property
def link_support(self) -> Tuple[float, float]:
return (self.value + np.nextafter(0, 1), np.inf)

def link(self, x: np.ndarray) -> np.ndarray:
return np.log(x - self.value + LOG_LOWER_BOUND)

Expand Down Expand Up @@ -114,6 +124,8 @@ class SqrtLink(LinkFunction):
The square root link function is defined as $$g(x) = \sqrt(x)$$.
"""

link_support = (np.nextafter(0, 1), np.inf)

def __init__(self):
pass

Expand Down Expand Up @@ -146,6 +158,10 @@ class SqrtShiftValueLink(LinkFunction):
def __init__(self, value: float):
self.value = value

@property
def link_support(self) -> Tuple[float, float]:
return (self.value + np.nextafter(0, 1), np.inf)

def link(self, x: np.ndarray) -> np.ndarray:
return np.sqrt(x - self.value + LOG_LOWER_BOUND)

Expand Down Expand Up @@ -185,6 +201,8 @@ class LogIdentLink(LinkFunction):
the estimation procedure.
"""

link_support = (np.nextafter(0, 1), np.inf)

def __init__(self):
pass

Expand Down Expand Up @@ -213,4 +231,5 @@ def link_second_derivative(self, x) -> np.ndarray:
"SqrtLink",
"SqrtShiftValueLink",
"SqrtShiftTwoLink",
"SqrtShiftTwoLink",
]

0 comments on commit 733dfec

Please sign in to comment.