Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Better handling when no cached token exists, with more decriptive error #252

Merged
merged 2 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions polaris/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def login(

This CLI will use the OAuth2 protocol to gain token-based access to the Polaris Hub API.
"""
with PolarisHubClient(settings=PolarisHubSettings(_env_file=client_env_file)) as client:
client.login(auto_open_browser=auto_open_browser, overwrite=overwrite)
client = PolarisHubClient(settings=PolarisHubSettings(_env_file=client_env_file))
client.login(auto_open_browser=auto_open_browser, overwrite=overwrite)


@app.command(hidden=True)
Expand Down
27 changes: 17 additions & 10 deletions polaris/hub/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from authlib.oauth2.rfc6749 import OAuth2Token
from httpx import HTTPStatusError, Response
from loguru import logger
from typing_extensions import Self

from polaris.benchmark import (
BenchmarkV1Specification,
Expand Down Expand Up @@ -120,6 +121,15 @@ def __init__(
settings=self.settings, cache_auth_token=cache_auth_token, **kwargs
)

def __enter__(self: Self) -> Self:
"""
When used as a context manager, automatically check that authentication is valid.
"""
super().__enter__()
if not self.ensure_active_token():
raise PolarisUnauthorizedError()
return self

@property
def has_user_password(self) -> bool:
return bool(self.settings.username and self.settings.password)
Expand All @@ -143,16 +153,13 @@ def ensure_active_token(self, token: OAuth2Token | None = None) -> bool:
"""
Override the active check to trigger a refetch of the token if it is not active.
"""
if token is None:
# This won't be needed with if we set a lower bound for authlib: >=1.3.2
# See https://github.com/lepture/authlib/pull/625
# As of now, this latest version is not available on Conda though.
token = self.token

if token:
is_active = super().ensure_active_token(token)
if is_active:
return True
# This won't be needed with if we set a lower bound for authlib: >=1.3.2
# See https://github.com/lepture/authlib/pull/625
# As of now, this latest version is not available on Conda though.
token = token or self.token
is_active = super().ensure_active_token(token) if token else False
if is_active:
return True

# Check if external token is still valid, or we're using password auth
if not (self.has_user_password or self.external_client.ensure_active_token()):
Expand Down
7 changes: 3 additions & 4 deletions polaris/hub/external_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,12 @@ def fetch_token(self, **kwargs) -> dict:
) from error

def ensure_active_token(self, token: OAuth2Token | None = None) -> bool:
if token is None:
try:
# This won't be needed with if we set a lower bound for authlib: >=1.3.2
# See https://github.com/lepture/authlib/pull/625
# As of now, this latest version is not available on Conda though.
token = self.token
try:
return super().ensure_active_token(token) or False
token = token or self.token
return super().ensure_active_token(token) if token else False
except OAuthError:
# The refresh attempt can fail with this error
return False
Expand Down
2 changes: 0 additions & 2 deletions polaris/loader/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def load_dataset(path: str, verify_checksum: ChecksumStrategy = "verify_unless_z
if not is_file:
# Load from the Hub
with PolarisHubClient() as client:
client.ensure_active_token()
return client.get_dataset(*path.split("/"), verify_checksum=verify_checksum)

# Load from local file
Expand Down Expand Up @@ -73,7 +72,6 @@ def load_benchmark(path: str, verify_checksum: ChecksumStrategy = "verify_unless
if not is_file:
# Load from the Hub
with PolarisHubClient() as client:
client.ensure_active_token()
return client.get_benchmark(*path.split("/"), verify_checksum=verify_checksum)

with fsspec.open(path, "r") as fd:
Expand Down
13 changes: 0 additions & 13 deletions polaris/utils/context.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,8 @@
from contextlib import contextmanager

from halo import Halo

from polaris.mixins import FormattingMixin


@contextmanager
def tmp_attribute_change(obj, attribute, value):
"""Temporarily set and reset an attribute of an object."""
original_value = getattr(obj, attribute)
try:
setattr(obj, attribute, value)
yield obj
finally:
setattr(obj, attribute, original_value)


class ProgressIndicator(FormattingMixin):
def __init__(self, success_msg: str, error_msg: str, start_msg: str = "In progress..."):
self._start_msg = start_msg
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ dependencies = [
"datamol >=0.12.1",
"fastpdb",
"fsspec[http]",
"halo",
"httpx",
"loguru",
"numcodecs[msgpack]>=0.13.1",
Expand All @@ -55,7 +56,6 @@ dependencies = [
"tqdm",
"typer",
"typing-extensions>=4.12.0",
"halo",
"zarr >=2,<3",
]

Expand Down
10 changes: 5 additions & 5 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading