diff --git a/.github/codecov.yml b/.github/codecov.yml new file mode 100644 index 0000000..db24720 --- /dev/null +++ b/.github/codecov.yml @@ -0,0 +1 @@ +comment: off diff --git a/ci/recipe/meta.yaml b/ci/recipe/meta.yaml index bbeb269..c54c1ff 100644 --- a/ci/recipe/meta.yaml +++ b/ci/recipe/meta.yaml @@ -22,9 +22,11 @@ requirements: - importlib-metadata - qiime2 {{ qiime2_epoch }}.* - q2-feature-table {{ qiime2_epoch }}.* + - q2-phylogeny {{ qiime2_epoch }}.* - lightning - mlflow - numpy + - packaging - pandas - pip - pytorch @@ -44,6 +46,9 @@ requirements: run_constrained: - pip: - coral_pytorch + - c-lasso + # grpcio pinned due to incompatibility with ray caused by c-lasso + - grpcio==1.51.1 test: diff --git a/experiments/implement_matrixA.ipynb b/experiments/implement_matrixA.ipynb new file mode 100644 index 0000000..bc9378f --- /dev/null +++ b/experiments/implement_matrixA.ipynb @@ -0,0 +1,654 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import qiime2 as q2\n", + "import skbio\n", + "from classo import Classo, classo_problem\n", + "from numpy import linalg\n", + "from qiime2.plugins import phylogeny\n", + "from skbio import TreeNode\n", + "\n", + "from q2_ritme.process_data import load_n_split_data\n", + "\n", + "%matplotlib inline\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_matrix_from_tree(tree):\n", + " # Get all leaves and create a mapping from leaf names to indices\n", + " leaves = list(tree.tips())\n", + " leaf_names = [leaf.name for leaf in leaves]\n", + " # map each leaf name to unique index\n", + " leaf_index_map = {name: idx for idx, name in enumerate(leaf_names)}\n", + "\n", + " # Get the number of leaves and internal nodes\n", + " num_leaves = len(leaf_names)\n", + " # root is not included\n", + " internal_nodes = list(tree.non_tips())\n", + "\n", + " # Create the identity matrix for the leaves: A1 (num_leaves x num_leaves)\n", + " A1 = np.eye(num_leaves)\n", + "\n", + " # Create the matrix for the internal nodes: A2 (num_leaves x\n", + " # num_internal_nodes)\n", + " # initialise it with zeros\n", + " A2 = np.zeros((num_leaves, len(internal_nodes)))\n", + "\n", + " # Populate A2 with 1s for the leaves linked by each internal node\n", + " # iterate over all internal nodes to find descendents of this node and mark\n", + " # them accordingly\n", + " a2_node_names = []\n", + " for j, node in enumerate(internal_nodes):\n", + " # todo: adjust names to consensus taxonomy from descentents\n", + " # for now node names are just increasing integers - since node.name is float\n", + " a2_node_names.append(\"n\" + str(j))\n", + " descendant_leaves = {leaf.name for leaf in node.tips()}\n", + " for leaf_name in leaf_names:\n", + " if leaf_name in descendant_leaves:\n", + " A2[leaf_index_map[leaf_name], j] = 1\n", + "\n", + " # Concatenate A1 and A2 to create the final matrix A\n", + " A = np.hstack((A1, A2))\n", + "\n", + " return A, a2_node_names" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "# Create the tree nodes with lengths\n", + "n1 = TreeNode(name=\"n1\")\n", + "f1 = TreeNode(name=\"f1\", length=1.0)\n", + "f2 = TreeNode(name=\"f2\", length=1.0)\n", + "n2 = TreeNode(name=\"n2\")\n", + "f3 = TreeNode(name=\"f3\", length=1.0)\n", + "\n", + "# Build the tree structure with lengths\n", + "n1.extend([f1, f2])\n", + "n2.extend([n1, f3])\n", + "n1.length = 1.0\n", + "n2.length = 1.0\n", + "\n", + "# n2 is the root of this tree\n", + "tree = n2\n", + "print(tree.ascii_art())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "A_example, a2_names_ex = create_matrix_from_tree(tree)\n", + "A_example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "a2_names_ex" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Real data: MA2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "# read feature table\n", + "art_feature_table = q2.Artifact.load(\"data/220728_monthly/all_otu_table_filt.qza\")\n", + "df_ft = art_feature_table.view(pd.DataFrame)\n", + "df_ft.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "# read taxonomy\n", + "path_to_taxonomy = \"data/220728_monthly/otu_taxonomy_all.qza\"\n", + "art_taxonomy = q2.Artifact.load(path_to_taxonomy)\n", + "df_taxonomy = art_taxonomy.view(pd.DataFrame)\n", + "print(df_taxonomy.shape)\n", + "\n", + "# Filter the taxonomy based on the feature table\n", + "df_taxonomy_f = df_taxonomy[df_taxonomy.index.isin(df_ft.columns.tolist())]\n", + "print(df_taxonomy_f.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "df_taxonomy_f" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "# read silva phylo tree\n", + "path_to_phylo = \"data/220728_monthly/silva-138-99-rooted-tree.qza\"\n", + "art_phylo = q2.Artifact.load(path_to_phylo)\n", + "tree_phylo = art_phylo.view(skbio.TreeNode)\n", + "# total nodes\n", + "tree_phylo.count()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "# filter tree by feature table: this prunes a phylogenetic tree to match the\n", + "# input ids\n", + "(art_phylo_f,) = phylogeny.actions.filter_tree(tree=art_phylo, table=art_feature_table)\n", + "tree_phylo_f = art_phylo_f.view(skbio.TreeNode)\n", + "\n", + "# total nodes\n", + "tree_phylo_f.count()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# ensure that # leaves in tree == feature table dimension\n", + "num_leaves = tree_phylo_f.count(tips=True)\n", + "assert num_leaves == df_ft.shape[1]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "A, a2_names = create_matrix_from_tree(tree_phylo_f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "# verification\n", + "# no all 1 in one column\n", + "assert not np.any(np.all(A == 1.0, axis=0))\n", + "\n", + "# shape should be = feature_count + node_count\n", + "nb_features = df_ft.shape[1]\n", + "nb_non_leaf_nodes = len(list(tree_phylo_f.non_tips()))\n", + "\n", + "assert nb_features + nb_non_leaf_nodes == A.shape[1]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run trac with this" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "# load metadata\n", + "target = \"age_months\"\n", + "train_val, test, _tx, _phlo = load_n_split_data(\n", + " path2md=\"data/220728_monthly/metadata_proc_v20240323_r0_r3_le_2yrs.tsv\",\n", + " path2ft=\"data/220728_monthly/all_otu_table_filt.qza\",\n", + " path2tax=\"data/220728_monthly/otu_taxonomy_all.qza\",\n", + " path2phylo=\"data/220728_monthly/silva-138-99-rooted-tree.qza\",\n", + " host_id=\"host_id\",\n", + " target=target,\n", + " train_size=0.8,\n", + " seed=12,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "# preprocess taxonomy aggregation\n", + "def _preprocess_taxonomy_aggregation(x, A):\n", + " pseudo_count = 0.000001\n", + " # ? what happens if x is relative abundances\n", + " X = np.log(pseudo_count + x)\n", + " nleaves = np.sum(A, axis=0)\n", + " log_geom = X.dot(A) / nleaves\n", + "\n", + " return log_geom, nleaves" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "# perform preprocessing on train\n", + "ft_cols = [x for x in train_val.columns if x.startswith(\"F\")]\n", + "x_train_val = train_val[ft_cols]\n", + "y_train_val = train_val[target]\n", + "# todo: afterwards perform it on test\n", + "log_geom_trainval, nleaves = _preprocess_taxonomy_aggregation(x_train_val.values, A)\n", + "\n", + "n, d = log_geom_trainval.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "# get labels from taxonomy\n", + "# change labels to match new feature names\n", + "df_taxonomy_f.index = df_taxonomy_f.index.map(lambda x: \"F\" + str(x))\n", + "\n", + "# todo: add proper A2 labels for A -> for now it's just n + count\n", + "label = df_taxonomy_f[\"Taxon\"].values\n", + "label_short = np.array([la.split(\";\")[-1].strip() for la in label])\n", + "assert len(label) == len(ft_cols)\n", + "assert len(label) == len(label_short)\n", + "label = np.append(label, a2_names)\n", + "label_short = np.append(label_short, a2_names)\n", + "\n", + "assert len(label_short) == A.shape[1]\n", + "label_short" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "# perform CV classo: trac\n", + "problem = classo_problem(log_geom_trainval, y_train_val.values, label=label_short)\n", + "\n", + "problem.formulation.w = 1 / nleaves\n", + "problem.formulation.intercept = True\n", + "problem.formulation.concomitant = False # not relevant for here\n", + "\n", + "# ! one form of model selection needs to be chosen\n", + "# stability selection: for pre-selected range of lambda find beta paths\n", + "problem.model_selection.StabSel = False\n", + "# calculate coefficients for a grid of lambdas\n", + "problem.model_selection.PATH = False\n", + "# todo: check if it is fair that trac is trained with CV internally whereas others are not\n", + "# lambda values checked with CV are `Nlam` points between 1 and `lamin`, with\n", + "# logarithm scale or not depending on `logscale`.\n", + "problem.model_selection.CV = True\n", + "problem.model_selection.CVparameters.seed = (\n", + " 6 # one could change logscale, Nsubset, oneSE\n", + ")\n", + "# 'one-standard-error' = select simplest model (largest lambda value) in CV\n", + "# whose CV score is within 1 stddev of best score\n", + "# ! create hyperparameter for this\n", + "problem.model_selection.CVparameters.oneSE = True\n", + "# ! create hyperparameter for this\n", + "problem.model_selection.CVparameters.Nlam = 80\n", + "# ! create hyperparameter for this\n", + "problem.model_selection.CVparameters.lamin = 0.001\n", + "\n", + "# ! for ritme: no feature_transformation to be used for trac\n", + "print(problem)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "problem.solve()\n", + "# todo: find out how to extract the insights from the model to disk without changing classo\n", + "print(problem.solution)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "# alpha [0] is learned intercept, alpha [1:] are learned coefficients for all features\n", + "# in logGeom (n_samples, n_features)\n", + "# ! if oneSE=True -> uses lambda_1SE else lambda_min (see CV in\n", + "# ! classo>cross_validation.py)\n", + "# refit -> solves unconstrained least squares problem with selected lambda and\n", + "# variables\n", + "alpha = problem.solution.CV.refit" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "# ! class solution_CV: defined in @solver.py L930\n", + "selection = problem.solution.CV.selected_param[1:] # exclude the intercept\n", + "selected_ft = label[selection]\n", + "print(selected_ft)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "# # selected lambda with 1-standard-error method\n", + "# problem.solution.CV.lambda_1SE\n", + "\n", + "# # selected lambda without 1-standard-error method\n", + "# problem.solution.CV.lambda_min" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "# # save model: A, label, alpha (includes selected_ft)\n", + "# # todo: adjust path\n", + "# path2out = \"test_model\"\n", + "# if not os.path.exists(path2out):\n", + "# os.makedirs(path2out)\n", + "\n", + "# # storing A w labels\n", + "# df_A_with_labels = pd.DataFrame(A, columns=label, index=label[:nb_features])\n", + "# df_A_with_labels.to_csv(os.path.join(path2out, \"matrix_a_w_labels.csv\"), index=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "# # storing alpha w labels\n", + "# idx_alpha = [\"intercept\"] + df_A_with_labels.columns.tolist()\n", + "# df_alpha_with_labels = pd.DataFrame(alpha, columns=[\"alpha\"], index=idx_alpha)\n", + "# df_alpha_with_labels.to_csv(\n", + "# os.path.join(path2out, \"model_alpha_w_labels.csv\"), index=True\n", + "# )\n", + "\n", + "# # we can get selected features from alpha\n", + "# selected_ft_inf = df_alpha_with_labels[\n", + "# df_alpha_with_labels[\"alpha\"] != 0\n", + "# ].index.tolist()\n", + "# assert selected_ft_inf[1:] == selected_ft.tolist()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Solve with Classo directly" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# former alpha AFTER refit\n", + "# lam used in code = 0.001191103133283007 = problem.solution.CV.lambda_1SE\n", + "alpha" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lam = problem.solution.CV.lambda_1SE\n", + "\n", + "matrices = (log_geom_trainval, np.ones((1, len(log_geom_trainval[0]))), y_train_val)\n", + "\n", + "method = \"Path-Alg\"\n", + "intercept = True\n", + "beta_norefit = Classo(\n", + " matrix=matrices, lam=lam, typ=\"R1\", meth=method, w=1 / nleaves, intercept=intercept\n", + ")\n", + "# new alpha\n", + "beta_norefit" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def solve_unpenalized_least_squares(cmatrices, intercept=False):\n", + " # adapted from classo > misc_functions.py > unpenalised\n", + " if intercept:\n", + " A1, C1, y = cmatrices\n", + " A = np.concatenate([np.ones((len(A1), 1)), A1], axis=1)\n", + " C = np.concatenate([np.zeros((len(C1), 1)), C1], axis=1)\n", + " else:\n", + " A, C, y = cmatrices\n", + "\n", + " k = len(C)\n", + " d = len(A[0])\n", + " M1 = np.concatenate([A.T.dot(A), C.T], axis=1)\n", + " M2 = np.concatenate([C, np.zeros((k, k))], axis=1)\n", + " M = np.concatenate([M1, M2], axis=0)\n", + " b = np.concatenate([A.T.dot(y), np.zeros(k)])\n", + " sol = linalg.lstsq(M, b, rcond=None)[0]\n", + " beta = sol[:d]\n", + " return beta\n", + "\n", + "\n", + "def min_least_squares_solution(matrices, selected, intercept=False):\n", + " \"\"\"Minimum Least Squares solution for selected features.\"\"\"\n", + " # adapted from classo > misc_functions.py > min_LS\n", + " X, C, y = matrices\n", + " beta = np.zeros(len(selected))\n", + "\n", + " if intercept:\n", + " beta[selected] = solve_unpenalized_least_squares(\n", + " (X[:, selected[1:]], C[:, selected[1:]], y), intercept=selected[0]\n", + " )\n", + " else:\n", + " beta[selected] = solve_unpenalized_least_squares(\n", + " (X[:, selected], C[:, selected], y), intercept=False\n", + " )\n", + "\n", + " return beta" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "selected_param = abs(beta_norefit) > 1e-5\n", + "beta_refit = min_least_squares_solution(matrices, selected_param, intercept=intercept)\n", + "\n", + "\n", + "assert np.array_equal(alpha, beta_refit)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Perform prediction on test set" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "# derive log_geom for test\n", + "ft_cols = [x for x in test.columns if x.startswith(\"F\")]\n", + "\n", + "x_test = test[ft_cols]\n", + "y_test = test[target]\n", + "# todo: read A\n", + "log_geom_test, nleaves = _preprocess_taxonomy_aggregation(x_test.values, A)\n", + "\n", + "# apply model to test\n", + "# todo: read alpha\n", + "y_test_pred = log_geom_test.dot(alpha[1:]) + alpha[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "# # test pickle\n", + "# # Create a dictionary to store the dataframes\n", + "# model = {\"model\": df_alpha_with_labels, \"matrix_a\": df_A_with_labels}\n", + "\n", + "# # Serialize the dictionary to a pickle file\n", + "# with open(\"data.pkl\", \"wb\") as file:\n", + "# pickle.dump(model, file)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "metadata": {} + }, + "outputs": [], + "source": [ + "# with open(\"data.pkl\", \"rb\") as file:\n", + "# model = pickle.load(file)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ritme_wclasso", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/experiments/test_classo.ipynb b/experiments/test_classo.ipynb new file mode 100644 index 0000000..b836e09 --- /dev/null +++ b/experiments/test_classo.ipynb @@ -0,0 +1,432 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Test classo\n", + "code edited from CentralPark example in GitHub repos [here](https://github.com/Leo-Simpson/c-lasso/blob/master/examples/example_CentralParkSoil.py)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Import" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from os.path import join\n", + "from classo import classo_problem\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "%matplotlib inline\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "data_dir = join(\"data\", \"CentralParkSoil\")\n", + "data = np.load(join(data_dir, \"cps.npz\"))\n", + "\n", + "# X are relative abundance counts\n", + "x = data[\"x\"] # (580, 3379)\n", + "\n", + "# y is target\n", + "y = data[\"y\"] # (580,)\n", + "\n", + "label = data[\"label\"] # (3704,) = 3379 OTUs + 325 nodes in tree\n", + "label_short = np.array([la.split(\"::\")[-1] for la in label])" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(['Life::k__Bacteria::p__Proteobacteria::c__Gammaproteobacteria::o__Legionellales::f__Coxiellaceae::g__Aquicella::s__1::OTU_2211',\n", + " 'Life::k__Bacteria::p__Proteobacteria::c__Gammaproteobacteria::o__Legionellales::f__Coxiellaceae::g__Aquicella::s__2::OTU_1172',\n", + " 'Life::k__Bacteria::p__Proteobacteria::c__Gammaproteobacteria::o__Legionellales::f__Coxiellaceae::g__Aquicella::s__3::OTU_1734',\n", + " ...,\n", + " 'Life::k__Bacteria::p__Proteobacteria::c__Gammaproteobacteria',\n", + " 'Life::k__Bacteria::p__Proteobacteria', 'Life::k__Bacteria'],\n", + " dtype=' uses lambda_1SE else lambda_min (see CV in\n", + "# ! classo>cross_validation.py)\n", + "# refit -> solves unconstrained least squares problem with selected lambda and\n", + "# variables\n", + "alpha = problem.solution.CV.refit\n", + "len(alpha)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# model prediction\n", + "yhat = logGeom[te].dot(alpha[1:]) + alpha[0]\n", + "\n", + "M1, M2 = max(y[te]), min(y[te])\n", + "plt.plot(yhat, y[te], \"bo\", label=\"sample of the testing set\")\n", + "plt.plot([M1, M2], [M1, M2], \"k-\", label=\"identity\")\n", + "plt.xlabel(\"predictor yhat\"), plt.ylabel(\"real y\"), plt.legend()\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Stability selection" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "problem = classo_problem(logGeom[tr], y[tr], label=label_short)\n", + "\n", + "problem.formulation.w = 1 / nleaves\n", + "problem.formulation.intercept = True\n", + "problem.formulation.concomitant = False\n", + "\n", + "\n", + "problem.model_selection.PATH = False\n", + "problem.model_selection.CV = False\n", + "# can change q, B, nS, method, threshold etc in problem.model_selection.StabSelparameters\n", + "\n", + "problem.solve()\n", + "\n", + "print(problem, problem.solution)\n", + "\n", + "selection = problem.solution.StabSel.selected_param[1:] # exclude the intercept\n", + "print(label[selection])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prediction plot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "te = np.array([i for i in range(len(y)) if i not in tr])\n", + "alpha = problem.solution.StabSel.refit\n", + "yhat = logGeom[te].dot(alpha[1:]) + alpha[0]\n", + "\n", + "M1, M2 = max(y[te]), min(y[te])\n", + "plt.plot(yhat, y[te], \"bo\", label=\"sample of the testing set\")\n", + "plt.plot([M1, M2], [M1, M2], \"k-\", label=\"identity\")\n", + "plt.xlabel(\"predictor yhat\"), plt.ylabel(\"real y\"), plt.legend()\n", + "plt.tight_layout()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ritme_wclasso", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/experiments/trac_tree_problem.ipynb b/experiments/trac_tree_problem.ipynb new file mode 100644 index 0000000..6f32c43 --- /dev/null +++ b/experiments/trac_tree_problem.ipynb @@ -0,0 +1,874 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Notebook to illustrate problem of weird tree matching to taxonomy" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "import qiime2 as q2\n", + "import skbio\n", + "from qiime2.plugins import phylogeny\n", + "\n", + "%matplotlib inline\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Read data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(9478, 5580)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# read feature table\n", + "art_feature_table = q2.Artifact.load(\"data/220728_monthly/all_otu_table_filt.qza\")\n", + "df_ft = art_feature_table.view(pd.DataFrame)\n", + "df_ft.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(5608, 2)\n", + "(5580, 2)\n" + ] + } + ], + "source": [ + "# read taxonomy\n", + "path_to_taxonomy = \"data/220728_monthly/otu_taxonomy_all.qza\"\n", + "art_taxonomy = q2.Artifact.load(path_to_taxonomy)\n", + "df_taxonomy = art_taxonomy.view(pd.DataFrame)\n", + "print(df_taxonomy.shape)\n", + "\n", + "# Filter the taxonomy based on the feature table\n", + "df_taxonomy_f = df_taxonomy[df_taxonomy.index.isin(df_ft.columns.tolist())]\n", + "print(df_taxonomy_f.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "870198\n" + ] + }, + { + "data": { + "text/plain": [ + "11159" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# read silva phylo tree\n", + "path_to_phylo = \"data/220728_monthly/silva-138-99-rooted-tree.qza\"\n", + "art_phylo = q2.Artifact.load(path_to_phylo)\n", + "tree_phylo = art_phylo.view(skbio.TreeNode)\n", + "# total nodes\n", + "print(tree_phylo.count())\n", + "\n", + "# filter tree by feature table: this prunes a phylogenetic tree to match the\n", + "# input ids\n", + "(art_phylo_f,) = phylogeny.actions.filter_tree(tree=art_phylo, table=art_feature_table)\n", + "tree_phylo_f = art_phylo_f.view(skbio.TreeNode)\n", + "\n", + "# total nodes\n", + "tree_phylo_f.count()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# ensure that # leaves in tree == feature table dimension\n", + "num_leaves = tree_phylo_f.count(tips=True)\n", + "assert num_leaves == df_ft.shape[1]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Get matrix A" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def _create_matrix_from_tree(tree, tax):\n", + " # ! function copied from _process_train.py after git commit \"b51c22f\" in PR\n", + " # ! #16 to have dic_node2leaf output conserved Get all leaves and create a\n", + " # ! mapping from leaf names to indices\n", + " leaves = list(tree.tips())\n", + " leaf_names = [leaf.name for leaf in leaves]\n", + " # map each leaf name to unique index\n", + " leaf_index_map = {name: idx for idx, name in enumerate(leaf_names)}\n", + "\n", + " # Get the number of leaves and internal nodes\n", + " num_leaves = len(leaf_names)\n", + " # root is not included\n", + " internal_nodes = list(tree.non_tips())\n", + "\n", + " # Create the identity matrix for the leaves: A1 (num_leaves x num_leaves)\n", + " A1 = np.eye(num_leaves)\n", + " # taxonomic name should include OTU name\n", + " tax_e = tax.copy()\n", + " tax_e[\"tax_ft\"] = tax_e[\"Taxon\"] + \"; otu__\" + tax_e.index\n", + " a2_node_names = tax_e.loc[leaf_names, \"tax_ft\"].tolist()\n", + " # Create the matrix for the internal nodes: A2 (num_leaves x\n", + " # num_internal_nodes)\n", + " # initialise it with zeros\n", + " A2 = np.zeros((num_leaves, len(internal_nodes)))\n", + "\n", + " # Populate A2 with 1s for the leaves linked by each internal node\n", + " # iterate over all internal nodes to find descendents of this node and mark\n", + " # them accordingly\n", + " dict_node2leaf = {}\n", + " for j, node in enumerate(internal_nodes):\n", + " # per node keep track of leaf names - for consensus naming\n", + " node_leaf_names = []\n", + " # todo: adjust names to consensus taxonomy from descendents\n", + " # for now node names are just increasing integers - since node.name is float\n", + " descendant_leaves = {leaf.name for leaf in node.tips()}\n", + " for leaf_name in leaf_names:\n", + " if leaf_name in descendant_leaves:\n", + " node_leaf_names.append(leaf_name)\n", + " A2[leaf_index_map[leaf_name], j] = 1\n", + "\n", + " # create consensus taxonomy from all leaf_names\n", + " node_mapped_taxon = tax_e.loc[node_leaf_names, \"tax_ft\"].tolist()\n", + " dict_node2leaf[j] = node_mapped_taxon\n", + " str_consensus_taxon = os.path.commonprefix(node_mapped_taxon)\n", + " # get name before last \";\"\n", + " node_consensus_taxon = str_consensus_taxon.rpartition(\";\")[0]\n", + "\n", + " if node_consensus_taxon in a2_node_names:\n", + " # if consensus name already exists, add index to make it unique\n", + " node_consensus_taxon = node_consensus_taxon + \"; n__\" + str(j)\n", + " a2_node_names.append(node_consensus_taxon)\n", + "\n", + " # Concatenate A1 and A2 to create the final matrix A\n", + " A = np.hstack((A1, A2))\n", + "\n", + " return A, a2_node_names, dict_node2leaf" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(5580, 11158)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "A, a2_names, dic_node2leaf = _create_matrix_from_tree(tree_phylo_f, df_taxonomy_f)\n", + "\n", + "df_A = pd.DataFrame(A, columns=a2_names)\n", + "df_A.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
d__Archaea; p__Nanoarchaeota; c__Nanoarchaeia; o__Woesearchaeales; f__SCGC_AAA011-D5; g__SCGC_AAA011-D5; s__Nanoarchaeota_archaeon; otu__ASMP01000002.125551.126982d__Archaea; p__Euryarchaeota; c__Methanobacteria; o__Methanobacteriales; f__Methanobacteriaceae; g__Methanobrevibacter; otu__JX833581.1.1262d__Archaea; p__Euryarchaeota; c__Methanobacteria; o__Methanobacteriales; f__Methanobacteriaceae; g__Methanobrevibacter; s__uncultured_Methanobacteriales; otu__AB535261.1.1262d__Archaea; p__Euryarchaeota; c__Methanobacteria; o__Methanobacteriales; f__Methanobacteriaceae; g__Methanosphaera; otu__CP000102.408655.410144d__Archaea; p__Euryarchaeota; c__Methanobacteria; o__Methanobacteriales; f__Methanobacteriaceae; g__Methanosphaera; s__uncultured_methanogenic; otu__AB905959.1.1268d__Archaea; p__Euryarchaeota; c__Methanobacteria; o__Methanobacteriales; f__Methanobacteriaceae; g__Methanobrevibacter; otu__AY196669.1.1262d__Archaea; p__Euryarchaeota; c__Methanobacteria; o__Methanobacteriales; f__Methanobacteriaceae; g__Methanobrevibacter; otu__AB905821.1.1267d__Archaea; p__Thermoplasmatota; c__Thermoplasmata; o__Methanomassiliicoccales; f__Methanomethylophilaceae; otu__DQ445723.1.1210d__Archaea; p__Thermoplasmatota; c__Thermoplasmata; o__Methanomassiliicoccales; f__Methanomethylophilaceae; otu__JF980498.1.1419d__Bacteria; p__Patescibacteria; c__Gracilibacteria; o__Absconditabacteriales_(SR1); f__Absconditabacteriales_(SR1); g__Absconditabacteriales_(SR1); s__SR1_bacterium; otu__AOTF01000010.101107.102585...d__Bacteria; n__5568d__Bacteria; p__Chloroflexi; c__Chloroflexia; o__Thermomicrobiales; f__JG30-KF-CM45; g__JG30-KF-CM45d__Bacteria; p__Chloroflexi; c__Chloroflexia; o__Thermomicrobiales; f__JG30-KF-CM45; g__JG30-KF-CM45; n__5570d__Bacteria; p__Chloroflexi; c__Chloroflexia; o__Thermomicrobiales; f__JG30-KF-CM45; g__JG30-KF-CM45; n__5571d__Bacteria; p__Chloroflexi; c__Chloroflexia; o__Thermomicrobiales; f__JG30-KF-CM45; g__JG30-KF-CM45; n__5572d__Bacteria; p__Chloroflexid__Bacteria; p__Chloroflexi; n__5574d__Bacteria; p__Chloroflexi; n__5575d__Bacteria; p__Chloroflexi; n__5576d__Bacteria; n__5577
01.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
10.01.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
20.00.01.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
30.00.00.01.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
40.00.00.00.01.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
\n", + "

5 rows \u00d7 11158 columns

\n", + "
" + ], + "text/plain": [ + " d__Archaea; p__Nanoarchaeota; c__Nanoarchaeia; o__Woesearchaeales; f__SCGC_AAA011-D5; g__SCGC_AAA011-D5; s__Nanoarchaeota_archaeon; otu__ASMP01000002.125551.126982 \\\n", + "0 1.0 \n", + "1 0.0 \n", + "2 0.0 \n", + "3 0.0 \n", + "4 0.0 \n", + "\n", + " d__Archaea; p__Euryarchaeota; c__Methanobacteria; o__Methanobacteriales; f__Methanobacteriaceae; g__Methanobrevibacter; otu__JX833581.1.1262 \\\n", + "0 0.0 \n", + "1 1.0 \n", + "2 0.0 \n", + "3 0.0 \n", + "4 0.0 \n", + "\n", + " d__Archaea; p__Euryarchaeota; c__Methanobacteria; o__Methanobacteriales; f__Methanobacteriaceae; g__Methanobrevibacter; s__uncultured_Methanobacteriales; otu__AB535261.1.1262 \\\n", + "0 0.0 \n", + "1 0.0 \n", + "2 1.0 \n", + "3 0.0 \n", + "4 0.0 \n", + "\n", + " d__Archaea; p__Euryarchaeota; c__Methanobacteria; o__Methanobacteriales; f__Methanobacteriaceae; g__Methanosphaera; otu__CP000102.408655.410144 \\\n", + "0 0.0 \n", + "1 0.0 \n", + "2 0.0 \n", + "3 1.0 \n", + "4 0.0 \n", + "\n", + " d__Archaea; p__Euryarchaeota; c__Methanobacteria; o__Methanobacteriales; f__Methanobacteriaceae; g__Methanosphaera; s__uncultured_methanogenic; otu__AB905959.1.1268 \\\n", + "0 0.0 \n", + "1 0.0 \n", + "2 0.0 \n", + "3 0.0 \n", + "4 1.0 \n", + "\n", + " d__Archaea; p__Euryarchaeota; c__Methanobacteria; o__Methanobacteriales; f__Methanobacteriaceae; g__Methanobrevibacter; otu__AY196669.1.1262 \\\n", + "0 0.0 \n", + "1 0.0 \n", + "2 0.0 \n", + "3 0.0 \n", + "4 0.0 \n", + "\n", + " d__Archaea; p__Euryarchaeota; c__Methanobacteria; o__Methanobacteriales; f__Methanobacteriaceae; g__Methanobrevibacter; otu__AB905821.1.1267 \\\n", + "0 0.0 \n", + "1 0.0 \n", + "2 0.0 \n", + "3 0.0 \n", + "4 0.0 \n", + "\n", + " d__Archaea; p__Thermoplasmatota; c__Thermoplasmata; o__Methanomassiliicoccales; f__Methanomethylophilaceae; otu__DQ445723.1.1210 \\\n", + "0 0.0 \n", + "1 0.0 \n", + "2 0.0 \n", + "3 0.0 \n", + "4 0.0 \n", + "\n", + " d__Archaea; p__Thermoplasmatota; c__Thermoplasmata; o__Methanomassiliicoccales; f__Methanomethylophilaceae; otu__JF980498.1.1419 \\\n", + "0 0.0 \n", + "1 0.0 \n", + "2 0.0 \n", + "3 0.0 \n", + "4 0.0 \n", + "\n", + " d__Bacteria; p__Patescibacteria; c__Gracilibacteria; o__Absconditabacteriales_(SR1); f__Absconditabacteriales_(SR1); g__Absconditabacteriales_(SR1); s__SR1_bacterium; otu__AOTF01000010.101107.102585 \\\n", + "0 0.0 \n", + "1 0.0 \n", + "2 0.0 \n", + "3 0.0 \n", + "4 0.0 \n", + "\n", + " ... d__Bacteria; n__5568 \\\n", + "0 ... 0.0 \n", + "1 ... 0.0 \n", + "2 ... 0.0 \n", + "3 ... 0.0 \n", + "4 ... 0.0 \n", + "\n", + " d__Bacteria; p__Chloroflexi; c__Chloroflexia; o__Thermomicrobiales; f__JG30-KF-CM45; g__JG30-KF-CM45 \\\n", + "0 0.0 \n", + "1 0.0 \n", + "2 0.0 \n", + "3 0.0 \n", + "4 0.0 \n", + "\n", + " d__Bacteria; p__Chloroflexi; c__Chloroflexia; o__Thermomicrobiales; f__JG30-KF-CM45; g__JG30-KF-CM45; n__5570 \\\n", + "0 0.0 \n", + "1 0.0 \n", + "2 0.0 \n", + "3 0.0 \n", + "4 0.0 \n", + "\n", + " d__Bacteria; p__Chloroflexi; c__Chloroflexia; o__Thermomicrobiales; f__JG30-KF-CM45; g__JG30-KF-CM45; n__5571 \\\n", + "0 0.0 \n", + "1 0.0 \n", + "2 0.0 \n", + "3 0.0 \n", + "4 0.0 \n", + "\n", + " d__Bacteria; p__Chloroflexi; c__Chloroflexia; o__Thermomicrobiales; f__JG30-KF-CM45; g__JG30-KF-CM45; n__5572 \\\n", + "0 0.0 \n", + "1 0.0 \n", + "2 0.0 \n", + "3 0.0 \n", + "4 0.0 \n", + "\n", + " d__Bacteria; p__Chloroflexi d__Bacteria; p__Chloroflexi; n__5574 \\\n", + "0 0.0 0.0 \n", + "1 0.0 0.0 \n", + "2 0.0 0.0 \n", + "3 0.0 0.0 \n", + "4 0.0 0.0 \n", + "\n", + " d__Bacteria; p__Chloroflexi; n__5575 d__Bacteria; p__Chloroflexi; n__5576 \\\n", + "0 0.0 0.0 \n", + "1 0.0 0.0 \n", + "2 0.0 0.0 \n", + "3 0.0 0.0 \n", + "4 0.0 0.0 \n", + "\n", + " d__Bacteria; n__5577 \n", + "0 0.0 \n", + "1 0.0 \n", + "2 0.0 \n", + "3 0.0 \n", + "4 0.0 \n", + "\n", + "[5 rows x 11158 columns]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_A.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Show problem with identical \"taxonomic nodes\"" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
d__Bacteria; p__Chloroflexid__Bacteria; p__Chloroflexi; n__5574d__Bacteria; p__Chloroflexi; n__5575d__Bacteria; p__Chloroflexi; n__5576
00.00.00.00.0
10.00.00.00.0
20.00.00.00.0
30.00.00.00.0
40.00.00.00.0
...............
55751.01.00.01.0
55761.01.00.01.0
55771.01.00.01.0
55780.00.01.01.0
55790.00.01.01.0
\n", + "

5580 rows \u00d7 4 columns

\n", + "
" + ], + "text/plain": [ + " d__Bacteria; p__Chloroflexi d__Bacteria; p__Chloroflexi; n__5574 \\\n", + "0 0.0 0.0 \n", + "1 0.0 0.0 \n", + "2 0.0 0.0 \n", + "3 0.0 0.0 \n", + "4 0.0 0.0 \n", + "... ... ... \n", + "5575 1.0 1.0 \n", + "5576 1.0 1.0 \n", + "5577 1.0 1.0 \n", + "5578 0.0 0.0 \n", + "5579 0.0 0.0 \n", + "\n", + " d__Bacteria; p__Chloroflexi; n__5575 \\\n", + "0 0.0 \n", + "1 0.0 \n", + "2 0.0 \n", + "3 0.0 \n", + "4 0.0 \n", + "... ... \n", + "5575 0.0 \n", + "5576 0.0 \n", + "5577 0.0 \n", + "5578 1.0 \n", + "5579 1.0 \n", + "\n", + " d__Bacteria; p__Chloroflexi; n__5576 \n", + "0 0.0 \n", + "1 0.0 \n", + "2 0.0 \n", + "3 0.0 \n", + "4 0.0 \n", + "... ... \n", + "5575 1.0 \n", + "5576 1.0 \n", + "5577 1.0 \n", + "5578 1.0 \n", + "5579 1.0 \n", + "\n", + "[5580 rows x 4 columns]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nb_otus = 5580\n", + "\n", + "df_A.iloc[:, nb_otus + 5573 : nb_otus + 5577]" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "d__Bacteria; p__Chloroflexi 6.0\n", + "d__Bacteria; p__Chloroflexi; n__5574 7.0\n", + "d__Bacteria; p__Chloroflexi; n__5575 2.0\n", + "d__Bacteria; p__Chloroflexi; n__5576 9.0\n", + "dtype: float64" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_A.iloc[:, nb_otus + 5573 : nb_otus + 5577].sum()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['d__Bacteria; p__Chloroflexi; c__OLB14; o__OLB14; f__OLB14; g__OLB14; s__uncultured_bacterium; otu__KT835595.1.1453',\n", + " 'd__Bacteria; p__Chloroflexi; c__Chloroflexia; o__Thermomicrobiales; f__JG30-KF-CM45; g__JG30-KF-CM45; s__bacterium_QTYC46b; otu__JQ624352.1.1469',\n", + " 'd__Bacteria; p__Chloroflexi; c__Chloroflexia; o__Thermomicrobiales; f__JG30-KF-CM45; g__JG30-KF-CM45; s__uncultured_bacterium; otu__JQ978845.1.1460',\n", + " 'd__Bacteria; p__Chloroflexi; c__Chloroflexia; o__Thermomicrobiales; f__JG30-KF-CM45; g__JG30-KF-CM45; s__uncultured_Chloroflexi; otu__AM935498.1.1337',\n", + " 'd__Bacteria; p__Chloroflexi; c__Chloroflexia; o__Thermomicrobiales; f__JG30-KF-CM45; g__JG30-KF-CM45; s__uncultured_soil; otu__FQ659571.1.1334',\n", + " 'd__Bacteria; p__Chloroflexi; c__Chloroflexia; o__Thermomicrobiales; f__JG30-KF-CM45; g__JG30-KF-CM45; s__metagenome; otu__FPLS01036268.22.1504']" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# get matching leave taxonomic names\n", + "dic_node2leaf[5573]" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['d__Bacteria; p__Chloroflexi; c__Anaerolineae; o__SBR1031; f__SBR1031; g__SBR1031; s__anaerobic_digester; otu__CZCB01001507.135.1626',\n", + " 'd__Bacteria; p__Chloroflexi; c__OLB14; o__OLB14; f__OLB14; g__OLB14; s__uncultured_bacterium; otu__KT835595.1.1453',\n", + " 'd__Bacteria; p__Chloroflexi; c__Chloroflexia; o__Thermomicrobiales; f__JG30-KF-CM45; g__JG30-KF-CM45; s__bacterium_QTYC46b; otu__JQ624352.1.1469',\n", + " 'd__Bacteria; p__Chloroflexi; c__Chloroflexia; o__Thermomicrobiales; f__JG30-KF-CM45; g__JG30-KF-CM45; s__uncultured_bacterium; otu__JQ978845.1.1460',\n", + " 'd__Bacteria; p__Chloroflexi; c__Chloroflexia; o__Thermomicrobiales; f__JG30-KF-CM45; g__JG30-KF-CM45; s__uncultured_Chloroflexi; otu__AM935498.1.1337',\n", + " 'd__Bacteria; p__Chloroflexi; c__Chloroflexia; o__Thermomicrobiales; f__JG30-KF-CM45; g__JG30-KF-CM45; s__uncultured_soil; otu__FQ659571.1.1334',\n", + " 'd__Bacteria; p__Chloroflexi; c__Chloroflexia; o__Thermomicrobiales; f__JG30-KF-CM45; g__JG30-KF-CM45; s__metagenome; otu__FPLS01036268.22.1504']" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dic_node2leaf[5574]" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['d__Bacteria; p__Chloroflexi; c__Gitt-GS-136; o__Gitt-GS-136; f__Gitt-GS-136; g__Gitt-GS-136; otu__HM299006.1.1326',\n", + " 'd__Bacteria; p__Chloroflexi; c__KD4-96; o__KD4-96; f__KD4-96; g__KD4-96; s__uncultured_bacterium; otu__KY190653.1.1395']" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dic_node2leaf[5575]" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['d__Bacteria; p__Chloroflexi; c__Anaerolineae; o__SBR1031; f__SBR1031; g__SBR1031; s__anaerobic_digester; otu__CZCB01001507.135.1626',\n", + " 'd__Bacteria; p__Chloroflexi; c__OLB14; o__OLB14; f__OLB14; g__OLB14; s__uncultured_bacterium; otu__KT835595.1.1453',\n", + " 'd__Bacteria; p__Chloroflexi; c__Chloroflexia; o__Thermomicrobiales; f__JG30-KF-CM45; g__JG30-KF-CM45; s__bacterium_QTYC46b; otu__JQ624352.1.1469',\n", + " 'd__Bacteria; p__Chloroflexi; c__Chloroflexia; o__Thermomicrobiales; f__JG30-KF-CM45; g__JG30-KF-CM45; s__uncultured_bacterium; otu__JQ978845.1.1460',\n", + " 'd__Bacteria; p__Chloroflexi; c__Chloroflexia; o__Thermomicrobiales; f__JG30-KF-CM45; g__JG30-KF-CM45; s__uncultured_Chloroflexi; otu__AM935498.1.1337',\n", + " 'd__Bacteria; p__Chloroflexi; c__Chloroflexia; o__Thermomicrobiales; f__JG30-KF-CM45; g__JG30-KF-CM45; s__uncultured_soil; otu__FQ659571.1.1334',\n", + " 'd__Bacteria; p__Chloroflexi; c__Chloroflexia; o__Thermomicrobiales; f__JG30-KF-CM45; g__JG30-KF-CM45; s__metagenome; otu__FPLS01036268.22.1504',\n", + " 'd__Bacteria; p__Chloroflexi; c__Gitt-GS-136; o__Gitt-GS-136; f__Gitt-GS-136; g__Gitt-GS-136; otu__HM299006.1.1326',\n", + " 'd__Bacteria; p__Chloroflexi; c__KD4-96; o__KD4-96; f__KD4-96; g__KD4-96; s__uncultured_bacterium; otu__KY190653.1.1395']" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dic_node2leaf[5576]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Is this a problem?" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ritme", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/q2_ritme/eval_best_trial_overall.py b/q2_ritme/eval_best_trial_overall.py index 1d5d972..a11fb63 100644 --- a/q2_ritme/eval_best_trial_overall.py +++ b/q2_ritme/eval_best_trial_overall.py @@ -29,7 +29,7 @@ def parse_args(): "--ls_model_types", type=str, nargs="+", - default=["nn_reg", "nn_class", "nn_corn", "xgb", "linreg", "rf"], + default=["nn_reg", "nn_class", "nn_corn", "xgb", "linreg", "rf", "trac"], help="List of model types to evaluate. Separate each model type with a space.", ) return parser.parse_args() @@ -55,10 +55,13 @@ def main(): experiment_dir = f"{model_path}/*/{model}" analyses_ls = get_all_exp_analyses(experiment_dir) - # identify best trial from all analyses of this model type - best_trials_overall[model] = best_trial_name( - analyses_ls, "rmse_val", mode="min" - ) + if len(analyses_ls) == 0: + print(f"No analyses found for model type: {model}") + else: + # identify best trial from all analyses of this model type + best_trials_overall[model] = best_trial_name( + analyses_ls, "rmse_val", mode="min" + ) compare_trials(best_trials_overall, model_path, overall_comparison_output) diff --git a/q2_ritme/evaluate_models.py b/q2_ritme/evaluate_models.py index 7bfae54..e9219f7 100644 --- a/q2_ritme/evaluate_models.py +++ b/q2_ritme/evaluate_models.py @@ -1,4 +1,5 @@ import os +import pickle from typing import Any import matplotlib.pyplot as plt @@ -11,6 +12,9 @@ from ray.air.result import Result from sklearn.metrics import mean_squared_error +from q2_ritme.feature_space._process_trac_specific import ( + _preprocess_taxonomy_aggregation, +) from q2_ritme.feature_space.transform_features import transform_features from q2_ritme.model_space.static_trainables import NeuralNet @@ -48,6 +52,19 @@ def load_sklearn_model(result: Result) -> Any: return load(result.metrics["model_path"]) +def load_trac_model(result: Result) -> Any: + """ + Load a TRAC model from a given result object. + + :param result: The result object containing the model path. + :return: The loaded TRAC model. + """ + with open(result.metrics["model_path"], "rb") as file: + model = pickle.load(file) + + return model + + def load_xgb_model(result: Result) -> xgb.Booster: """ Load an XGBoost model from a given result object. @@ -77,6 +94,7 @@ def get_model(model_type: str, result) -> Any: """ model_loaders = { "linreg": load_sklearn_model, + "trac": load_trac_model, "rf": load_sklearn_model, "xgb": load_xgb_model, "nn_reg": load_nn_model, @@ -128,6 +146,13 @@ def predict(self, data): elif self.model.nn_type == "ordinal_regression": logits = self.model(transformed) predicted = corn_label_from_logits(logits).numpy() + elif isinstance(self.model, dict): + # trac model + log_geom, _ = _preprocess_taxonomy_aggregation( + transformed, self.model["matrix_a"].values + ) + alpha = self.model["model"].values + predicted = log_geom.dot(alpha[1:]) + alpha[0] else: predicted = self.model.predict(transformed).flatten() return predicted @@ -189,7 +214,6 @@ def plot_rmse_over_experiments(preds_dic, save_loc, dpi=400): path_to_save = os.path.join(save_loc, "rmse_over_experiments_train_test.png") plt.tight_layout() plt.savefig(path_to_save, dpi=dpi) - plt.show() def plot_rmse_over_time(preds_dic, ls_model_types, save_loc, dpi=300): @@ -232,7 +256,6 @@ def plot_rmse_over_time(preds_dic, ls_model_types, save_loc, dpi=300): save_loc, f"rmse_over_time_train_test_{model_type}.png" ) plt.savefig(path_to_save, dpi=dpi) - plt.show() def get_best_model_metrics_and_config( @@ -305,7 +328,6 @@ def plot_best_models_comparison( plt.tight_layout() path_to_save = os.path.join(save_loc, "rmse_over_experiments_train_val.png") plt.savefig(path_to_save, dpi=400) - plt.show() def plot_model_training_over_iterations(model_type, result_dic, labels, save_loc): diff --git a/q2_ritme/feature_space/_process_trac_specific.py b/q2_ritme/feature_space/_process_trac_specific.py new file mode 100644 index 0000000..fc6c55f --- /dev/null +++ b/q2_ritme/feature_space/_process_trac_specific.py @@ -0,0 +1,116 @@ +import os + +import numpy as np +import pandas as pd + + +def _verify_matrix_a(A, feature_columns, tree_phylo): + # no all 1 in one column + assert not np.any(np.all(A == 1.0, axis=0)) + + # shape should be = feature_count + node_count + nb_features = len(feature_columns) + nb_non_leaf_nodes = len(list(tree_phylo.non_tips())) + + assert nb_features + nb_non_leaf_nodes == A.shape[1] + + +def _get_leaves_and_index_map(tree): + leaves = list(tree.tips()) + leaf_names = [leaf.name for leaf in leaves] + # map each leaf name to unique index + leaf_index_map = {name: idx for idx, name in enumerate(leaf_names)} + return leaves, leaf_index_map + + +def _get_internal_nodes(tree): + # root is not included + return list(tree.non_tips()) + + +def _create_identity_matrix_for_leaves(num_leaves, tax, leaves): + A1 = np.eye(num_leaves) + # taxonomic name should include OTU name + tax_e = tax.copy() + tax_e["tax_ft"] = tax_e["Taxon"] + "; otu__" + tax_e.index + a1_node_names = tax_e.loc[[leaf.name for leaf in leaves], "tax_ft"].tolist() + return A1, a1_node_names + + +def _populate_A2_for_node(A2, node, leaf_index_map, j): + node_leaf_names = [] + # flag leaves that match to a node + descendant_leaves = {leaf.name for leaf in node.tips()} + for leaf_name in leaf_index_map: + if leaf_name in descendant_leaves: + node_leaf_names.append(leaf_name) + A2[leaf_index_map[leaf_name], j] = 1 + return A2, node_leaf_names + + +def _create_consensus_taxonomy(node_leaf_names, tax, a2_node_names, j): + tax_e = tax.copy() + tax_e["tax_ft"] = tax_e["Taxon"] + "; otu__" + tax_e.index + node_mapped_taxon = tax_e.loc[node_leaf_names, "tax_ft"].tolist() + str_consensus_taxon = os.path.commonprefix(node_mapped_taxon) + # get name before last ";" + node_consensus_taxon = str_consensus_taxon.rpartition(";")[0] + # if consensus name already exists, add index to make it unique + if node_consensus_taxon in a2_node_names: + node_consensus_taxon = node_consensus_taxon + "; n__" + str(j) + return node_consensus_taxon + + +def _create_matrix_for_internal_nodes(num_leaves, internal_nodes, leaf_index_map, tax): + # initialise it with zeros + A2 = np.zeros((num_leaves, len(internal_nodes))) + a2_node_names = [] + # Populate A2 with 1s for the leaves linked by each internal node # iterate + # over all internal nodes to find descendents of this node and mark them + # accordingly + for j, node in enumerate(internal_nodes): + A2, node_leaf_names = _populate_A2_for_node(A2, node, leaf_index_map, j) + # create consensus taxonomy from all leaf_names- since node.name is just float + node_consensus_taxon = _create_consensus_taxonomy( + node_leaf_names, tax, a2_node_names, j + ) + a2_node_names.append(node_consensus_taxon) + return A2, a2_node_names + + +def create_matrix_from_tree(tree, tax) -> pd.DataFrame: + # Get all leaves and create a mapping from leaf names to indices + leaves, leaf_index_map = _get_leaves_and_index_map(tree) + num_leaves = len(leaves) + + # Get all internal nodes + internal_nodes = _get_internal_nodes(tree) + + # Create the identity matrix for the leaves: A1 (num_leaves x num_leaves) + A1, a1_node_names = _create_identity_matrix_for_leaves(num_leaves, tax, leaves) + + # Create the matrix for the internal nodes: A2 (num_leaves x num_internal_nodes) + A2, a2_node_names = _create_matrix_for_internal_nodes( + num_leaves, internal_nodes, leaf_index_map, tax + ) + + # Concatenate A1 and A2 to create the final matrix A + A = np.hstack((A1, A2)) + df_a = pd.DataFrame( + A, columns=a1_node_names + a2_node_names, index=[leaf.name for leaf in leaves] + ) + _verify_matrix_a(df_a.values, tax.index.tolist(), tree) + + return df_a + + +def _preprocess_taxonomy_aggregation(x, A): + pseudo_count = 0.000001 + + X = np.log(pseudo_count + x) + nleaves = np.sum(A, axis=0) + # safekeeping: dot-product would not work with wrong dimensions + # X: n_samples, n_features, A: n_features, (n_features+n_nodes) + log_geom = X.dot(A) / nleaves + + return log_geom, nleaves diff --git a/q2_ritme/feature_space/_process_train.py b/q2_ritme/feature_space/_process_train.py index b7a327a..64128dc 100644 --- a/q2_ritme/feature_space/_process_train.py +++ b/q2_ritme/feature_space/_process_train.py @@ -2,13 +2,10 @@ from q2_ritme.process_data import split_data_by_host -def process_train(config, train_val, target, host_id, seed_data): - # todo: adjust feature selection in future to include md - # todo: note -> must also be adjusted in run_n_eval_tune.py +def _transform_features_in_complete_data(config, train_val, target): features = [x for x in train_val if x.startswith("F")] non_features = [x for x in train_val if x not in features] - # feature engineering method -> config ft_transformed = transform_features( train_val[features], config["data_transform"], @@ -16,8 +13,15 @@ def process_train(config, train_val, target, host_id, seed_data): ) train_val_t = train_val[non_features].join(ft_transformed) - # train & val split - for training purposes + return train_val_t, ft_transformed.columns + + +def process_train(config, train_val, target, host_id, seed_data): + train_val_t, feature_columns = _transform_features_in_complete_data( + config, train_val, target + ) + train, val = split_data_by_host(train_val_t, host_id, 0.8, seed_data) - X_train, y_train = train[ft_transformed.columns], train[target] - X_val, y_val = val[ft_transformed.columns], val[target] - return X_train.values, y_train.values, X_val.values, y_val.values + X_train, y_train = train[feature_columns], train[target] + X_val, y_val = val[feature_columns], val[target] + return X_train.values, y_train.values, X_val.values, y_val.values, feature_columns diff --git a/q2_ritme/model_space/_model_trac_calc.py b/q2_ritme/model_space/_model_trac_calc.py new file mode 100644 index 0000000..8e20665 --- /dev/null +++ b/q2_ritme/model_space/_model_trac_calc.py @@ -0,0 +1,40 @@ +import numpy as np +from numpy import linalg + + +def solve_unpenalized_least_squares(cmatrices, intercept=False): + # adapted from classo > misc_functions.py > unpenalised + if intercept: + A1, C1, y = cmatrices + A = np.concatenate([np.ones((len(A1), 1)), A1], axis=1) + C = np.concatenate([np.zeros((len(C1), 1)), C1], axis=1) + else: + A, C, y = cmatrices + + k = len(C) + d = len(A[0]) + M1 = np.concatenate([A.T.dot(A), C.T], axis=1) + M2 = np.concatenate([C, np.zeros((k, k))], axis=1) + M = np.concatenate([M1, M2], axis=0) + b = np.concatenate([A.T.dot(y), np.zeros(k)]) + sol = linalg.lstsq(M, b, rcond=None)[0] + beta = sol[:d] + return beta + + +def min_least_squares_solution(matrices, selected, intercept=False): + """Minimum Least Squares solution for selected features.""" + # adapted from classo > misc_functions.py > min_LS + X, C, y = matrices + beta = np.zeros(len(selected)) + + if intercept: + beta[selected] = solve_unpenalized_least_squares( + (X[:, selected[1:]], C[:, selected[1:]], y), intercept=selected[0] + ) + else: + beta[selected] = solve_unpenalized_least_squares( + (X[:, selected], C[:, selected], y), intercept=False + ) + + return beta diff --git a/q2_ritme/model_space/static_searchspace.py b/q2_ritme/model_space/static_searchspace.py index 1d2e045..27d4f1a 100644 --- a/q2_ritme/model_space/static_searchspace.py +++ b/q2_ritme/model_space/static_searchspace.py @@ -56,7 +56,7 @@ def get_rf_space(train_val): model="rf", **data_eng_space, **{ - "n_estimators": tune.randint(100, 500), + "n_estimators": tune.randint(50, 300), "max_depth": tune.randint(2, 32), "min_samples_split": tune.choice([0.001, 0.01, 0.1]), "min_samples_leaf": tune.choice([0.0001, 0.001]), @@ -103,6 +103,20 @@ def get_xgb_space(train_val): ) +def get_trac_space(train_val): + # no feature_transformation to be used for trac + data_eng_space = {"data_transform": None, "data_alr_denom_idx": None} + return dict( + model="trac", + **data_eng_space, + **{ + # with loguniform: sampled values are more densely concentrated + # towards the lower end of the range + "lambda": tune.loguniform(1e-3, 1.0) + }, + ) + + def get_search_space(train_val): return { "xgb": get_xgb_space(train_val), @@ -111,4 +125,5 @@ def get_search_space(train_val): "nn_corn": get_nn_space(train_val, "nn_corn"), "linreg": get_linreg_space(train_val), "rf": get_rf_space(train_val), + "trac": get_trac_space(train_val), } diff --git a/q2_ritme/model_space/static_trainables.py b/q2_ritme/model_space/static_trainables.py index 438d07d..f3c7d0a 100644 --- a/q2_ritme/model_space/static_trainables.py +++ b/q2_ritme/model_space/static_trainables.py @@ -1,6 +1,7 @@ """Module with tune trainables of all static models""" import os +import pickle import random from typing import Any, Dict @@ -8,8 +9,10 @@ import numpy as np import pandas as pd import ray +import skbio import torch import xgboost as xgb +from classo import Classo from coral_pytorch.dataset import corn_label_from_logits from coral_pytorch.losses import corn_loss from pytorch_lightning import LightningModule, Trainer, seed_everything @@ -26,7 +29,12 @@ from torch.optim import Adam from torch.utils.data import DataLoader, TensorDataset +from q2_ritme.feature_space._process_trac_specific import ( + _preprocess_taxonomy_aggregation, + create_matrix_from_tree, +) from q2_ritme.feature_space._process_train import process_train +from q2_ritme.model_space._model_trac_calc import min_least_squares_solution def _predict_rmse(model: BaseEstimator, X: np.ndarray, y: np.ndarray) -> float: @@ -103,6 +111,8 @@ def train_linreg( host_id: str, seed_data: int, seed_model: int, + tax: pd.DataFrame = pd.DataFrame(), + tree_phylo: skbio.TreeNode = skbio.TreeNode(), ) -> None: """ Train a linear regression model and report the results to Ray Tune. @@ -119,7 +129,7 @@ def train_linreg( None """ # ! process dataset: X with features & y with host_id - X_train, y_train, X_val, y_val = process_train( + X_train, y_train, X_val, y_val, ft_col = process_train( config, train_val, target, host_id, seed_data ) @@ -135,6 +145,97 @@ def train_linreg( _report_results_manually(linreg, X_train, y_train, X_val, y_val) +def _predict_rmse_trac(alpha, log_geom_X, y): + y_pred = log_geom_X.dot(alpha[1:]) + alpha[0] + return mean_squared_error(y, y_pred, squared=False) + + +def _report_results_manually_trac( + alpha, A_df, log_geom_train, y_train, log_geom_val, y_val +): + # save coefficients w labels & matrix A with labels -> model_path + idx_alpha = ["intercept"] + A_df.columns.tolist() + df_alpha_with_labels = pd.DataFrame(alpha, columns=["alpha"], index=idx_alpha) + + model = {"model": df_alpha_with_labels, "matrix_a": A_df} + + # save model + path_to_save = ray.train.get_context().get_trial_dir() + model_path = os.path.join(path_to_save, "model.pkl") + with open(model_path, "wb") as file: + pickle.dump(model, file) + + # calculate RMSE + score_train = _predict_rmse_trac(alpha, log_geom_train, y_train) + score_val = _predict_rmse_trac(alpha, log_geom_val, y_val) + + session.report( + metrics={ + "rmse_val": score_val, + "rmse_train": score_train, + "model_path": model_path, + } + ) + return None + + +def train_trac( + config: Dict[str, Any], + train_val: pd.DataFrame, + target: str, + host_id: str, + seed_data: int, + seed_model: int, + tax: pd.DataFrame, + tree_phylo: skbio.TreeNode, +) -> None: + """ + Train a trac model and report the results to Ray Tune. + + Parameters: + config (Dict[str, Any]): The configuration for the training. + train_val (DataFrame): The training and validation data. + target (str): The target variable. + host_id (str): The host ID. + seed_data (int): The seed for the data. + seed_model (int): The seed for the model. + + Returns: + None + """ + # ! process dataset: X with features & y with host_id + X_train, y_train, X_val, y_val, ft_col = process_train( + config, train_val, target, host_id, seed_data + ) + # ! derive matrix A + a_df = create_matrix_from_tree(tree_phylo, tax) + + # ! get log_geom + log_geom_train, nleaves = _preprocess_taxonomy_aggregation(X_train, a_df.values) + log_geom_val, _ = _preprocess_taxonomy_aggregation(X_val, a_df.values) + + # ! model + np.random.seed(seed_model) + matrices_train = (log_geom_train, np.ones((1, len(log_geom_train[0]))), y_train) + intercept = True + alpha_norefit = Classo( + matrix=matrices_train, + lam=config["lambda"], + typ="R1", + meth="Path-Alg", + w=1 / nleaves, + intercept=intercept, + ) + selected_param = abs(alpha_norefit) > 1e-5 + alpha = min_least_squares_solution( + matrices_train, selected_param, intercept=intercept + ) + + _report_results_manually_trac( + alpha, a_df, log_geom_train, y_train, log_geom_val, y_val + ) + + def train_rf( config: Dict[str, Any], train_val: pd.DataFrame, @@ -142,6 +243,8 @@ def train_rf( host_id: str, seed_data: int, seed_model: int, + tax: pd.DataFrame = pd.DataFrame(), + tree_phylo: skbio.TreeNode = skbio.TreeNode(), ) -> None: """ Train a random forest model and report the results to Ray Tune. @@ -158,7 +261,7 @@ def train_rf( None """ # ! process dataset - X_train, y_train, X_val, y_val = process_train( + X_train, y_train, X_val, y_val, ft_col = process_train( config, train_val, target, host_id, seed_data ) @@ -275,13 +378,19 @@ def load_data(X_train, y_train, X_val, y_val, y_type, config): def train_nn( - config, train_val, target, host_id, seed_data, seed_model, nn_type="regression" + config, + train_val, + target, + host_id, + seed_data, + seed_model, + nn_type="regression", ): # Set the seed for reproducibility seed_everything(seed_model, workers=True) # Process dataset - X_train, y_train, X_val, y_val = process_train( + X_train, y_train, X_val, y_val, ft_col = process_train( config, train_val, target, host_id, seed_data ) @@ -362,13 +471,17 @@ def train_nn( trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) -def train_nn_reg(config, train_val, target, host_id, seed_data, seed_model): +def train_nn_reg( + config, train_val, target, host_id, seed_data, seed_model, tax, tree_phylo +): train_nn( config, train_val, target, host_id, seed_data, seed_model, nn_type="regression" ) -def train_nn_class(config, train_val, target, host_id, seed_data, seed_model): +def train_nn_class( + config, train_val, target, host_id, seed_data, seed_model, tax, tree_phylo +): train_nn( config, train_val, @@ -380,7 +493,9 @@ def train_nn_class(config, train_val, target, host_id, seed_data, seed_model): ) -def train_nn_corn(config, train_val, target, host_id, seed_data, seed_model): +def train_nn_corn( + config, train_val, target, host_id, seed_data, seed_model, tax, tree_phylo +): # corn model from https://github.com/Raschka-research-group/coral-pytorch train_nn( config, @@ -400,6 +515,8 @@ def train_xgb( host_id: str, seed_data: int, seed_model: int, + tax: pd.DataFrame = pd.DataFrame(), + tree_phylo: skbio.TreeNode = skbio.TreeNode(), ) -> None: """ Train an XGBoost model and report the results to Ray Tune. @@ -416,7 +533,7 @@ def train_xgb( None """ # ! process dataset - X_train, y_train, X_val, y_val = process_train( + X_train, y_train, X_val, y_val, ft_col = process_train( config, train_val, target, host_id, seed_data ) # Set seeds diff --git a/q2_ritme/process_data.py b/q2_ritme/process_data.py index a282aa1..fa1a89a 100644 --- a/q2_ritme/process_data.py +++ b/q2_ritme/process_data.py @@ -1,6 +1,8 @@ import biom import pandas as pd import qiime2 as q2 +import skbio +from qiime2.plugins import phylogeny from sklearn.model_selection import GroupShuffleSplit # todo: adjust to json file to be read in from user @@ -37,7 +39,6 @@ def get_relative_abundance( columns=ft_rel_biom.ids(axis="observation"), ) - print(ft_rel.head()) # round needed as certain 1.0 are represented in different digits 2e-16 assert ft_rel[ft_cols].sum(axis=1).round(5).eq(1.0).all() @@ -86,6 +87,53 @@ def load_data( return ft, md +def load_tax_phylo( + path2tax: str, path2phylo: str, ft: pd.DataFrame +) -> (pd.DataFrame, skbio.TreeNode): + """ + Load taxonomy and phylogeny data. + """ + # todo: add option for simulated data + if path2tax and path2phylo: + # taxonomy + art_taxonomy = q2.Artifact.load(path2tax) + df_tax = art_taxonomy.view(pd.DataFrame) + # rename taxonomy to match new "F" feature names + df_tax.index = df_tax.index.map(lambda x: "F" + str(x)) + + # Filter the taxonomy based on the feature table + df_tax_f = df_tax[df_tax.index.isin(ft.columns.tolist())] + + if df_tax_f.shape[0] == 0: + raise ValueError("Taxonomy data does not match with feature table.") + + # phylogeny + art_phylo = q2.Artifact.load(path2phylo) + # filter tree by feature table: this prunes a phylogenetic tree to match + # the input ids + # Remove the first letter of each column name: "F" to match phylotree + ft_i = ft.copy() + ft_i.columns = [col[1:] for col in ft_i.columns] + art_ft_i = q2.Artifact.import_data("FeatureTable[RelativeFrequency]", ft_i) + + (art_phylo_f,) = phylogeny.actions.filter_tree(tree=art_phylo, table=art_ft_i) + tree_phylo_f = art_phylo_f.view(skbio.TreeNode) + + # add prefix "F" to leaf names in tree to remain consistent with ft + for node in tree_phylo_f.tips(): + node.name = "F" + node.name + + # ensure that # leaves in tree == feature table dimension + num_leaves = tree_phylo_f.count(tips=True) + assert num_leaves == ft.shape[1] + else: + # load empty variables + df_tax_f = pd.DataFrame() + tree_phylo_f = skbio.TreeNode() + + return df_tax_f, tree_phylo_f + + def filter_merge_n_sort( md: pd.DataFrame, ft: pd.DataFrame, @@ -152,12 +200,14 @@ def split_data_by_host( def load_n_split_data( path2md: str, path2ft: str, + path2tax: str, + path2phylo: str, host_id: str, target: str, train_size: float, seed: int, filter_md_cols: list = None, -) -> (pd.DataFrame, pd.DataFrame): +) -> (pd.DataFrame, pd.DataFrame, pd.DataFrame, skbio.TreeNode): """ Load, merge and sort data, then split into train-test sets by host_id. @@ -166,6 +216,10 @@ def load_n_split_data( is used. path2ft (str, optional): Path to features file. If None, simulated data is used. + path2tax (str, optional): Path to taxonomy file. If None, model options + requiring taxonomy can't be run. + path2phylo (str, optional): Path to phylogeny file. If None, model + options requiring taxonomy can't be run. host_id (str, optional): ID of the host. Default is HOST_ID from config. target (str, optional): Name of target variable. Default is TARGET from config. @@ -177,13 +231,16 @@ def load_n_split_data( SEED_DATA from config. Returns: - tuple: A tuple containing train and test dataframes. + : Train and test dataframes as well as matching taxonomy and phylogeny. """ ft, md = load_data(path2md, path2ft, target) + # tax: n_features x ("Taxon", "Confidence") + # tree_phylo: n_features leaf nodes + tax, tree_phylo = load_tax_phylo(path2tax, path2phylo, ft) data = filter_merge_n_sort(md, ft, host_id, target, filter_md_cols) # todo: add split also by study_id train_val, test = split_data_by_host(data, host_id, train_size, seed) - return train_val, test + return train_val, test, tax, tree_phylo diff --git a/q2_ritme/run_config.json b/q2_ritme/run_config.json index c20fb1b..1f4ca2f 100644 --- a/q2_ritme/run_config.json +++ b/q2_ritme/run_config.json @@ -1,8 +1,9 @@ { - "experiment_tag": "test_synthetic", + "experiment_tag": "run_config", "host_id": "host_id", "ls_model_types": [ "linreg", + "trac", "xgb", "nn_reg", "nn_class", @@ -16,9 +17,11 @@ "nn_class", "nn_corn" ], - "num_trials": 1, - "path_to_ft": null, - "path_to_md": null, + "num_trials": 10, + "path_to_ft": "experiments/data/all_otu_table_filt.qza", + "path_to_md": "experiments/data/metadata_proc_v20240323_r0_r3_le_2yrs.tsv", + "path_to_phylo": "experiments/data/silva-138-99-rooted-tree.qza", + "path_to_tax": "experiments/data/otu_taxonomy_all.qza", "seed_data": 12, "seed_model": 12, "target": "age_months", diff --git a/q2_ritme/run_n_eval_tune.py b/q2_ritme/run_n_eval_tune.py index e12ec63..a7aeac6 100644 --- a/q2_ritme/run_n_eval_tune.py +++ b/q2_ritme/run_n_eval_tune.py @@ -43,14 +43,15 @@ def run_n_eval_tune(config_path): "Please use another one." ) - # todo: flag mlflow runs also with experiment tag somehow path_mlflow = os.path.join("experiments", config["mlflow_tracking_uri"]) path_exp = os.path.join(base_path, config["experiment_tag"]) # ! Load and split data - train_val, test = load_n_split_data( + train_val, test, tax, tree_phylo = load_n_split_data( config["path_to_md"], config["path_to_ft"], + config["path_to_tax"], + config["path_to_phylo"], config["host_id"], config["target"], config["train_size"], @@ -64,6 +65,8 @@ def run_n_eval_tune(config_path): config["host_id"], config["seed_data"], config["seed_model"], + tax, + tree_phylo, path_mlflow, path_exp, # number of trials to run per model type * grid_search parameters in diff --git a/q2_ritme/tests/test_feature_space.py b/q2_ritme/tests/test_feature_space.py index 58a8cd7..0b12576 100644 --- a/q2_ritme/tests/test_feature_space.py +++ b/q2_ritme/tests/test_feature_space.py @@ -5,8 +5,16 @@ from pandas.testing import assert_frame_equal from qiime2.plugin.testing import TestPluginBase from scipy.stats.mstats import gmean +from skbio import TreeNode from skbio.stats.composition import ilr +from q2_ritme.feature_space._process_trac_specific import ( + _create_identity_matrix_for_leaves, + _create_matrix_for_internal_nodes, + _get_internal_nodes, + _get_leaves_and_index_map, + create_matrix_from_tree, +) from q2_ritme.feature_space._process_train import process_train from q2_ritme.feature_space.transform_features import ( PSEUDOCOUNT, @@ -16,7 +24,7 @@ class TestTransformFeatures(TestPluginBase): - package = "q2_ritme.test" + package = "q2_ritme.tests" def setUp(self): super().setUp() @@ -119,7 +127,7 @@ def test_transform_features_error(self): class TestProcessTrain(TestPluginBase): - package = "q2_ritme.test" + package = "q2_ritme.tests" def setUp(self): super().setUp() @@ -157,7 +165,7 @@ def test_process_train(self, mock_split_data_by_host, mock_transform_features): ) # Act - X_train, y_train, X_val, y_val = process_train( + X_train, y_train, X_val, y_val, ft_col = process_train( self.config, self.train_val, self.target, self.host_id, self.seed_data ) @@ -170,3 +178,107 @@ def test_process_train(self, mock_split_data_by_host, mock_transform_features): 0.8, 0, ) + + +class TestProcessTracSpecific(TestPluginBase): + package = "q2_ritme.tests" + + def setUp(self): + super().setUp() + self.tree = self._build_example_tree() + self.tax = self._build_example_taxonomy() + + def _build_example_taxonomy(self): + tax = pd.DataFrame( + { + "Taxon": [ + "d__Bacteria; p__Chloroflexi; c__Anaerolineae; o__SBR1031; " + "f__SBR1031; g__SBR1031; s__anaerobic_digester", + "d__Bacteria; p__Chloroflexi; c__Anaerolineae; o__SBR1031; " + "f__SBR1031; g__SBR1031; s__uncultured_bacterium", + "d__Bacteria; p__Chloroflexi; c__Anaerolineae; o__SBR1031", + ], + "Confidence": [0.9, 0.9, 0.9], + } + ) + tax.index = ["F1", "F2", "F3"] + tax.index.name = "Feature ID" + return tax + + def _build_example_tree(self): + # Create the tree nodes with lengths + n1 = TreeNode(name="node1") + f1 = TreeNode(name="F1", length=1.0) + f2 = TreeNode(name="F2", length=1.0) + n2 = TreeNode(name="node2") + f3 = TreeNode(name="F3", length=1.0) + + # Build the tree structure with lengths + n1.extend([f1, f2]) + n2.extend([n1, f3]) + n1.length = 1.0 + n2.length = 1.0 + + # n2 is the root of this tree + tree = n2 + + return tree + + def test_get_leaves_and_index_map(self): + leaves, leaf_index_map = _get_leaves_and_index_map(self.tree) + self.assertEqual(len(leaves), 3) + self.assertEqual(leaf_index_map, {"F1": 0, "F2": 1, "F3": 2}) + + def test_get_internal_nodes(self): + internal_nodes = _get_internal_nodes(self.tree) + self.assertEqual(len(internal_nodes), 1) + self.assertEqual(internal_nodes[0].name, "node1") + + def test_create_identity_matrix_for_leaves(self): + leaves = list(self.tree.tips()) + A1, a1_node_names = _create_identity_matrix_for_leaves(3, self.tax, leaves) + np.testing.assert_array_equal(A1, np.eye(3)) + self.assertEqual( + a1_node_names, + [ + "d__Bacteria; p__Chloroflexi; c__Anaerolineae; o__SBR1031; " + "f__SBR1031; g__SBR1031; s__anaerobic_digester; otu__F1", + "d__Bacteria; p__Chloroflexi; c__Anaerolineae; o__SBR1031; " + "f__SBR1031; g__SBR1031; s__uncultured_bacterium; otu__F2", + "d__Bacteria; p__Chloroflexi; c__Anaerolineae; o__SBR1031; otu__F3", + ], + ) + + def test_create_matrix_for_internal_nodes(self): + leaves, leaf_index_map = _get_leaves_and_index_map(self.tree) + internal_nodes = _get_internal_nodes(self.tree) + A2, a2_node_names = _create_matrix_for_internal_nodes( + 3, internal_nodes, leaf_index_map, self.tax + ) + np.testing.assert_array_equal(A2, np.array([[1], [1], [0]])) + self.assertEqual( + a2_node_names, + [ + "d__Bacteria; p__Chloroflexi; c__Anaerolineae; o__SBR1031; " + "f__SBR1031; g__SBR1031" + ], + ) + + def test_create_matrix_from_tree(self): + ma_exp = np.array( + [[1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0.0]] + ) + node_taxon_names = [ + "d__Bacteria; p__Chloroflexi; c__Anaerolineae; o__SBR1031; " + "f__SBR1031; g__SBR1031" + ] + leaf_names = (self.tax["Taxon"] + "; otu__" + self.tax.index).values.tolist() + ft_names = ["F1", "F2", "F3"] + ma_exp = pd.DataFrame( + ma_exp, + columns=leaf_names + node_taxon_names, + index=ft_names, + ) + ma_act = create_matrix_from_tree(self.tree, self.tax) + + assert_frame_equal(ma_exp, ma_act) diff --git a/q2_ritme/tests/test_process_data.py b/q2_ritme/tests/test_process_data.py index a61e858..97ceaea 100644 --- a/q2_ritme/tests/test_process_data.py +++ b/q2_ritme/tests/test_process_data.py @@ -16,7 +16,7 @@ class TestProcessData(TestPluginBase): - package = "q2_ritme.test" + package = "q2_ritme.tests" def setUp(self): super().setUp() @@ -169,9 +169,11 @@ def test_split_data_by_host_error_one_host(self): def test_load_n_split_data(self): # Call the function with the test paths - train_val, test = load_n_split_data( + train_val, test, tax, tree_phylo = load_n_split_data( self.tmp_md_path, self.tmp_ft_rel_path, + None, + None, host_id="host_id", target="supertarget", train_size=0.8, @@ -179,10 +181,14 @@ def test_load_n_split_data(self): filter_md_cols=["host_id", "supertarget"], ) - # Check that the dataframes are not empty + # Check that the train + test dataframes are not empty self.assertFalse(train_val.empty) self.assertFalse(test.empty) # Check that the dataframes do not overlap overlap = pd.merge(train_val, test, how="inner") self.assertEqual(len(overlap), 0) + + # tax and phylo should be empty in this case + self.assertTrue(tax.empty) + self.assertTrue(tree_phylo.children == []) diff --git a/q2_ritme/tests/test_static_searchspace.py b/q2_ritme/tests/test_static_searchspace.py index 46cf4c4..86e93cb 100644 --- a/q2_ritme/tests/test_static_searchspace.py +++ b/q2_ritme/tests/test_static_searchspace.py @@ -72,6 +72,15 @@ def test_get_linreg_space(self): self.assertIn("alpha", linreg_space) self.assertIn("l1_ratio", linreg_space) + def test_get_trac_space(self): + trac_space = ss.get_trac_space(self.train_val) + + self.assertIsInstance(trac_space, dict) + self.assertEqual(trac_space["model"], "trac") + self.assertEqual(trac_space["data_transform"], None) + self.assertEqual(trac_space["data_alr_denom_idx"], None) + self.assertIn("lambda", trac_space) + def test_get_rf_space(self): rf_space = ss.get_rf_space(self.train_val) diff --git a/q2_ritme/tests/test_static_trainables.py b/q2_ritme/tests/test_static_trainables.py index 3530863..02d5e64 100644 --- a/q2_ritme/tests/test_static_trainables.py +++ b/q2_ritme/tests/test_static_trainables.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd +import skbio import torch from qiime2.plugin.testing import TestPluginBase from sklearn.linear_model import LinearRegression @@ -80,7 +81,7 @@ def test_train_linreg(self, mock_report, mock_linreg, mock_process_train): # define input parameters config = {"fit_intercept": True, "alpha": 0.1, "l1_ratio": 0.5} - mock_process_train.return_value = (None, None, None, None) + mock_process_train.return_value = (None, None, None, None, None) mock_linreg_instance = mock_linreg.return_value # run model @@ -105,6 +106,67 @@ def test_train_linreg(self, mock_report, mock_linreg, mock_process_train): mock_linreg_instance.fit.assert_called_once() mock_report.assert_called_once() + @patch("q2_ritme.model_space.static_trainables.process_train") + @patch("q2_ritme.model_space.static_trainables.create_matrix_from_tree") + @patch("q2_ritme.model_space.static_trainables._preprocess_taxonomy_aggregation") + @patch("q2_ritme.model_space.static_trainables.Classo") + @patch("q2_ritme.model_space.static_trainables.min_least_squares_solution") + @patch("q2_ritme.model_space.static_trainables._report_results_manually_trac") + def test_train_trac( + self, + mock_report, + mock_min_least_squares, + mock_classo, + mock_preprocess_taxonomy, + mock_create_matrix, + mock_process_train, + ): + # Arrange + config = {"lambda": 0.1} + mock_process_train.return_value = (None, None, None, None, None) + mock_create_matrix.return_value = pd.DataFrame() + mock_preprocess_taxonomy.side_effect = [ + (np.array([[1, 2], [3, 4]]), 2), + (np.array([[5, 6], [7, 8]]), 2), + ] + mock_classo.return_value = np.array([0.1, 0.2]) + mock_min_least_squares.return_value = np.array([0.1, 0.2]) + + # Act + st.train_trac( + config, + self.train_val, + self.target, + self.host_id, + self.seed_data, + self.seed_model, + pd.DataFrame(), + skbio.TreeNode(), + ) + + # Assert + mock_process_train.assert_called_once_with( + config, self.train_val, self.target, self.host_id, self.seed_data + ) + mock_create_matrix.assert_called_once() + assert mock_preprocess_taxonomy.call_count == 2 + + # mock_classo.assert_called_once_with doesn't work because matrix is a + # numpy array + kwargs = mock_classo.call_args.kwargs + + self.assertTrue(np.array_equal(kwargs["matrix"][0], np.array([[1, 2], [3, 4]]))) + self.assertTrue(np.array_equal(kwargs["matrix"][1], np.ones((1, 2)))) + self.assertIsNone(kwargs["matrix"][2]) + self.assertEqual(kwargs["lam"], config["lambda"]) + self.assertEqual(kwargs["typ"], "R1") + self.assertEqual(kwargs["meth"], "Path-Alg") + self.assertEqual(kwargs["w"], 0.5) + self.assertEqual(kwargs["intercept"], True) + + mock_min_least_squares.assert_called_once() + mock_report.assert_called_once() + @patch("q2_ritme.model_space.static_trainables.process_train") @patch("q2_ritme.model_space.static_trainables.RandomForestRegressor") @patch("q2_ritme.model_space.static_trainables._report_results_manually") @@ -112,7 +174,7 @@ def test_train_rf(self, mock_report, mock_rf, mock_process_train): # Arrange config = {"n_estimators": 100, "max_depth": 10} - mock_process_train.return_value = (None, None, None, None) + mock_process_train.return_value = (None, None, None, None, None) mock_rf_instance = mock_rf.return_value # Act @@ -159,11 +221,13 @@ def test_train_xgb( mock_train_y = mock_train[self.target].values mock_test_x = mock_test[["F1", "F2"]].values mock_test_y = mock_test[self.target].values + mock_ft_cols = ["F1", "F2"] mock_process_train.return_value = ( mock_train_x, mock_train_y, mock_test_x, mock_test_y, + mock_ft_cols, ) mock_dmatrix.return_value = None @@ -209,6 +273,7 @@ def test_train_nn( torch.rand(10), torch.rand(3, 5), torch.rand(3), + ["F1", "F2", "F3", "F4", "F5"], ) mock_load_data.return_value = (MagicMock(), MagicMock()) mock_trainer_instance = mock_trainer.return_value diff --git a/q2_ritme/tune_models.py b/q2_ritme/tune_models.py index 1212111..e40eeb8 100644 --- a/q2_ritme/tune_models.py +++ b/q2_ritme/tune_models.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd +import skbio import torch from ray import air, init, shutdown, tune from ray.air.integrations.mlflow import MLflowLoggerCallback @@ -19,6 +20,7 @@ "nn_corn": st.train_nn_corn, "linreg": st.train_linreg, "rf": st.train_rf, + "trac": st.train_trac, } @@ -39,6 +41,8 @@ def run_trials( host_id, seed_data, seed_model, + tax, + tree_phylo, path2exp, num_trials, fully_reproducible=False, # if True hyperband instead of ASHA scheduler is used @@ -97,6 +101,8 @@ def run_trials( host_id=host_id, seed_data=seed_data, seed_model=seed_model, + tax=tax, + tree_phylo=tree_phylo, ), resources, ), @@ -147,14 +153,33 @@ def run_all_trials( host_id: str, seed_data: int, seed_model: int, + tax: pd.DataFrame, + tree_phylo: skbio.TreeNode, mlflow_uri: str, path_exp: str, num_trials: int, - model_types: list = ["xgb", "nn_reg", "nn_class", "nn_corn", "linreg", "rf"], + model_types: list = [ + "xgb", + "nn_reg", + "nn_class", + "nn_corn", + "linreg", + "rf", + "trac", + ], fully_reproducible: bool = False, ) -> dict: results_all = {} model_search_space = ss.get_search_space(train_val) + + # if tax + phylogeny empty we can't run trac + if tax.empty or tree_phylo.children == []: + model_types.remove("trac") + print( + "Removing trac from model_types since no taxonomy and phylogeny were " + "provided." + ) + for model in model_types: # todo: parallelize this for loop if not os.path.exists(path_exp): @@ -170,6 +195,8 @@ def run_all_trials( host_id, seed_data, seed_model, + tax, + tree_phylo, path_exp, num_trials, fully_reproducible=fully_reproducible,