diff --git a/q2_ritme/model_space/_static_trainables.py b/q2_ritme/model_space/_static_trainables.py index 61db63e..934de93 100644 --- a/q2_ritme/model_space/_static_trainables.py +++ b/q2_ritme/model_space/_static_trainables.py @@ -106,8 +106,8 @@ def train_linreg( host_id: str, seed_data: int, seed_model: int, - tax: pd.DataFrame, - tree_phylo: skbio.TreeNode, + tax: pd.DataFrame = pd.DataFrame(), + tree_phylo: skbio.TreeNode = skbio.TreeNode(), ) -> None: """ Train a linear regression model and report the results to Ray Tune. @@ -249,8 +249,8 @@ def train_rf( host_id: str, seed_data: int, seed_model: int, - tax: pd.DataFrame, - tree_phylo: skbio.TreeNode, + tax: pd.DataFrame = pd.DataFrame(), + tree_phylo: skbio.TreeNode = skbio.TreeNode(), ) -> None: """ Train a random forest model and report the results to Ray Tune. @@ -521,8 +521,8 @@ def train_xgb( host_id: str, seed_data: int, seed_model: int, - tax: pd.DataFrame, - tree_phylo: skbio.TreeNode, + tax: pd.DataFrame = pd.DataFrame(), + tree_phylo: skbio.TreeNode = skbio.TreeNode(), ) -> None: """ Train an XGBoost model and report the results to Ray Tune. diff --git a/q2_ritme/process_data.py b/q2_ritme/process_data.py index d7571df..fa1a89a 100644 --- a/q2_ritme/process_data.py +++ b/q2_ritme/process_data.py @@ -127,9 +127,10 @@ def load_tax_phylo( num_leaves = tree_phylo_f.count(tips=True) assert num_leaves == ft.shape[1] else: - raise ValueError( - "Simulation of taxonomy and phylogeny data not implemented yet." - ) + # load empty variables + df_tax_f = pd.DataFrame() + tree_phylo_f = skbio.TreeNode() + return df_tax_f, tree_phylo_f @@ -215,10 +216,10 @@ def load_n_split_data( is used. path2ft (str, optional): Path to features file. If None, simulated data is used. - path2tax (str, optional): Path to taxonomy file. If None, simulated data - is used. - path2phylo (str, optional): Path to phylogeny file. If None, simulated data - is used. + path2tax (str, optional): Path to taxonomy file. If None, model options + requiring taxonomy can't be run. + path2phylo (str, optional): Path to phylogeny file. If None, model + options requiring taxonomy can't be run. host_id (str, optional): ID of the host. Default is HOST_ID from config. target (str, optional): Name of target variable. Default is TARGET from config. diff --git a/q2_ritme/tests/test_process_data.py b/q2_ritme/tests/test_process_data.py index a61e858..c0b0a4e 100644 --- a/q2_ritme/tests/test_process_data.py +++ b/q2_ritme/tests/test_process_data.py @@ -169,9 +169,11 @@ def test_split_data_by_host_error_one_host(self): def test_load_n_split_data(self): # Call the function with the test paths - train_val, test = load_n_split_data( + train_val, test, tax, tree_phylo = load_n_split_data( self.tmp_md_path, self.tmp_ft_rel_path, + None, + None, host_id="host_id", target="supertarget", train_size=0.8, @@ -179,10 +181,14 @@ def test_load_n_split_data(self): filter_md_cols=["host_id", "supertarget"], ) - # Check that the dataframes are not empty + # Check that the train + test dataframes are not empty self.assertFalse(train_val.empty) self.assertFalse(test.empty) # Check that the dataframes do not overlap overlap = pd.merge(train_val, test, how="inner") self.assertEqual(len(overlap), 0) + + # tax and phylo should be empty in this case + self.assertTrue(tax.empty) + self.assertTrue(tree_phylo.children == []) diff --git a/q2_ritme/tune_models.py b/q2_ritme/tune_models.py index 4447b5e..4037a0b 100644 --- a/q2_ritme/tune_models.py +++ b/q2_ritme/tune_models.py @@ -171,6 +171,15 @@ def run_all_trials( ) -> dict: results_all = {} model_search_space = ss.get_search_space(train_val) + + # if tax + phylogeny empty we can't run trac + if tax.empty or tree_phylo.children == []: + model_types.remove("trac") + print( + "Removing trac from model_types since no taxonomy and phylogeny were " + "provided." + ) + for model in model_types: # todo: parallelize this for loop if not os.path.exists(path_exp):