From 976f58d394c58ab040d818c2d0e47e5d2758558b Mon Sep 17 00:00:00 2001 From: "Lucas Shen Y. S" Date: Fri, 15 Dec 2023 12:16:41 +0800 Subject: [PATCH 1/2] Add tests for `mplot_dataframe_utils` (for #88) (#90) --- .github/workflows/CI.yml | 2 +- tests/test_mplot_dataframe_utils.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 03c721f..0140192 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -2,7 +2,7 @@ name: CI on: push: - branches: [ "main", "docs", "patch", "feature", "mplot" ] + branches: [ "main", "docs", "patch", "feature", "mplot", "mplot-dev" ] pull_request: branches: [ "main" ] diff --git a/tests/test_mplot_dataframe_utils.py b/tests/test_mplot_dataframe_utils.py index e5f81ee..a146da1 100644 --- a/tests/test_mplot_dataframe_utils.py +++ b/tests/test_mplot_dataframe_utils.py @@ -1,5 +1,6 @@ import numpy as np import pandas as pd +from pandas.testing import assert_frame_equal from forestplot.mplot_dataframe_utils import ( _insert_headers_models, @@ -48,7 +49,7 @@ def test_insert_group_model(): result_df = insert_group_model(df, "groupvar", "varlabel", "model_col") # Assert - pd.testing.assert_frame_equal(result_df, expected_df) + assert_frame_equal(result_df, expected_df) def test_insert_headers_models(): @@ -72,9 +73,7 @@ def test_insert_headers_models(): result = _insert_headers_models(df, "model_col", None) # Verify - pd.testing.assert_frame_equal( - result.reset_index(drop=True), expected_output.reset_index(drop=True) - ) + assert_frame_equal(result.reset_index(drop=True), expected_output.reset_index(drop=True)) def test_make_multimodel_tableheaders(): @@ -205,4 +204,6 @@ def test_make_multimodel_tableheaders(): right_annoteheaders=None, ) # Verify - pd.testing.assert_frame_equal(df_result, df_expected) + assert_frame_equal(df_result.iloc[:, :4], df_expected.iloc[:, :4]) + assert pd.notna(df_result.loc[0, "yticklabel"]) + assert pd.notna(df_result.loc[0, "yticklabel2"]) From 794f49132b6bf352867e17766d4b7ca24c3d932f Mon Sep 17 00:00:00 2001 From: "Lucas Shen Y. S" Date: Sat, 16 Dec 2023 12:13:56 +0800 Subject: [PATCH 2/2] Add docstring & test for mdraw_yticklabels (#88, #89) (#91) * Troubleshooting workflow error (#87) Pytest showing nan==nan as error for py3.9 and py3.10. * Add branch to workflow (#88) * Testing make_multimodel_tableheaders (#88) * Add test_make_multimodel_tableheaders (#88) * Pleasing linters (#88) * Add docstring & test for (#88, #89) * Fix compatibility with newer mpl versions (#82) * Pleasing linters * Pleasing linters --- forestplot/mplot_graph_utils.py | 43 +++++++++++++++++++++------ tests/test_mplot_graph_utils.py | 52 +++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 9 deletions(-) create mode 100644 tests/test_mplot_graph_utils.py diff --git a/forestplot/mplot_graph_utils.py b/forestplot/mplot_graph_utils.py index a10bdd3..edd381a 100644 --- a/forestplot/mplot_graph_utils.py +++ b/forestplot/mplot_graph_utils.py @@ -50,18 +50,37 @@ def mdraw_ref_xline( return ax -# ============================================================================================= -# ============================================================================================= -# ============================================================================================= def mdraw_yticklabels( dataframe: pd.core.frame.DataFrame, yticklabel: str, - model_col: str, - models: Optional[Union[Sequence[str], None]], flush: bool, ax: Axes, **kwargs: Any, ) -> Axes: + """ + Set custom y-axis tick labels on a matplotlib Axes object using the yticklabel column in the provided + pandas dataframe. + + Parameters + ---------- + dataframe : pd.core.frame.DataFrame + The pandas DataFrame from which the y-axis tick labels are derived. + yticklabel : str + Column name in the DataFrame whose values are used as y-axis tick labels. + flush : bool + If True, aligns y-axis tick labels to the left with adjusted padding to prevent overlap. + If False, aligns labels to the right. + ax : Axes + The matplotlib Axes object to be modified. + **kwargs : Any + Additional keyword arguments for customizing the appearance of the tick labels. + Supported customizations include 'fontfamily' (default 'monospace') and 'fontsize' (default 12). + + Returns + ------- + Axes + The modified matplotlib Axes object with updated y-axis tick labels. + """ ax.set_yticks(range(len(dataframe))) fontfamily = kwargs.get("fontfamily", "monospace") @@ -72,10 +91,16 @@ def mdraw_yticklabels( ) yax = ax.get_yaxis() fig = plt.gcf() - pad = max( - T.label.get_window_extent(renderer=fig.canvas.get_renderer()).width - for T in yax.majorTicks - ) + try: + pad = max( + T.label.get_window_extent(renderer=fig.canvas.get_renderer()).width + for T in yax.majorTicks + ) + except AttributeError: + pad = max( + T.label1.get_window_extent(renderer=fig.canvas.get_renderer()).width + for T in yax.majorTicks + ) yax.set_tick_params(pad=pad) else: ax.set_yticklabels( diff --git a/tests/test_mplot_graph_utils.py b/tests/test_mplot_graph_utils.py new file mode 100644 index 0000000..bcc5e9b --- /dev/null +++ b/tests/test_mplot_graph_utils.py @@ -0,0 +1,52 @@ +import matplotlib.pyplot as plt +import pandas as pd +from matplotlib.pyplot import Axes + +from forestplot.mplot_graph_utils import mdraw_ref_xline, mdraw_yticklabels + +x, y = [0, 1, 2], [0, 1, 2] +str_vector = ["a", "b", "c"] +input_df = pd.DataFrame( + { + "yticklabel": str_vector, + "estimate": x, + "moerror": y, + "ll": x, + "hl": y, + "pval": y, + "formatted_pval": y, + "yticklabel1": str_vector, + "yticklabel2": str_vector, + } +) + + +def test_mdraw_ref_xline(): + _, ax = plt.subplots() + ax = mdraw_ref_xline( + ax, + dataframe=input_df, + model_col="yticklabel", + annoteheaders=None, + right_annoteheaders=None, + ) + assert isinstance(ax, Axes) + + +def test_mdraw_yticklabels(): + # Prepare the input DataFrame + str_vector = ["a", "b", "c"] + input_df = pd.DataFrame( + { + "yticklabel": str_vector, + } + ) + + # Create a matplotlib Axes object + _, ax = plt.subplots() + + # Call the function + ax = mdraw_yticklabels(input_df, yticklabel="yticklabel", flush=True, ax=ax) + + assert isinstance(ax, Axes) + assert [label.get_text() for label in ax.get_yticklabels()] == str_vector