Skip to content

Commit

Permalink
Allow to use pathlib.Path in metatensor.save_xxx (metatensor#789)
Browse files Browse the repository at this point in the history
  • Loading branch information
tulga-rdn authored Nov 20, 2024
1 parent 5f63ee8 commit 0fbc23d
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 39 deletions.
33 changes: 19 additions & 14 deletions python/metatensor-core/metatensor/io/_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,25 +188,30 @@ def _save_block(
):
assert isinstance(block, TensorBlock)

if isinstance(file, (str, pathlib.Path)):
if not file.endswith(".npz"):
file += ".npz"
warnings.warn(
message=f"adding '.npz' extension, the file will be saved at '{file}'",
stacklevel=1,
)

if use_numpy:
all_entries = _block_to_dict(block, prefix="", is_gradient=False)
np.savez(file, **all_entries)
else:
lib = _get_library()
if isinstance(file, (str, pathlib.Path)):
if isinstance(file, str):
path = file.encode("utf8")
elif isinstance(file, pathlib.Path):
path = bytes(file)

if isinstance(file, str):
if not file.endswith(".npz"):
file += ".npz"
warnings.warn(
message="adding '.npz' extension,"
f" the file will be saved at '{file}'",
stacklevel=1,
)
path = file.encode("utf8")
lib.mts_block_save(path, block._ptr)
elif isinstance(file, pathlib.Path):
if not file.name.endswith(".npz"):
file = file.with_name(file.name + ".npz")
warnings.warn(
message="adding '.npz' extension,"
f" the file will be saved at '{file.name}'",
stacklevel=1,
)
path = bytes(file)
lib.mts_block_save(path, block._ptr)
else:
# assume we have a file-like object
Expand Down
27 changes: 16 additions & 11 deletions python/metatensor-core/metatensor/io/_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,21 +77,26 @@ def _save_labels(
"""
assert isinstance(labels, Labels)

if isinstance(file, (str, pathlib.Path)):
lib = _get_library()
if isinstance(file, str):
if not file.endswith(".npy"):
file += ".npy"
file += ".npz"
warnings.warn(
message=f"adding '.npy' extension, the file will be saved at '{file}'",
message="adding '.npy' extension,"
f" the file will be saved at '{file}'",
stacklevel=1,
)

lib = _get_library()
if isinstance(file, (str, pathlib.Path)):
if isinstance(file, str):
path = file.encode("utf8")
elif isinstance(file, pathlib.Path):
path = bytes(file)

path = file.encode("utf8")
lib.mts_labels_save(path, labels._labels)
elif isinstance(file, pathlib.Path):
if not file.name.endswith(".npy"):
file = file.with_name(file.name + ".npy")
warnings.warn(
message="adding '.npy' extension,"
f" the file will be saved at '{file.name}'",
stacklevel=1,
)
path = bytes(file)
lib.mts_labels_save(path, labels._labels)
else:
# assume we have a file-like object
Expand Down
33 changes: 19 additions & 14 deletions python/metatensor-core/metatensor/io/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,25 +149,30 @@ def _save_tensor(
):
assert isinstance(tensor, TensorMap)

if isinstance(file, (str, pathlib.Path)):
if not file.endswith(".npz"):
file += ".npz"
warnings.warn(
message=f"adding '.npz' extension, the file will be saved at '{file}'",
stacklevel=1,
)

if use_numpy:
all_entries = _tensor_to_dict(tensor)
np.savez(file, **all_entries)
else:
lib = _get_library()
if isinstance(file, (str, pathlib.Path)):
if isinstance(file, str):
path = file.encode("utf8")
elif isinstance(file, pathlib.Path):
path = bytes(file)

if isinstance(file, str):
if not file.endswith(".npz"):
file += ".npz"
warnings.warn(
message="adding '.npz' extension,"
f" the file will be saved at '{file}'",
stacklevel=1,
)
path = file.encode("utf8")
lib.mts_tensormap_save(path, tensor._ptr)
elif isinstance(file, pathlib.Path):
if not file.name.endswith(".npz"):
file = file.with_name(file.name + ".npz")
warnings.warn(
message="adding '.npz' extension,"
f" the file will be saved at '{file.name}'",
stacklevel=1,
)
path = bytes(file)
lib.mts_tensormap_save(path, tensor._ptr)
else:
# assume we have a file-like object
Expand Down
21 changes: 21 additions & 0 deletions python/metatensor-core/tests/serialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import io
import os
import pickle
from pathlib import Path

import numpy as np
import pytest
Expand Down Expand Up @@ -193,6 +194,26 @@ def test_save_warning_errors(tmpdir, tensor):
metatensor.save(tmpfile, tensor.block(0).values)


def test_save_pathlib(tmpdir, tensor):
# does not have .npz ending and causes warning
tmpfile = Path("serialize-test")

expected = f"adding '.npz' extension, the file will be saved at '{tmpfile}.npz'"
with tmpdir.as_cwd():
with pytest.warns(UserWarning, match=expected):
metatensor.save(tmpfile, tensor)

tmpfile = "serialize-test.npz"

message = (
"`data` must be one of 'Labels', 'TensorBlock' or 'TensorMap', "
"not <class 'numpy.ndarray'>"
)
with pytest.raises(TypeError, match=message):
with tmpdir.as_cwd():
metatensor.save(tmpfile, tensor.block(0).values)


@pytest.mark.parametrize("protocol", PICKLE_PROTOCOLS)
def test_pickle(protocol, tmpdir, tensor):
"""
Expand Down

0 comments on commit 0fbc23d

Please sign in to comment.