Skip to content

Commit

Permalink
fix empty tax + phylo
Browse files Browse the repository at this point in the history
  • Loading branch information
adamovanja committed May 21, 2024
1 parent b353309 commit 9195392
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 15 deletions.
12 changes: 6 additions & 6 deletions q2_ritme/model_space/_static_trainables.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def train_linreg(
host_id: str,
seed_data: int,
seed_model: int,
tax: pd.DataFrame,
tree_phylo: skbio.TreeNode,
tax: pd.DataFrame = pd.DataFrame(),
tree_phylo: skbio.TreeNode = skbio.TreeNode(),
) -> None:
"""
Train a linear regression model and report the results to Ray Tune.
Expand Down Expand Up @@ -249,8 +249,8 @@ def train_rf(
host_id: str,
seed_data: int,
seed_model: int,
tax: pd.DataFrame,
tree_phylo: skbio.TreeNode,
tax: pd.DataFrame = pd.DataFrame(),
tree_phylo: skbio.TreeNode = skbio.TreeNode(),
) -> None:
"""
Train a random forest model and report the results to Ray Tune.
Expand Down Expand Up @@ -521,8 +521,8 @@ def train_xgb(
host_id: str,
seed_data: int,
seed_model: int,
tax: pd.DataFrame,
tree_phylo: skbio.TreeNode,
tax: pd.DataFrame = pd.DataFrame(),
tree_phylo: skbio.TreeNode = skbio.TreeNode(),
) -> None:
"""
Train an XGBoost model and report the results to Ray Tune.
Expand Down
15 changes: 8 additions & 7 deletions q2_ritme/process_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,10 @@ def load_tax_phylo(
num_leaves = tree_phylo_f.count(tips=True)
assert num_leaves == ft.shape[1]

Check warning on line 128 in q2_ritme/process_data.py

View check run for this annotation

Codecov / codecov/patch

q2_ritme/process_data.py#L127-L128

Added lines #L127 - L128 were not covered by tests
else:
raise ValueError(
"Simulation of taxonomy and phylogeny data not implemented yet."
)
# load empty variables
df_tax_f = pd.DataFrame()
tree_phylo_f = skbio.TreeNode()

return df_tax_f, tree_phylo_f


Expand Down Expand Up @@ -215,10 +216,10 @@ def load_n_split_data(
is used.
path2ft (str, optional): Path to features file. If None, simulated data
is used.
path2tax (str, optional): Path to taxonomy file. If None, simulated data
is used.
path2phylo (str, optional): Path to phylogeny file. If None, simulated data
is used.
path2tax (str, optional): Path to taxonomy file. If None, model options
requiring taxonomy can't be run.
path2phylo (str, optional): Path to phylogeny file. If None, model
options requiring taxonomy can't be run.
host_id (str, optional): ID of the host. Default is HOST_ID from config.
target (str, optional): Name of target variable. Default is TARGET from
config.
Expand Down
10 changes: 8 additions & 2 deletions q2_ritme/tests/test_process_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,20 +169,26 @@ def test_split_data_by_host_error_one_host(self):

def test_load_n_split_data(self):
# Call the function with the test paths
train_val, test = load_n_split_data(
train_val, test, tax, tree_phylo = load_n_split_data(
self.tmp_md_path,
self.tmp_ft_rel_path,
None,
None,
host_id="host_id",
target="supertarget",
train_size=0.8,
seed=123,
filter_md_cols=["host_id", "supertarget"],
)

# Check that the dataframes are not empty
# Check that the train + test dataframes are not empty
self.assertFalse(train_val.empty)
self.assertFalse(test.empty)

# Check that the dataframes do not overlap
overlap = pd.merge(train_val, test, how="inner")
self.assertEqual(len(overlap), 0)

# tax and phylo should be empty in this case
self.assertTrue(tax.empty)
self.assertTrue(tree_phylo.children == [])
9 changes: 9 additions & 0 deletions q2_ritme/tune_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,15 @@ def run_all_trials(
) -> dict:
results_all = {}
model_search_space = ss.get_search_space(train_val)

# if tax + phylogeny empty we can't run trac
if tax.empty or tree_phylo.children == []:
model_types.remove("trac")
print(

Check warning on line 178 in q2_ritme/tune_models.py

View check run for this annotation

Codecov / codecov/patch

q2_ritme/tune_models.py#L177-L178

Added lines #L177 - L178 were not covered by tests
"Removing trac from model_types since no taxonomy and phylogeny were "
"provided."
)

for model in model_types:
# todo: parallelize this for loop
if not os.path.exists(path_exp):
Expand Down

0 comments on commit 9195392

Please sign in to comment.