diff --git a/polaris/cli.py b/polaris/cli.py index c5682c4b..a796009a 100644 --- a/polaris/cli.py +++ b/polaris/cli.py @@ -27,8 +27,8 @@ def login( This CLI will use the OAuth2 protocol to gain token-based access to the Polaris Hub API. """ - with PolarisHubClient(settings=PolarisHubSettings(_env_file=client_env_file)) as client: - client.login(auto_open_browser=auto_open_browser, overwrite=overwrite) + client = PolarisHubClient(settings=PolarisHubSettings(_env_file=client_env_file)) + client.login(auto_open_browser=auto_open_browser, overwrite=overwrite) @app.command(hidden=True) diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 0d06adba..32b2e036 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -15,6 +15,7 @@ from authlib.oauth2.rfc6749 import OAuth2Token from httpx import HTTPStatusError, Response from loguru import logger +from typing_extensions import Self from polaris.benchmark import ( BenchmarkV1Specification, @@ -120,6 +121,15 @@ def __init__( settings=self.settings, cache_auth_token=cache_auth_token, **kwargs ) + def __enter__(self: Self) -> Self: + """ + When used as a context manager, automatically check that authentication is valid. + """ + super().__enter__() + if not self.ensure_active_token(): + raise PolarisUnauthorizedError() + return self + @property def has_user_password(self) -> bool: return bool(self.settings.username and self.settings.password) @@ -143,16 +153,13 @@ def ensure_active_token(self, token: OAuth2Token | None = None) -> bool: """ Override the active check to trigger a refetch of the token if it is not active. """ - if token is None: - # This won't be needed with if we set a lower bound for authlib: >=1.3.2 - # See https://github.com/lepture/authlib/pull/625 - # As of now, this latest version is not available on Conda though. - token = self.token - - if token: - is_active = super().ensure_active_token(token) - if is_active: - return True + # This won't be needed with if we set a lower bound for authlib: >=1.3.2 + # See https://github.com/lepture/authlib/pull/625 + # As of now, this latest version is not available on Conda though. + token = token or self.token + is_active = super().ensure_active_token(token) if token else False + if is_active: + return True # Check if external token is still valid, or we're using password auth if not (self.has_user_password or self.external_client.ensure_active_token()): diff --git a/polaris/hub/external_client.py b/polaris/hub/external_client.py index 2d7ab72b..278a4f55 100644 --- a/polaris/hub/external_client.py +++ b/polaris/hub/external_client.py @@ -78,13 +78,12 @@ def fetch_token(self, **kwargs) -> dict: ) from error def ensure_active_token(self, token: OAuth2Token | None = None) -> bool: - if token is None: + try: # This won't be needed with if we set a lower bound for authlib: >=1.3.2 # See https://github.com/lepture/authlib/pull/625 # As of now, this latest version is not available on Conda though. - token = self.token - try: - return super().ensure_active_token(token) or False + token = token or self.token + return super().ensure_active_token(token) if token else False except OAuthError: # The refresh attempt can fail with this error return False diff --git a/polaris/loader/load.py b/polaris/loader/load.py index 306911c4..60108b4e 100644 --- a/polaris/loader/load.py +++ b/polaris/loader/load.py @@ -34,7 +34,6 @@ def load_dataset(path: str, verify_checksum: ChecksumStrategy = "verify_unless_z if not is_file: # Load from the Hub with PolarisHubClient() as client: - client.ensure_active_token() return client.get_dataset(*path.split("/"), verify_checksum=verify_checksum) # Load from local file @@ -73,7 +72,6 @@ def load_benchmark(path: str, verify_checksum: ChecksumStrategy = "verify_unless if not is_file: # Load from the Hub with PolarisHubClient() as client: - client.ensure_active_token() return client.get_benchmark(*path.split("/"), verify_checksum=verify_checksum) with fsspec.open(path, "r") as fd: diff --git a/polaris/utils/context.py b/polaris/utils/context.py index c53a80bd..582032dd 100644 --- a/polaris/utils/context.py +++ b/polaris/utils/context.py @@ -1,21 +1,8 @@ -from contextlib import contextmanager - from halo import Halo from polaris.mixins import FormattingMixin -@contextmanager -def tmp_attribute_change(obj, attribute, value): - """Temporarily set and reset an attribute of an object.""" - original_value = getattr(obj, attribute) - try: - setattr(obj, attribute, value) - yield obj - finally: - setattr(obj, attribute, original_value) - - class ProgressIndicator(FormattingMixin): def __init__(self, success_msg: str, error_msg: str, start_msg: str = "In progress..."): self._start_msg = start_msg diff --git a/pyproject.toml b/pyproject.toml index 757ab896..8130efda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "datamol >=0.12.1", "fastpdb", "fsspec[http]", + "halo", "httpx", "loguru", "numcodecs[msgpack]>=0.13.1", @@ -55,7 +56,6 @@ dependencies = [ "tqdm", "typer", "typing-extensions>=4.12.0", - "halo", "zarr >=2,<3", ] diff --git a/uv.lock b/uv.lock index a700988f..f375b89f 100644 --- a/uv.lock +++ b/uv.lock @@ -451,7 +451,7 @@ name = "click" version = "8.1.8" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } wheels = [ @@ -1006,7 +1006,7 @@ name = "ipykernel" version = "6.29.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "appnope", marker = "platform_system == 'Darwin'" }, + { name = "appnope", marker = "sys_platform == 'darwin'" }, { name = "comm" }, { name = "debugpy" }, { name = "ipython" }, @@ -1641,7 +1641,7 @@ version = "1.6.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, { name = "ghp-import" }, { name = "jinja2" }, { name = "markdown" }, @@ -2237,7 +2237,7 @@ wheels = [ [[package]] name = "polaris-lib" -version = "0.11.3.dev0+g817f3a5.d20250118" +version = "0.11.4.dev0+g1649472.d20250123" source = { editable = "." } dependencies = [ { name = "authlib" }, @@ -3377,7 +3377,7 @@ name = "tqdm" version = "4.67.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 } wheels = [