Skip to content

Commit

Permalink
testing individual classo model
Browse files Browse the repository at this point in the history
  • Loading branch information
adamovanja committed May 21, 2024
1 parent 9195392 commit 5fb4c83
Showing 1 changed file with 123 additions and 30 deletions.
153 changes: 123 additions & 30 deletions experiments/implement_matrixA.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@
"outputs": [],
"source": [
"import numpy as np\n",
"import os\n",
"import pandas as pd\n",
"import qiime2 as q2\n",
"import skbio\n",
"import pickle\n",
"\n",
"from classo import classo_problem\n",
"from classo import Classo, classo_problem\n",
"from numpy import linalg\n",
"from qiime2.plugins import phylogeny\n",
"from skbio import TreeNode\n",
"\n",
"from q2_ritme.process_data import load_n_split_data\n",
"\n",
"%matplotlib inline\n",
Expand Down Expand Up @@ -446,15 +445,15 @@
},
"outputs": [],
"source": [
"# save model: A, label, alpha (includes selected_ft)\n",
"# todo: adjust path\n",
"path2out = \"test_model\"\n",
"if not os.path.exists(path2out):\n",
" os.makedirs(path2out)\n",
"# # save model: A, label, alpha (includes selected_ft)\n",
"# # todo: adjust path\n",
"# path2out = \"test_model\"\n",
"# if not os.path.exists(path2out):\n",
"# os.makedirs(path2out)\n",
"\n",
"# storing A w labels\n",
"df_A_with_labels = pd.DataFrame(A, columns=label, index=label[:nb_features])\n",
"df_A_with_labels.to_csv(os.path.join(path2out, \"matrix_a_w_labels.csv\"), index=True)"
"# # storing A w labels\n",
"# df_A_with_labels = pd.DataFrame(A, columns=label, index=label[:nb_features])\n",
"# df_A_with_labels.to_csv(os.path.join(path2out, \"matrix_a_w_labels.csv\"), index=True)"
]
},
{
Expand All @@ -465,18 +464,112 @@
},
"outputs": [],
"source": [
"# storing alpha w labels\n",
"idx_alpha = [\"intercept\"] + df_A_with_labels.columns.tolist()\n",
"df_alpha_with_labels = pd.DataFrame(alpha, columns=[\"alpha\"], index=idx_alpha)\n",
"df_alpha_with_labels.to_csv(\n",
" os.path.join(path2out, \"model_alpha_w_labels.csv\"), index=True\n",
"# # storing alpha w labels\n",
"# idx_alpha = [\"intercept\"] + df_A_with_labels.columns.tolist()\n",
"# df_alpha_with_labels = pd.DataFrame(alpha, columns=[\"alpha\"], index=idx_alpha)\n",
"# df_alpha_with_labels.to_csv(\n",
"# os.path.join(path2out, \"model_alpha_w_labels.csv\"), index=True\n",
"# )\n",
"\n",
"# # we can get selected features from alpha\n",
"# selected_ft_inf = df_alpha_with_labels[\n",
"# df_alpha_with_labels[\"alpha\"] != 0\n",
"# ].index.tolist()\n",
"# assert selected_ft_inf[1:] == selected_ft.tolist()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Solve with Classo directly"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# former alpha AFTER refit\n",
"# lam used in code = 0.001191103133283007 = problem.solution.CV.lambda_1SE\n",
"alpha"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"lam = problem.solution.CV.lambda_1SE\n",
"\n",
"matrices = (log_geom_trainval, np.ones((1, len(log_geom_trainval[0]))), y_train_val)\n",
"\n",
"method = \"Path-Alg\"\n",
"intercept = True\n",
"beta_norefit = Classo(\n",
" matrix=matrices, lam=lam, typ=\"R1\", meth=method, w=1 / nleaves, intercept=intercept\n",
")\n",
"# new alpha\n",
"beta_norefit"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def solve_unpenalized_least_squares(cmatrices, intercept=False):\n",
" # adapted from classo > misc_functions.py > unpenalised\n",
" if intercept:\n",
" A1, C1, y = cmatrices\n",
" A = np.concatenate([np.ones((len(A1), 1)), A1], axis=1)\n",
" C = np.concatenate([np.zeros((len(C1), 1)), C1], axis=1)\n",
" else:\n",
" A, C, y = cmatrices\n",
"\n",
" k = len(C)\n",
" d = len(A[0])\n",
" M1 = np.concatenate([A.T.dot(A), C.T], axis=1)\n",
" M2 = np.concatenate([C, np.zeros((k, k))], axis=1)\n",
" M = np.concatenate([M1, M2], axis=0)\n",
" b = np.concatenate([A.T.dot(y), np.zeros(k)])\n",
" sol = linalg.lstsq(M, b, rcond=None)[0]\n",
" beta = sol[:d]\n",
" return beta\n",
"\n",
"\n",
"def min_least_squares_solution(matrices, selected, intercept=False):\n",
" \"\"\"Minimum Least Squares solution for selected features.\"\"\"\n",
" # adapted from classo > misc_functions.py > min_LS\n",
" X, C, y = matrices\n",
" beta = np.zeros(len(selected))\n",
"\n",
" if intercept:\n",
" beta[selected] = solve_unpenalized_least_squares(\n",
" (X[:, selected[1:]], C[:, selected[1:]], y), intercept=selected[0]\n",
" )\n",
" else:\n",
" beta[selected] = solve_unpenalized_least_squares(\n",
" (X[:, selected], C[:, selected], y), intercept=False\n",
" )\n",
"\n",
" return beta"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"selected_param = abs(beta_norefit) > 1e-5\n",
"beta_refit = min_least_squares_solution(matrices, selected_param, intercept=intercept)\n",
"\n",
"\n",
"# we can get selected features from alpha\n",
"selected_ft_inf = df_alpha_with_labels[\n",
" df_alpha_with_labels[\"alpha\"] != 0\n",
"].index.tolist()\n",
"assert selected_ft_inf[1:] == selected_ft.tolist()"
"assert np.array_equal(alpha, beta_refit)"
]
},
{
Expand Down Expand Up @@ -515,13 +608,13 @@
},
"outputs": [],
"source": [
"# test pickle\n",
"# Create a dictionary to store the dataframes\n",
"model = {\"model\": df_alpha_with_labels, \"matrix_a\": df_A_with_labels}\n",
"# # test pickle\n",
"# # Create a dictionary to store the dataframes\n",
"# model = {\"model\": df_alpha_with_labels, \"matrix_a\": df_A_with_labels}\n",
"\n",
"# Serialize the dictionary to a pickle file\n",
"with open(\"data.pkl\", \"wb\") as file:\n",
" pickle.dump(model, file)"
"# # Serialize the dictionary to a pickle file\n",
"# with open(\"data.pkl\", \"wb\") as file:\n",
"# pickle.dump(model, file)"
]
},
{
Expand All @@ -532,8 +625,8 @@
},
"outputs": [],
"source": [
"with open(\"data.pkl\", \"rb\") as file:\n",
" model = pickle.load(file)"
"# with open(\"data.pkl\", \"rb\") as file:\n",
"# model = pickle.load(file)"
]
}
],
Expand Down

0 comments on commit 5fb4c83

Please sign in to comment.