Skip to content

Commit

Permalink
Merge pull request #35 from simon-hirsch/links_derivatives
Browse files Browse the repository at this point in the history
Links derivatives
  • Loading branch information
simon-hirsch authored Dec 23, 2024
2 parents e8c14c8 + 279ab8a commit 7f1eaaf
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
7 changes: 6 additions & 1 deletion src/rolch/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
update_inverted_gram,
update_y_gram,
)
from rolch.utils import handle_param_dict

if HAS_PANDAS:
import pandas as pd
Expand All @@ -37,9 +36,15 @@ def link_derivative(self, x: np.ndarray) -> np.ndarray:
"""Calculate the first derivative of the link function"""
raise NotImplementedError("Currently not implemented. Will be needed for GLMs")

@abstractmethod
def link_second_derivative(self, x: np.ndarray) -> np.ndarray:
"""Calculate the second derivative for the link function"""
raise NotImplementedError("Currently not implemented.")

@abstractmethod
def inverse_derivative(self, x: np.ndarray) -> np.ndarray:
"""Calculate the first derivative for the inverse link function"""
raise NotImplementedError("Currently not implemented.")


class Distribution(ABC):
Expand Down
22 changes: 20 additions & 2 deletions src/rolch/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def inverse_derivative(self, x: np.ndarray) -> np.ndarray:
def link_derivative(self, x: np.ndarray) -> np.ndarray:
return 1 / x

def link_second_derivative(self, x: np.ndarray) -> np.ndarray:
return -1 / x**2


class IdentityLink(LinkFunction):
"""
Expand All @@ -55,6 +58,9 @@ def inverse_derivative(self, x: np.ndarray) -> np.ndarray:
def link_derivative(self, x: np.ndarray) -> np.ndarray:
return np.ones_like(x)

def link_second_derivative(self, x: np.ndarray) -> np.ndarray:
return np.zeros_like(x)


class LogShiftValueLink(LinkFunction):
"""
Expand All @@ -81,7 +87,10 @@ def inverse_derivative(self, x: np.ndarray) -> np.ndarray:
return np.fmax(np.exp(np.fmin(x, EXP_UPPER_BOUND)), LOG_LOWER_BOUND)

def link_derivative(self, x: np.ndarray) -> np.ndarray:
return super().link_derivative(x)
return 1 / (x - self.value)

def link_second_derivative(self, x: np.ndarray) -> np.ndarray:
return -1 / (x - self.value) ** 2


class LogShiftTwoLink(LogShiftValueLink):
Expand Down Expand Up @@ -120,6 +129,9 @@ def inverse_derivative(self, x: np.ndarray) -> np.ndarray:
def link_derivative(self, x: np.ndarray) -> np.ndarray:
return 1 / (2 * np.sqrt(x))

def link_second_derivative(self, x) -> np.ndarray:
return -1 / (4 * x ** (3 / 2))


class SqrtShiftValueLink(LinkFunction):
"""
Expand All @@ -144,7 +156,10 @@ def inverse_derivative(self, x: np.ndarray) -> np.ndarray:
return 2 * x

def link_derivative(self, x: np.ndarray) -> np.ndarray:
return super().link_derivative(x)
return 1 / (2 * np.sqrt(x - self.value))

def link_second_derivative(self, x) -> np.ndarray:
return -1 / (4 * (x - self.value) ** (3 / 2))


class SqrtShiftTwoLink(SqrtShiftValueLink):
Expand Down Expand Up @@ -185,6 +200,9 @@ def inverse_derivative(self, x: np.ndarray) -> np.ndarray:
def link_derivative(self, x: np.ndarray) -> np.ndarray:
return super().link_derivative(x)

def link_second_derivative(self, x) -> np.ndarray:
return super().link_second_derivative(x)


__all__ = [
"LogLink",
Expand Down

0 comments on commit 7f1eaaf

Please sign in to comment.