Skip to content

Commit

Permalink
Update round_params too.
Browse files Browse the repository at this point in the history
  • Loading branch information
tsalo committed Feb 5, 2025
1 parent 10ce7c7 commit 9897b03
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 65 deletions.
104 changes: 47 additions & 57 deletions cubids/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,54 @@

import pandas as pd

from cubids.utils import cluster_single_parameters
from cubids import utils
from cubids.tests.utils import compare_group_assignments


def test_round_params():
"""Test the cubids.utils.round_params function."""
# Example DataFrame
df = pd.DataFrame(
{
"A": [1.12345, 2.23456, 3.34567],
"B": [[1.12345, 2.23456], [3.34567, 4.45678], [5.56789, 6.67890]],
"C": ["text", "more text", "even more text"],
"D": [1.12345, 2.23456, 3.34567],
}
)

# Example config
config = {
"sidecar_params": {
"modality1": {
"A": {"precision": 2},
"B": {"precision": 2},
},
},
"derived_params": {
"modality1": {},
},
}

# Expected DataFrame after rounding
expected_df = pd.DataFrame(
{
"A": [1.12, 2.23, 3.35],
"B": [[1.12, 2.23], [3.35, 4.46], [5.57, 6.68]],
"C": ["text", "more text", "even more text"],
"D": [1.12345, 2.23456, 3.34567],
}
)

# Round columns
rounded_df = utils.round_params(df, config, "modality1")

# Assert that the rounded DataFrame matches the expected DataFrame
pd.testing.assert_frame_equal(rounded_df, expected_df)


def test_cluster_single_parameters():
"""Test the cluster_single_parameters function.
"""Test the cubids.utils.cluster_single_parameters function.
We want to test that the function correctly clusters parameters based on the
configuration dictionary.
Expand Down Expand Up @@ -86,7 +129,7 @@ def test_cluster_single_parameters():
modality = "func"

# Run the function
out_df = cluster_single_parameters(
out_df = utils.cluster_single_parameters(
files_df=files_df,
config=config,
modality=modality,
Expand All @@ -113,7 +156,7 @@ def test_cluster_single_parameters():

# Change the tolerance for SliceTiming
config["sidecar_params"]["func"]["SliceTiming"]["tolerance"] = 0.5
out_df = cluster_single_parameters(
out_df = utils.cluster_single_parameters(
files_df=files_df,
config=config,
modality=modality,
Expand All @@ -139,56 +182,3 @@ def test_cluster_single_parameters():
out_df["Cluster_ImageType"].values.astype(int),
[0, 0, 0, 0, 0, 0, 1, 2],
)


def compare_group_assignments(list1, list2):
"""Compare two lists for equality based on group assignments.
This function checks if two lists can be considered equal based on their group assignments.
The actual values in the lists do not matter, only the group assignments do. Each unique value
in the first list is mapped to a unique value in the second list, and the function checks if
this mapping is consistent throughout the lists.
Parameters
----------
list1 : list
The first list to compare.
list2 : list
The second list to compare.
Returns
-------
bool
True if the lists are equal based on group assignments, False otherwise.
Examples
--------
>>> list1 = [1, 2, 1, 3, 2]
>>> list2 = ['a', 'b', 'a', 'c', 'b']
>>> compare_group_assignments(list1, list2)
True
>>> list1 = [1, 2, 1, 3, 2]
>>> list2 = ['b', 'd', 'b', 'q', 'd']
>>> compare_group_assignments(list1, list2)
True
>>> list1 = [1, 2, 1, 3, 2]
>>> list2 = ['a', 'b', 'a', 'c', 'd']
>>> compare_group_assignments(list1, list2)
False
"""
if len(list1) != len(list2):
return False

mapping = {}
for a, b in zip(list1, list2):
if a in mapping:
if mapping[a] != b:
return False
else:
if b in mapping.values():
return False
mapping[a] = b

return True
53 changes: 53 additions & 0 deletions cubids/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,56 @@ def chdir(path):
yield
finally:
os.chdir(oldpwd)


def compare_group_assignments(list1, list2):
"""Compare two lists for equality based on group assignments.
This function checks if two lists can be considered equal based on their group assignments.
The actual values in the lists do not matter, only the group assignments do. Each unique value
in the first list is mapped to a unique value in the second list, and the function checks if
this mapping is consistent throughout the lists.
Parameters
----------
list1 : list
The first list to compare.
list2 : list
The second list to compare.
Returns
-------
bool
True if the lists are equal based on group assignments, False otherwise.
Examples
--------
>>> list1 = [1, 2, 1, 3, 2]
>>> list2 = ['a', 'b', 'a', 'c', 'b']
>>> compare_group_assignments(list1, list2)
True
>>> list1 = [1, 2, 1, 3, 2]
>>> list2 = ['b', 'd', 'b', 'q', 'd']
>>> compare_group_assignments(list1, list2)
True
>>> list1 = [1, 2, 1, 3, 2]
>>> list2 = ['a', 'b', 'a', 'c', 'd']
>>> compare_group_assignments(list1, list2)
False
"""
if len(list1) != len(list2):
return False

mapping = {}
for a, b in zip(list1, list2):
if a in mapping:
if mapping[a] != b:
return False
else:
if b in mapping.values():
return False
mapping[a] = b

return True
19 changes: 11 additions & 8 deletions cubids/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,13 +377,13 @@ def _get_param_groups(
return ordered_labeled_files, param_groups_with_counts


def round_params(param_group_df, config, modality):
def round_params(df, config, modality):
"""Round columns' values in a DataFrame according to requested precision.
Parameters
----------
param_group_df : pandas.DataFrame
DataFrame containing the parameters to be rounded.
df : pandas.DataFrame
DataFrame containing the parameters to be rounded, with one row per file.
config : dict
Configuration dictionary containing rounding precision information.
modality : str
Expand All @@ -398,16 +398,19 @@ def round_params(param_group_df, config, modality):
to_format.update(config["derived_params"][modality])

for column_name, column_fmt in to_format.items():
if column_name not in param_group_df:
if column_name not in df:
continue

if "precision" in column_fmt:
if isinstance(param_group_df[column_name], float):
param_group_df[column_name] = param_group_df[column_name].round(
column_fmt["precision"]
precision = column_fmt["precision"]
if isinstance(df[column_name], float):
df[column_name] = df[column_name].round(precision)
elif df[column_name].apply(lambda x: isinstance(x, (list, np.ndarray))).any():
df[column_name] = df[column_name].apply(
lambda x: np.round(x, precision) if isinstance(x, (list, np.ndarray)) else x
)

return param_group_df
return df


def get_sidecar_metadata(json_file):
Expand Down

0 comments on commit 9897b03

Please sign in to comment.