Skip to content

Commit

Permalink
Normalize rss using mean weights in gamlss update
Browse files Browse the repository at this point in the history
  • Loading branch information
BerriJ committed Jul 17, 2024
1 parent c41bff1 commit 7c11956
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 26 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ development
literature/*
presentation/*
.cache
.devcontainer/*
.devcontainer/*
build.txt
44 changes: 22 additions & 22 deletions example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,16 @@
"LASSO Coefficients \n",
"\n",
"[[151.969 3.973 25. ]\n",
" [ -0. -0. -0. ]\n",
" [ -9.815 -0. -0. ]\n",
" [ 24.232 0.044 -0. ]\n",
" [ 13.786 0. 0. ]\n",
" [ -4.886 -0. -0. ]\n",
" [ -0. -0. 0. ]\n",
" [ -9.815 -0. 0. ]\n",
" [ 24.232 0.044 0. ]\n",
" [ 13.786 0. -0. ]\n",
" [ -4.886 -0. 0. ]\n",
" [ -0. 0. 0. ]\n",
" [-10.781 -0.026 0. ]\n",
" [ 0. 0. 0. ]\n",
" [-10.781 -0.026 -0. ]\n",
" [ 0. 0. -0. ]\n",
" [ 24.851 0. 0. ]\n",
" [ 2.192 0. 0. ]]\n"
" [ 2.192 0. -0. ]]\n"
]
}
],
Expand Down Expand Up @@ -210,17 +210,17 @@
"\n",
"Coefficients after update call \n",
"\n",
"[[152.144 3.933 3.475]\n",
" [ -0.71 -0.045 0.511]\n",
" [-12.521 -0.098 0.618]\n",
" [ 24.582 0.036 -0.223]\n",
" [ 15.341 0.067 -0.232]\n",
" [-33.336 -0.412 2.878]\n",
" [ 19.242 0.388 -2.365]\n",
" [ 2.281 -0.002 -0.51 ]\n",
" [ 7.311 -0.089 0.267]\n",
" [ 35.07 0.19 -1.478]\n",
" [ 2.397 0.044 -0.235]]\n"
"[[152.023 3.917 2.737]\n",
" [ -0. -0. 0. ]\n",
" [-10.774 -0. 0. ]\n",
" [ 24.566 0.035 -0.277]\n",
" [ 14.204 0. -0. ]\n",
" [ -5.537 0. 0. ]\n",
" [ -0. -0. 0. ]\n",
" [-10.838 -0. 0. ]\n",
" [ 0. 0. -0. ]\n",
" [ 25.509 0. -0. ]\n",
" [ 1.748 0. -0.143]]\n"
]
}
],
Expand All @@ -239,9 +239,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python (ROLCH Test)",
"display_name": "Python 3",
"language": "python",
"name": "myenv"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -253,7 +253,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
8 changes: 5 additions & 3 deletions src/rolch/online_gamlss.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def update_beta(
rss = (
(residuals**2).flatten() * w
+ (1 - self.forget) * (self.rss[param] * self.sum_of_weights[param])
) / (self.sum_of_weights[param] * (1 - self.forget) + w)
) / (self.mean_of_weights[param] * (1 - self.forget) + w)

elif (self.method == "lasso") & self.intercept_only[param]:
lambda_max = None
Expand All @@ -264,7 +264,7 @@ def update_beta(
rss = (
(residuals**2).flatten() * w
+ (1 - self.forget) * (self.rss[param] * self.sum_of_weights[param])
) / (self.sum_of_weights[param] * (1 - self.forget) + w)
) / (self.mean_of_weights[param] * (1 - self.forget) + w)

elif self.method == "lasso":
intercept = (
Expand Down Expand Up @@ -293,7 +293,7 @@ def update_beta(
rss = (
(residuals**2).flatten() * w
+ (1 - self.forget) * (self.rss[param] * self.sum_of_weights[param])
) / (self.sum_of_weights[param] * (1 - self.forget) + w)
) / (self.mean_of_weights[param] * (1 - self.forget) + w)

model_params_n = np.sum(np.isclose(beta_path, 0), axis=1)
best_ic = select_best_model_by_information_criterion(
Expand Down Expand Up @@ -405,6 +405,7 @@ def fit(
# distribution parameter for
# model selection online
self.sum_of_weights = {}
self.mean_of_weights = {}

(
self.betas,
Expand Down Expand Up @@ -779,6 +780,7 @@ def _inner_fit(

## Sum of weights
self.sum_of_weights[param] = np.sum(w * wt)
self.mean_of_weights[param] = np.mean(w * wt)

self.beta_iterations_inner[param][iteration_outer][iteration_inner] = beta
self.beta_path_iterations_inner[param][iteration_outer][
Expand Down

0 comments on commit 7c11956

Please sign in to comment.