Skip to content

Commit

Permalink
Pleasing linters
Browse files Browse the repository at this point in the history
  • Loading branch information
LSYS committed Dec 16, 2023
1 parent 003ec98 commit 5a5bd3e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
4 changes: 2 additions & 2 deletions forestplot/mplot_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def mdraw_est_markers(
mcolor : Union[Sequence[str], None], optional
A sequence of colors for each model group, defaults to ["0", "0.4", ".8", "0.2"].
**kwargs : Any
Additional keyword arguments. Supported customizations include 'markersize' (default 40)
Additional keyword arguments. Supported customizations include 'markersize' (default 40)
and 'offset' for the spacing between markers of different model groups.
Returns
Expand All @@ -159,7 +159,7 @@ def mdraw_est_markers(
_y = base_y_vector + (ix * offset)
ax.scatter(y=_y, x=_df[estimate], marker=msymbols[ix], color=mcolor[ix], s=markersize)
return ax


def mdraw_ci(
dataframe: pd.core.frame.DataFrame,
Expand Down
24 changes: 15 additions & 9 deletions tests/test_mplot_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import pandas as pd
from matplotlib.pyplot import Axes

from forestplot.mplot_graph_utils import mdraw_ref_xline, mdraw_yticklabels, mdraw_est_markers
from forestplot.mplot_graph_utils import mdraw_est_markers, mdraw_ref_xline, mdraw_yticklabels

x, y = [0, 1, 2], [0, 1, 2]
str_vector = ["a", "b", "c"]
models_vector =["m1", "m1", "m2"]
models_vector = ["m1", "m1", "m2"]
input_df = pd.DataFrame(
{
"yticklabel": str_vector,
Expand Down Expand Up @@ -55,11 +55,17 @@ def test_mdraw_yticklabels():


def test_mdraw_est_markers():
_, ax = plt.subplots()
ax = mdraw_est_markers(input_df, estimate='estimate', model_col='model', models=list(set(models_vector)), ax=ax)
assert (all(isinstance(tick, int)) for tick in ax.get_yticks())
_, ax = plt.subplots()
ax = mdraw_est_markers(
input_df,
estimate="estimate",
model_col="model",
models=list(set(models_vector)),
ax=ax,
)
assert (all(isinstance(tick, int)) for tick in ax.get_yticks())

xmin, xmax = ax.get_xlim()
assert xmin <= input_df["estimate"].min()
assert xmax >= input_df["estimate"].max()
assert len(ax.collections) == len(set(models_vector))
xmin, xmax = ax.get_xlim()
assert xmin <= input_df["estimate"].min()
assert xmax >= input_df["estimate"].max()
assert len(ax.collections) == len(set(models_vector))

0 comments on commit 5a5bd3e

Please sign in to comment.