Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ptable plotters and add ptable_heatmap with diagonally-split tiles #131

Merged
merged 103 commits into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
283fb3d
`sourcery` clean ups
DanielYang59 Apr 6, 2024
3f7994c
remove finished TODO tag
DanielYang59 Apr 6, 2024
d2cf219
remove TODO tag
DanielYang59 Apr 6, 2024
abd4c55
remove accidentally tracked test file
DanielYang59 Apr 6, 2024
8ae397b
fix example notebook imports in matbench_dielectric_eda.ipynb and mat…
janosh Apr 6, 2024
3269710
add module doc strings
janosh Apr 6, 2024
974fb25
fix wbm-summary.csv download URL in explore_wbm.py
janosh Apr 6, 2024
96a5773
add sketch (color not working)
DanielYang59 Apr 10, 2024
14b8e96
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2024
d849814
fix color
DanielYang59 Apr 10, 2024
75ca44c
hide all border and remove alpha
DanielYang59 Apr 10, 2024
0e09d09
increase default font size
DanielYang59 Apr 10, 2024
13be543
update readme
DanielYang59 Apr 10, 2024
bf2ac9d
add condition to skip running svg compression on PRs from forks
janosh Apr 10, 2024
079895f
integrate rectangle plotter
DanielYang59 Apr 12, 2024
099379b
simplify rectangle plotter
DanielYang59 Apr 12, 2024
1d01d67
in plot_split_rectangle ax.pie set wedgeprops=dict(clip_on=True)
janosh Apr 12, 2024
4007f29
change default angle
DanielYang59 Apr 12, 2024
7678f93
fix data length check
DanielYang59 Apr 12, 2024
e1bb25d
finish adding plotter
DanielYang59 Apr 12, 2024
a5da1f9
adjust plotter to alphabet order
DanielYang59 Apr 12, 2024
0918544
add unit test
DanielYang59 Apr 12, 2024
199ebca
test_ptable_splits assert len fig.axes and cbar_ax title
janosh Apr 12, 2024
58e19a3
add DanielYang59 to author list in citation.cff and readme.md
janosh Apr 12, 2024
e5c12f5
remove affiliation
DanielYang59 Apr 13, 2024
45e137f
remove unused element type color
DanielYang59 Apr 16, 2024
caba1b2
update precommit hooks
DanielYang59 Apr 16, 2024
a0926f5
use 2-split as example
DanielYang59 Apr 16, 2024
d7028da
duplicate ptable
DanielYang59 Apr 16, 2024
110bcd6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 16, 2024
c1357e3
nest ELEM_CLASS_COLORS for simplicity
DanielYang59 Apr 16, 2024
8ca94e6
add projector sketch
DanielYang59 Apr 16, 2024
3ae4af7
remove add_element_type_legend
DanielYang59 Apr 16, 2024
1a188b6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 16, 2024
2a53ab4
remove legacy plotters
DanielYang59 Apr 16, 2024
588b875
remove unused axis.on
DanielYang59 Apr 16, 2024
f97a73f
update sketch
DanielYang59 Apr 16, 2024
c60b2e7
add norm for colorbar
DanielYang59 Apr 16, 2024
cf4f0be
update some docstring
DanielYang59 Apr 16, 2024
21759f9
add data preprocessor sketch
DanielYang59 Apr 16, 2024
f8dbc12
fix type alias
DanielYang59 Apr 16, 2024
5229e81
fix dataprocessor for df and dict
DanielYang59 Apr 17, 2024
1daebbd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2024
22fd367
add series processor
DanielYang59 Apr 17, 2024
adcc5d9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2024
a53d6db
format code
DanielYang59 Apr 17, 2024
d8e7cab
fix data fetch
DanielYang59 Apr 17, 2024
bc9d495
hide empty elements
DanielYang59 Apr 17, 2024
7aad43f
norm data, first working version
DanielYang59 Apr 17, 2024
03223ef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2024
45cde17
remove legacy split plotter
DanielYang59 Apr 17, 2024
ede431c
add migration place holders
DanielYang59 Apr 17, 2024
897036e
relocate norm
DanielYang59 Apr 17, 2024
1e16990
convert all value to nparray and handle single float
DanielYang59 Apr 17, 2024
4d4f02e
add some type error
DanielYang59 Apr 17, 2024
771e747
replcate ptable_splits with new implement
DanielYang59 Apr 21, 2024
3057f47
remove test for `add_element_type_legend`
DanielYang59 Apr 21, 2024
90bb348
add placeholders
DanielYang59 Apr 21, 2024
e814a51
add docstring for split plotter
DanielYang59 Apr 21, 2024
0819eba
merge ptable projector
DanielYang59 Apr 21, 2024
eb38aa3
move ELEM_CLASS_COLORS to utils
DanielYang59 Apr 21, 2024
cb90cd5
merge tests
DanielYang59 Apr 21, 2024
b594de9
bump pre-commit
DanielYang59 Apr 21, 2024
0322e13
swap funcion
DanielYang59 Apr 21, 2024
dba12c2
update sketch for line plotter
DanielYang59 Apr 21, 2024
4d08743
pass ax_kwds and child_args
DanielYang59 Apr 23, 2024
0e83500
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2024
db86584
update line plot style
DanielYang59 Apr 23, 2024
f648e84
Merge branch 'ptable_split' of github.com:DanielYang59/pymatviz into …
DanielYang59 Apr 23, 2024
78c427d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2024
7b658ee
increase symbol font size
DanielYang59 Apr 23, 2024
6f4b962
migrate scatter plotter
DanielYang59 Apr 23, 2024
d216d20
TEMP SAVE: save asset generator
DanielYang59 Apr 23, 2024
819f91c
merge
DanielYang59 Apr 23, 2024
2e01053
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2024
3b58a6c
BREAKING: rename kwds to kwargs globally
DanielYang59 Apr 23, 2024
7184960
add some docstring for ptable projector
DanielYang59 Apr 23, 2024
e28114e
move optional kwargs pack to the end
DanielYang59 Apr 23, 2024
9a37846
separate utils from ptable plotter
DanielYang59 Apr 23, 2024
b603f9e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2024
1b6a2df
merge utils
DanielYang59 Apr 24, 2024
ebf71fe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 24, 2024
8132b3a
update TEMP asset generator
DanielYang59 Apr 24, 2024
8198259
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 24, 2024
ea79902
fix handling of nested arrays
DanielYang59 Apr 24, 2024
ff023ea
add some docstring
DanielYang59 Apr 24, 2024
4ce9ec8
allow numpy int
DanielYang59 Apr 24, 2024
d6bc953
fix child arg passing
DanielYang59 Apr 24, 2024
07a4d9c
update asset generation for lines/splits/scatters
DanielYang59 Apr 24, 2024
3dfd64d
revise docstring
DanielYang59 Apr 24, 2024
48e422e
suppress type errors
DanielYang59 Apr 24, 2024
10f804e
fix union operator
DanielYang59 Apr 24, 2024
df91201
update README
DanielYang59 Apr 24, 2024
664766c
skip eslint in pre-commit
DanielYang59 Apr 25, 2024
4f83105
try downgrade `eslint`
DanielYang59 Apr 25, 2024
e8669fb
allow legacy config file for eslint
DanielYang59 Apr 25, 2024
6d64c5d
Merge branch 'main' into ptable_split
DanielYang59 Apr 29, 2024
91964da
remove TypeAlias
janosh May 2, 2024
befcb06
rename ptable_splits to ptable_heatmap_splits
DanielYang59 May 3, 2024
0779f30
Merge branch 'ptable_split' of github.com:DanielYang59/pymatviz into …
DanielYang59 May 3, 2024
80f5783
Revert "remove test for `add_element_type_legend`"
DanielYang59 May 3, 2024
efbcd49
Revert "remove add_element_type_legend"
DanielYang59 May 3, 2024
0e26145
fix func name in readme
DanielYang59 May 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/svgo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ on:

jobs:
tests:
# don't run on PRs from forks
if: github.event.pull_request.head.repo.fork == false
runs-on: ubuntu-latest
steps:
- name: Check out repo
Expand Down
1 change: 1 addition & 0 deletions assets/ptable-heatmap-splits.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions assets/ptable-lines.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions assets/ptable-scatters.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 8 additions & 1 deletion citation.cff
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,23 @@ authors:
- family-names: Riebesell
given-names: Janosh
affiliation: University of Cambridge
email: janosh@lbl.gov
email: janosh[email protected]
orcid: https://orcid.org/0000-0001-5233-3462
github: janosh
- family-names: Goodall
given-names: Rhys
affiliation: University of Cambridge
orcid: https://orcid.org/0000-0002-6589-1700
github: comprhys
- family-names: Baird
given-names: Sterling G.
affiliation: University of Utah
orcid: https://orcid.org/0000-0002-4491-6876
github: sgbaird
- family-names: Yang
given-names: Haoyu (Daniel)
email: [email protected]
github: DanielYang59
license: MIT
license-url: https://github.com/janosh/pymatviz/blob/main/license"
repository-code: https://github.com/janosh/pymatviz
Expand Down
66 changes: 53 additions & 13 deletions examples/_generate_assets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# %%
import json
import random
from glob import glob

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -37,8 +38,10 @@
ptable_heatmap,
ptable_heatmap_plotly,
ptable_heatmap_ratio,
ptable_heatmap_splits,
ptable_hists,
ptable_plots,
ptable_lines,
ptable_scatters,
)
from pymatviz.relevance import precision_recall_curve, roc_curve
from pymatviz.sankey import sankey_from_2_df_cols
Expand All @@ -48,7 +51,7 @@
from pymatviz.utils import TEST_FILES, df_ptable


# %% configure matplotlib and load test data
# %% Configure matplotlib and load test data
plt.rc("font", size=14)
plt.rc("savefig", bbox="tight", dpi=200)
plt.rc("axes", titlesize=16, titleweight="bold")
Expand All @@ -58,7 +61,7 @@
px.defaults.template = "pymatviz_white"
pio.templates.default = "pymatviz_white"

# random classification data
# Random classification data
np.random.seed(42)
rand_clf_size = 100
y_binary = np.random.choice([0, 1], size=rand_clf_size)
Expand All @@ -67,7 +70,7 @@
)


# random regression data
# Random regression data
rand_regression_size = 500
y_true = np.random.normal(5, 4, rand_regression_size)
y_pred = 1.2 * y_true - 2 * np.random.normal(0, 1, rand_regression_size)
Expand Down Expand Up @@ -164,7 +167,7 @@


# %% Histograms laid out in as a periodic table
# generate random parity data with y \approx x with some noise
# Generate random parity data with y \approx x with some noise
data_dict = {
elem.symbol: np.random.randn(100) + np.random.randn(100) for elem in Element
}
Expand All @@ -179,20 +182,57 @@
elem.symbol: [
np.random.randint(0, 20, 10),
np.random.randint(0, 20, 10),
np.random.randint(0, 20, 10),
# np.random.randint(0, 20, 10), # TODO: allow 3rd dim
]
for elem in Element
}

fig = ptable_plots(
fig = ptable_scatters(
data_dict,
colormap="coolwarm",
cbar_title="Periodic Table Scatter Plots",
plot_kwds=dict(marker="o", linestyle=""),
# colormap="coolwarm",
# cbar_title="Periodic Table Scatter Plots",
child_args=dict(marker="o", linestyle=""),
symbol_pos=(0.5, 1.2),
symbol_kwargs=dict(fontsize=14),
)
save_and_compress_svg(fig, "ptable-scatters")


# %% Line plots laid out as a periodic table
data_dict = {
elem.symbol: [
np.linspace(0, 10, 10),
np.sin(2 * np.pi * np.linspace(0, 10, 10)) + np.random.normal(0, 0.2, 10),
]
for elem in Element
}

fig = ptable_lines(
data_dict,
symbol_pos=(0.5, 1.2),
symbol_kwargs=dict(fontsize=14),
)
save_and_compress_svg(fig, "ptable-lines")


# %% Evenly-split tile plots laid out as a periodic table
data_dict = {
elem.symbol: [
random.randint(0, 10),
random.randint(10, 20),
]
for elem in Element
}

fig = ptable_heatmap_splits(
data=data_dict,
colormap="coolwarm",
start_angle=135,
cbar_title="Periodic Table Evenly-Split Heatmap Plots",
)
save_and_compress_svg(fig, "ptable-heatmap-splits")


# %% Uncertainty Plots
ax = qq_gaussian(y_pred, y_true, y_std, identity_line={"line_kwds": {"color": "red"}})
save_and_compress_svg(ax, "normal-prob-plot")
Expand Down Expand Up @@ -261,7 +301,7 @@


# %% Correlation Plots
# plot eigenvalue distribution of a pure-noise correlation matrix
# Plot eigenvalue distribution of a pure-noise correlation matrix
# i.e. the correlation matrix contains no significant correlations
# beyond the spurious correlation that occurs randomly
n_rows, n_cols = 500, 1000
Expand All @@ -272,7 +312,7 @@
ax = marchenko_pastur(corr_mat, gamma=n_cols / n_rows)
save_and_compress_svg(ax, "marchenko-pastur")

# plot eigenvalue distribution of a correlation matrix with significant
# Plot eigenvalue distribution of a correlation matrix with significant
# (i.e. non-noise) eigenvalue
n_rows, n_cols = 50, 400
linear_matrix = np.arange(n_rows * n_cols).reshape(n_rows, n_cols) / n_cols
Expand All @@ -282,7 +322,7 @@
ax = marchenko_pastur(corr_mat, gamma=n_cols / n_rows)
save_and_compress_svg(ax, "marchenko-pastur-significant-eval")

# plot eigenvalue distribution of a rank-deficient correlation matrix
# Plot eigenvalue distribution of a rank-deficient correlation matrix
n_rows, n_cols = 600, 500
rand_tall_mat = np.random.normal(0, 1, size=(n_rows, n_cols))

Expand Down
8 changes: 4 additions & 4 deletions examples/dataset_exploration/wbm/explore_wbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
__date__ = "2022-08-18"


# %% download wbm-steps-summary.csv (23.31 MB)
df_wbm = pd.read_csv(
"https://figshare.com/files/37542841?private_link=ff0ad14505f9624f0c05"
).set_index("material_id", drop=False)
# %% download wbm-summary.csv (12 MB)
df_wbm = pd.read_csv("https://figshare.com/ndownloader/files/44225498").set_index(
"material_id", drop=False
)

df_wbm["batch_idx"] = df_wbm.index.str.split("-").str[2].astype(int)
df_wbm["spg_num"] = df_wbm.wyckoff.str.split("_").str[2].astype(int)
Expand Down
2 changes: 1 addition & 1 deletion examples/diatomics/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from matplotlib.backends.backend_pdf import PdfPages
from pymatgen.core import Element

from pymatviz import ptable_plots
from pymatviz import ptable_plots # type: ignore[attr-defined] # TODO:
from pymatviz.io import save_fig
from pymatviz.utils import df_ptable

Expand Down
4 changes: 2 additions & 2 deletions examples/matbench_dielectric_eda.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"from tqdm import tqdm\n",
"\n",
"from pymatviz import ptable_heatmap, spacegroup_hist, spacegroup_sunburst\n",
"from pymatviz.utils import get_crystal_sys\n",
"from pymatviz.utils import crystal_sys_from_spg_num\n",
"\n",
"\n",
"__author__ = \"Janosh Riebesell\"\n",
Expand Down Expand Up @@ -64,7 +64,7 @@
"df_diel[[\"spg_symbol\", \"spg_num\"]] = [\n",
" struct.get_space_group_info() for struct in tqdm(df_diel.structure)\n",
"]\n",
"df_diel[\"crys_sys\"] = df_diel.spg_num.map(get_crystal_sys)"
"df_diel[\"crys_sys\"] = df_diel.spg_num.map(crystal_sys_from_spg_num)"
]
},
{
Expand Down
33 changes: 13 additions & 20 deletions examples/matbench_perovskites_eda.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,21 @@
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import plotly.express as px\n",
"import plotly.io as pio\n",
"\n",
"# see https://github.com/CompRhys/aviary/blob/main/aviary/wren/utils.py\n",
"from aviary.wren.utils import get_aflow_label_aflow\n",
"from matminer.datasets import load_dataset\n",
"from tqdm import tqdm\n",
"\n",
"from pymatviz import (\n",
" annotate_bars,\n",
" plot_structure_2d,\n",
" ptable_heatmap_plotly,\n",
" spacegroup_sunburst,\n",
")\n",
"from pymatviz.plot_defaults import plt, px\n",
"from pymatviz import plot_structure_2d, ptable_heatmap_plotly, spacegroup_sunburst\n",
"from pymatviz.io import save_and_compress_svg\n",
"from pymatviz.powerups import annotate_bars\n",
"from pymatviz.sankey import sankey_from_2_df_cols\n",
"from pymatviz.utils import get_crystal_sys, save_and_compress_svg\n",
"from pymatviz.utils import crystal_sys_from_spg_num\n",
"\n",
"\n",
"__author__ = \"Janosh Riebesell\"\n",
Expand Down Expand Up @@ -79,7 +77,7 @@
"\n",
"df_perov[\"formula\"] = df_perov.structure.map(lambda cryst: cryst.formula)\n",
"\n",
"df_perov[\"spglib_crys_sys\"] = df_perov.spglib_spg_num.map(get_crystal_sys)"
"df_perov[\"spglib_crys_sys\"] = df_perov.spglib_spg_num.map(crystal_sys_from_spg_num)"
]
},
{
Expand Down Expand Up @@ -291,7 +289,7 @@
"outputs": [],
"source": [
"df_perov[\"aflow_spg_num\"] = df_perov.aflow_wyckoff.str.split(\"_\").str[2].astype(int)\n",
"df_perov[\"aflow_crys_sys\"] = df_perov.aflow_spg_num.map(get_crystal_sys)"
"df_perov[\"aflow_crys_sys\"] = df_perov.aflow_spg_num.map(crystal_sys_from_spg_num)"
]
},
{
Expand Down Expand Up @@ -332,11 +330,9 @@
"source": [
"fig = sankey_from_2_df_cols(df_perov, [\"spglib_spg_num\", \"aflow_spg_num\"])\n",
"\n",
"fig.update_layout(\n",
" title=\"Spacegroups as determined by Spglib vs Aflow<br>\"\n",
" \"for the Matbench Perovskites dataset\",\n",
" title_x=0.5,\n",
")\n",
"title = \"Spglib vs Aflow Spacegroups<br>for the Matbench Perovskites dataset\"\n",
"\n",
"fig.layout.title.update(text=title, x=0.5)\n",
"\n",
"save_and_compress_svg(fig, \"sankey-spglib-vs-aflow-spacegroups\")\n",
"\n",
Expand All @@ -361,11 +357,8 @@
"source": [
"fig = sankey_from_2_df_cols(df_perov, [\"spglib_crys_sys\", \"aflow_crys_sys\"])\n",
"\n",
"fig.update_layout(\n",
" title=\"Crystal systems as determined by Spglib vs Aflow<br>\"\n",
" \"for the Matbench Perovskites dataset\",\n",
" title_x=0.5,\n",
")"
"title = \"Spglib vs Aflow Crystal systems<br>for the Matbench Perovskites dataset\"\n",
"fig.layout.title.update(text=title, x=0.5)"
]
}
],
Expand Down
6 changes: 5 additions & 1 deletion pymatviz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,16 @@
plot_phonon_dos,
)
from pymatviz.ptable import (
ChildPlotters,
PTableProjector,
count_elements,
ptable_heatmap,
ptable_heatmap_plotly,
ptable_heatmap_ratio,
ptable_heatmap_splits,
ptable_hists,
ptable_plots,
ptable_lines,
ptable_scatters,
)
from pymatviz.relevance import precision_recall_curve, roc_curve
from pymatviz.sankey import sankey_from_2_df_cols
Expand Down
2 changes: 2 additions & 0 deletions pymatviz/correlation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Plot distributions of correlation matrix eigenvalues."""

from __future__ import annotations

from typing import TYPE_CHECKING
Expand Down
2 changes: 2 additions & 0 deletions pymatviz/cumulative.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Plot the cumulative distribution of residuals and absolute errors."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any
Expand Down
2 changes: 2 additions & 0 deletions pymatviz/histograms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Histograms and bar charts."""

from __future__ import annotations

from collections.abc import Sequence
Expand Down
2 changes: 2 additions & 0 deletions pymatviz/io.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""I/O utilities for saving figures and dataframes to various image formats."""

from __future__ import annotations

import copy
Expand Down
2 changes: 2 additions & 0 deletions pymatviz/parity.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Parity, residual and density plots."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any
Expand Down
4 changes: 3 additions & 1 deletion pymatviz/phonons.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Plotting functions for pymatgen phonon band structures and density of states."""

from __future__ import annotations

import sys
Expand Down Expand Up @@ -384,8 +386,8 @@ def plot_phonon_dos(
fig.layout.font.size = 16 * (fig.layout.width or 800) / 800
fig.layout.legend.update(x=0.005, y=0.99, orientation="h", yanchor="top")

qual_colors = px.colors.qualitative.Plotly
if last_peak_anno:
qual_colors = px.colors.qualitative.Plotly
for idx, (key, dos) in enumerate(doses.items()):
last_peak = dos.get_last_peak()
color = (
Expand Down
8 changes: 5 additions & 3 deletions pymatviz/powerups.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""Powerups/enhancements such as parity lines, annotations and marginals for matplotlib
and plotly figures.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -196,12 +200,10 @@ def annotate_metrics(
"MSE": lambda x, y: ((x - y) ** 2).mean(),
"MAPE": mape,
"R2": r2_score,
# TODO: check this for correctness
"R2_adj": lambda x, y: 1 - (1 - r2_score(x, y)) * (len(x) - 1) / (len(x) - 2),
}
for key in set(metrics) - set(funcs):
func = getattr(sklearn.metrics, key, None)
if func:
if func := getattr(sklearn.metrics, key, None):
funcs[key] = func
if bad_keys := set(metrics) - set(funcs):
raise ValueError(f"Unrecognized metrics: {bad_keys}")
Expand Down
Loading