diff --git a/forestplot/mplot_graph_utils.py b/forestplot/mplot_graph_utils.py index d9e287a..e32f5df 100644 --- a/forestplot/mplot_graph_utils.py +++ b/forestplot/mplot_graph_utils.py @@ -142,7 +142,7 @@ def mdraw_est_markers( mcolor : Union[Sequence[str], None], optional A sequence of colors for each model group, defaults to ["0", "0.4", ".8", "0.2"]. **kwargs : Any - Additional keyword arguments. Supported customizations include 'markersize' (default 40) + Additional keyword arguments. Supported customizations include 'markersize' (default 40) and 'offset' for the spacing between markers of different model groups. Returns @@ -159,7 +159,7 @@ def mdraw_est_markers( _y = base_y_vector + (ix * offset) ax.scatter(y=_y, x=_df[estimate], marker=msymbols[ix], color=mcolor[ix], s=markersize) return ax - + def mdraw_ci( dataframe: pd.core.frame.DataFrame, diff --git a/tests/test_mplot_graph_utils.py b/tests/test_mplot_graph_utils.py index 7292e45..b7b3626 100644 --- a/tests/test_mplot_graph_utils.py +++ b/tests/test_mplot_graph_utils.py @@ -2,11 +2,11 @@ import pandas as pd from matplotlib.pyplot import Axes -from forestplot.mplot_graph_utils import mdraw_ref_xline, mdraw_yticklabels, mdraw_est_markers +from forestplot.mplot_graph_utils import mdraw_est_markers, mdraw_ref_xline, mdraw_yticklabels x, y = [0, 1, 2], [0, 1, 2] str_vector = ["a", "b", "c"] -models_vector =["m1", "m1", "m2"] +models_vector = ["m1", "m1", "m2"] input_df = pd.DataFrame( { "yticklabel": str_vector, @@ -55,11 +55,17 @@ def test_mdraw_yticklabels(): def test_mdraw_est_markers(): - _, ax = plt.subplots() - ax = mdraw_est_markers(input_df, estimate='estimate', model_col='model', models=list(set(models_vector)), ax=ax) - assert (all(isinstance(tick, int)) for tick in ax.get_yticks()) + _, ax = plt.subplots() + ax = mdraw_est_markers( + input_df, + estimate="estimate", + model_col="model", + models=list(set(models_vector)), + ax=ax, + ) + assert (all(isinstance(tick, int)) for tick in ax.get_yticks()) - xmin, xmax = ax.get_xlim() - assert xmin <= input_df["estimate"].min() - assert xmax >= input_df["estimate"].max() - assert len(ax.collections) == len(set(models_vector)) \ No newline at end of file + xmin, xmax = ax.get_xlim() + assert xmin <= input_df["estimate"].min() + assert xmax >= input_df["estimate"].max() + assert len(ax.collections) == len(set(models_vector))