Skip to content

Commit

Permalink
Merge pull request #28 from simon-hirsch/correct_names_for_link_deriv…
Browse files Browse the repository at this point in the history
…atives

Correct names for link derivatives
  • Loading branch information
simon-hirsch authored Nov 19, 2024
2 parents 2147be3 + 412a0b3 commit a6e65a4
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 33 deletions.
1 change: 1 addition & 0 deletions docs/links.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ For all link functions, we implement

- the link \(g(x)\)
- the inverse \(g^{-1}(x)\)
- the derivative of the link function $\frac{\partial g(x)}{\partial x}$.
- the first derivative _of the inverse_ of the link function \(\frac{\partial g(x)^{-1}}{\partial x}\). The choice of the inverse is justified by Equation (7) in Hirsch, Berrisch & Ziel ([2024](https://github.com/simon-hirsch/rolch/blob/main/paper.pdf)).

The link functions implemented in `ROLCH` implemenent these as class methods each. Currently, we have implemented the identity-link, log-link and shifted log-link functions.
Expand Down
13 changes: 11 additions & 2 deletions src/rolch/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,13 @@ def inverse(self, x: np.ndarray) -> np.ndarray:
"""Calculate the inverse of the link function"""

@abstractmethod
def derivative(self, x: np.ndarray) -> np.ndarray:
def link_derivative(self, x: np.ndarray) -> np.ndarray:
"""Calculate the first derivative of the link function"""
raise NotImplementedError("Currently not implemented. Will be needed for GLMs")

@abstractmethod
def inverse_derivative(self, x: np.ndarray) -> np.ndarray:
"""Calculate the first derivative for the inverse link function"""


class Distribution(ABC):
Expand Down Expand Up @@ -66,9 +71,13 @@ def link_inverse(self, y: np.ndarray, param: int = 0) -> np.ndarray:
"""Apply the inverse of the link function for param on y."""

@abstractmethod
def link_derivative(self, y: np.ndarray, param: int = 0) -> np.ndarray:
def link_function_derivative(self, y: np.ndarray, param: int = 0) -> np.ndarray:
"""Apply the derivative of the link function for param on y."""

@abstractmethod
def link_inverse_derivative(self, y: np.ndarray, param: int = 0) -> np.ndarray:
"""Apply the derivative of the inverse link function for param on y."""

@abstractmethod
def initial_values(
self, y: np.ndarray, param: int = 0, axis: int = None
Expand Down
7 changes: 5 additions & 2 deletions src/rolch/distributions/gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,11 @@ def link_function(self, y, param=0):
def link_inverse(self, y, param=0):
return self.links[param].inverse(y)

def link_derivative(self, y, param=0):
return self.links[param].derivative(y)
def link_function_derivative(self, y: np.ndarray, param: int = 0) -> np.ndarray:
return self.links[param].link_derivative(y)

def link_inverse_derivative(self, y: np.ndarray, param: int = 0) -> np.ndarray:
return self.links[param].inverse_derivative(y)

def initial_values(self, y, param=0, axis=None):
if param == 0:
Expand Down
7 changes: 5 additions & 2 deletions src/rolch/distributions/johnsonsu.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,11 @@ def link_function(self, y, param=0):
def link_inverse(self, y, param=0):
return self.links[param].inverse(y)

def link_derivative(self, y, param=0):
return self.links[param].derivative(y)
def link_function_derivative(self, y: np.ndarray, param: int = 0) -> np.ndarray:
return self.links[param].link_derivative(y)

def link_inverse_derivative(self, y: np.ndarray, param: int = 0) -> np.ndarray:
return self.links[param].inverse_derivative(y)

def initial_values(self, y, param=0, axis=None):
if param == 0:
Expand Down
7 changes: 5 additions & 2 deletions src/rolch/distributions/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,11 @@ def link_function(self, y, param=0):
def link_inverse(self, y, param=0):
return self.links[param].inverse(y)

def link_derivative(self, y, param=0):
return self.links[param].derivative(y)
def link_function_derivative(self, y: np.ndarray, param: int = 0) -> np.ndarray:
return self.links[param].link_derivative(y)

def link_inverse_derivative(self, y: np.ndarray, param: int = 0) -> np.ndarray:
return self.links[param].inverse_derivative(y)

def initial_values(self, y, param=0, axis=None):
if param == 0:
Expand Down
7 changes: 5 additions & 2 deletions src/rolch/distributions/studentt.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,11 @@ def link_function(self, y, param=0):
def link_inverse(self, y, param=0):
return self.links[param].inverse(y)

def link_derivative(self, y, param=0):
return self.links[param].derivative(y)
def link_function_derivative(self, y: np.ndarray, param: int = 0) -> np.ndarray:
return self.links[param].link_derivative(y)

def link_inverse_derivative(self, y: np.ndarray, param: int = 0) -> np.ndarray:
return self.links[param].inverse_derivative(y)

def initial_values(self, y, param=0, axis=None):
if param == 0:
Expand Down
58 changes: 38 additions & 20 deletions src/rolch/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,21 @@ class LogLink(LinkFunction):
def __init__(self):
pass

def link(self, x):
def link(self, x: np.ndarray) -> np.ndarray:
return np.log(np.fmax(x, LOG_LOWER_BOUND))

def inverse(self, x):
def inverse(self, x: np.ndarray) -> np.ndarray:
return np.fmax(
np.exp(np.fmin(x, EXP_UPPER_BOUND)),
LOG_LOWER_BOUND,
)

def derivative(self, x):
def inverse_derivative(self, x: np.ndarray) -> np.ndarray:
return np.exp(np.fmin(x, EXP_UPPER_BOUND))

def link_derivative(self, x: np.ndarray) -> np.ndarray:
return 1 / x


class IdentityLink(LinkFunction):
"""
Expand All @@ -40,13 +43,16 @@ class IdentityLink(LinkFunction):
def __init__(self):
pass

def link(self, x):
def link(self, x: np.ndarray) -> np.ndarray:
return x

def inverse(self, x):
def inverse(self, x: np.ndarray) -> np.ndarray:
return x

def derivative(self, x):
def inverse_derivative(self, x: np.ndarray) -> np.ndarray:
return np.ones_like(x)

def link_derivative(self, x: np.ndarray) -> np.ndarray:
return np.ones_like(x)


Expand All @@ -60,20 +66,23 @@ class LogShiftValueLink(LinkFunction):
don't fall below 2, hence ensuring that the variance exists.
"""

def __init__(self, value):
def __init__(self, value: float):
self.value = value

def link(self, x):
def link(self, x: np.ndarray) -> np.ndarray:
return np.log(x - self.value + LOG_LOWER_BOUND)

def inverse(self, x):
def inverse(self, x: np.ndarray) -> np.ndarray:
return self.value + np.fmax(
np.exp(np.fmin(x, EXP_UPPER_BOUND)), LOG_LOWER_BOUND
)

def derivative(self, x):
def inverse_derivative(self, x: np.ndarray) -> np.ndarray:
return np.fmax(np.exp(np.fmin(x, EXP_UPPER_BOUND)), LOG_LOWER_BOUND)

def link_derivative(self, x: np.ndarray) -> np.ndarray:
return super().link_derivative(x)


class LogShiftTwoLink(LogShiftValueLink):
"""
Expand All @@ -99,15 +108,18 @@ class SqrtLink(LinkFunction):
def __init__(self):
pass

def link(self, x):
def link(self, x: np.ndarray) -> np.ndarray:
return np.sqrt(x)

def inverse(self, x):
def inverse(self, x: np.ndarray) -> np.ndarray:
return np.power(x, 2)

def derivative(self, x):
def inverse_derivative(self, x: np.ndarray) -> np.ndarray:
return 2 * x

def link_derivative(self, x: np.ndarray) -> np.ndarray:
return 1 / (2 * np.sqrt(x))


class SqrtShiftValueLink(LinkFunction):
"""
Expand All @@ -119,18 +131,21 @@ class SqrtShiftValueLink(LinkFunction):
don't fall below 2, hence ensuring that the variance exists.
"""

def __init__(self, value):
def __init__(self, value: float):
self.value = value

def link(self, x):
def link(self, x: np.ndarray) -> np.ndarray:
return np.sqrt(x - self.value + LOG_LOWER_BOUND)

def inverse(self, x):
def inverse(self, x: np.ndarray) -> np.ndarray:
return self.value + np.power(np.fmin(x, EXP_UPPER_BOUND), 2)

def derivative(self, x):
def inverse_derivative(self, x: np.ndarray) -> np.ndarray:
return 2 * x

def link_derivative(self, x: np.ndarray) -> np.ndarray:
return super().link_derivative(x)


class SqrtShiftTwoLink(SqrtShiftValueLink):
"""
Expand Down Expand Up @@ -158,15 +173,18 @@ class LogIdentLink(LinkFunction):
def __init__(self):
pass

def link(self, x: np.ndarray):
def link(self, x: np.ndarray) -> np.ndarray:
return np.where(x <= 1, np.log(x), x - 1)

def inverse(self, x: np.ndarray):
def inverse(self, x: np.ndarray) -> np.ndarray:
return np.where(x <= 0, np.exp(x), x + 1)

def derivative(self, x: np.ndarray):
def inverse_derivative(self, x: np.ndarray) -> np.ndarray:
return np.where(x <= 0, np.exp(x), 1)

def link_derivative(self, x: np.ndarray) -> np.ndarray:
return super().link_derivative(x)


__all__ = [
"LogLink",
Expand Down
4 changes: 2 additions & 2 deletions src/rolch/online_gamlss.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ def _inner_fit(
iteration_inner += 1
eta = self.distribution.link_function(fv[:, param], param=param)
# if iteration == 1:
dr = 1 / self.distribution.link_derivative(eta, param=param)
dr = 1 / self.distribution.link_inverse_derivative(eta, param=param)
# mu, sigma, nu vs. fv?
dl1dp1 = self.distribution.dl1_dp1(y, fv, param=param)
dl2dp2 = self.distribution.dl2_dp2(y, fv, param=param)
Expand Down Expand Up @@ -806,7 +806,7 @@ def _inner_update(

iteration_inner += 1
eta = self.distribution.link_function(fv[:, param], param=param)
dr = 1 / self.distribution.link_derivative(eta, param=param)
dr = 1 / self.distribution.link_inverse_derivative(eta, param=param)
# mu, sigma, nu vs. fv?
dl1dp1 = self.distribution.dl1_dp1(y, fv, param=param)
dl2dp2 = self.distribution.dl2_dp2(y, fv, param=param)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_link_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_link_positive_line(linkfun):

@pytest.mark.parametrize("linkfun", SHIFTED_LINKS)
@pytest.mark.parametrize("value", VALUES)
def test_link_positive_line(linkfun, value):
def test_link_positive_shifted_line(linkfun, value):
"""Test links that are shifted. This changes the domain of the links."""
instance = linkfun(value)
x = np.linspace(value, 100 + value, M)
Expand Down

0 comments on commit a6e65a4

Please sign in to comment.