Skip to content

Commit

Permalink
callbacks: wrap non-dvc callbacks passed via fsspec
Browse files Browse the repository at this point in the history
  • Loading branch information
pmrowla committed Nov 2, 2023
1 parent fdb3851 commit 6bf2d72
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 8 deletions.
52 changes: 45 additions & 7 deletions src/dvc_objects/fs/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from contextlib import ExitStack
from functools import wraps
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, cast, overload
from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol, TypeVar, cast, overload

import fsspec

Expand All @@ -17,9 +17,21 @@
_R = TypeVar("_R")


class Callback(fsspec.Callback):
"""Callback usable as a context manager, and a few helper methods."""
class _CallbackProtocol(Protocol):
def relative_update(self, inc: int = 1) -> None:
...

def branch(
self,
path_1: "Union[str, BinaryIO]",
path_2: str,
kwargs: Dict[str, Any],
child: Optional["Callback"] = None,
) -> "Callback":
...


class _DVCCallbackMixin(_CallbackProtocol):
@overload
def wrap_attr(self, fobj: "BinaryIO", method: str = "read") -> "BinaryIO":
...
Expand Down Expand Up @@ -95,6 +107,10 @@ def __exit__(self, *exc_args):
def close(self):
"""Handle here on exit."""


class Callback(fsspec.Callback, _DVCCallbackMixin):
"""Callback usable as a context manager, and a few helper methods."""

def relative_update(self, inc: int = 1) -> None:
inc = inc if inc is not None else 0
return super().relative_update(inc)
Expand All @@ -104,18 +120,26 @@ def absolute_update(self, value: int) -> None:
return super().absolute_update(value)

@classmethod
def as_callback(cls, maybe_callback: Optional["Callback"] = None) -> "Callback":
def as_callback(
cls, maybe_callback: Optional[fsspec.callbacks.Callback] = None
) -> "Callback":
if maybe_callback is None:
return DEFAULT_CALLBACK
return maybe_callback
if isinstance(maybe_callback, Callback):
return maybe_callback
return _FsspecCallbackWrapper(maybe_callback)

@classmethod
def as_tqdm_callback(
cls,
callback: Optional["Callback"] = None,
callback: Optional[fsspec.callbacks.Callback] = None,
**tqdm_kwargs: Any,
) -> "Callback":
return callback or TqdmCallback(**tqdm_kwargs)
if callback is None:
return TqdmCallback(**tqdm_kwargs)
if isinstance(callback, Callback):
return callback
return cast("Callback", _FsspecCallbackWrapper(callback))

def branch( # pylint: disable=arguments-differ
self,
Expand Down Expand Up @@ -185,4 +209,18 @@ def branch(
return super().branch(path_1, path_2, kwargs, child=child)


class _FsspecCallbackWrapper(fsspec.callbacks.Callback, _DVCCallbackMixin):
def __init__(self, callback: fsspec.callbacks.Callback):
object.__setattr__(self, "_callback", callback)

def __getattr__(self, name: str):
return getattr(self._callback, name)

def __setattr__(self, name: str, value: Any):
setattr(self._callback, name, value)

def branch(self, *args, **kwargs):
return _FsspecCallbackWrapper(self._callback.branch(*args, **kwargs))


DEFAULT_CALLBACK = NoOpCallback()
24 changes: 23 additions & 1 deletion tests/fs/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Optional

import fsspec
import pytest

from dvc_objects.fs.callbacks import DEFAULT_CALLBACK, Callback
from dvc_objects.fs.callbacks import DEFAULT_CALLBACK, Callback, TqdmCallback


@pytest.mark.parametrize("api", ["set_size", "relative_update", "absolute_update"])
Expand Down Expand Up @@ -29,3 +32,22 @@ def test_callback_with_none(request, api, callback_factory, kwargs, mocker):
if callback is not DEFAULT_CALLBACK:
assert callback.size is None
assert callback.value == 0


def test_wrap_fsspec():
def _branch_fn(*args, callback: Optional["Callback"] = None, **kwargs):
pass

callback = fsspec.callbacks.Callback()
assert callback.value == 0
with Callback.as_tqdm_callback(callback) as cb:
assert not isinstance(cb, TqdmCallback)
assert cb.value == 0
cb.relative_update()
assert cb.value == 1
assert callback.value == 1

fn = cb.wrap_and_branch(_branch_fn)
fn("foo", "bar", callback=callback)
assert cb.value == 2
assert callback.value == 2

0 comments on commit 6bf2d72

Please sign in to comment.