Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: add cache strategy option #161

Merged
merged 2 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/usage/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ github = GitHub(
user_agent="GitHubKit/Python",
follow_redirects=True,
timeout=None,
cache_strategy=None,
http_cache=True,
auto_retry=True,
rest_api_validate_body=True,
Expand All @@ -24,13 +25,15 @@ Or, you can pass the config object directly (not recommended):
import httpx
from githubkit import GitHub, Config
from githubkit.retry import RETRY_DEFAULT
from githubkit.cache import DEFAULT_CACHE_STRATEGY

config = Config(
base_url="https://api.github.com/",
accept="application/vnd.github+json",
user_agent="GitHubKit/Python",
follow_redirects=True,
timeout=httpx.Timeout(None),
cache_strategy=DEFAULT_CACHE_STRATEGY,
http_cache=True,
auto_retry=RETRY_DEFAULT,
rest_api_validate_body=True,
Expand Down Expand Up @@ -65,6 +68,10 @@ The `follow_redirects` option is used to enable or disable the HTTP redirect fol

The `timeout` option is used to set the request timeout. You can pass a float, `None` or `httpx.Timeout` to this field. By default, the requests will never timeout. See [Timeout](https://www.python-httpx.org/advanced/timeouts/) for more information.

### `cache_strategy`

The `cache_strategy` option defines how to cache the tokens or http responses. You can provide a githubkit built-in cache strategy or a custom one that implements the `BaseCacheStrategy` interface. By default, githubkit uses the `MemCacheStrategy` to cache the data in memory.

### `http_cache`

The `http_cache` option enables the http caching feature powered by [Hishel](https://hishel.com/) for HTTPX. GitHub API limits the number of requests that you can make within a specific amount of time. This feature is useful to reduce the number of requests to GitHub API and avoid hitting the rate limit.
Expand Down
35 changes: 16 additions & 19 deletions githubkit/auth/app.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Union, Optional
from typing_extensions import LiteralString
from datetime import datetime, timezone, timedelta
from collections.abc import Generator, AsyncGenerator
from typing import TYPE_CHECKING, Union, ClassVar, Optional

import httpx

from githubkit.exception import AuthCredentialError
from githubkit.cache import DEFAULT_CACHE, BaseCache
from githubkit.utils import UNSET, Unset, exclude_unset
from githubkit.compat import model_dump, type_validate_python

Expand Down Expand Up @@ -38,10 +38,9 @@ class AppAuth(httpx.Auth):
repositories: Union[Unset, list[str]] = UNSET
repository_ids: Union[Unset, list[int]] = UNSET
permissions: Union[Unset, "AppPermissionsType"] = UNSET
cache: "BaseCache" = DEFAULT_CACHE

JWT_CACHE_KEY = "githubkit:auth:app:{issuer}:jwt"
INSTALLATION_CACHE_KEY = (
JWT_CACHE_KEY: ClassVar[LiteralString] = "githubkit:auth:app:{issuer}:jwt"
INSTALLATION_CACHE_KEY: ClassVar[LiteralString] = (
"githubkit:auth:app:{issuer}:installation:"
"{installation_id}:{permissions}:{repositories}:{repository_ids}"
)
Expand Down Expand Up @@ -89,17 +88,19 @@ def _get_jwt_cache_key(self) -> str:
return self.JWT_CACHE_KEY.format(issuer=self.issuer)

def get_jwt(self) -> str:
cache = self.github.config.cache_strategy.get_cache_storage()
cache_key = self._get_jwt_cache_key()
if not (token := self.cache.get(cache_key)):
if not (token := cache.get(cache_key)):
token = self._create_jwt()
self.cache.set(cache_key, token, timedelta(minutes=8))
cache.set(cache_key, token, timedelta(minutes=8))
return token

async def aget_jwt(self) -> str:
cache = self.github.config.cache_strategy.get_async_cache_storage()
cache_key = self._get_jwt_cache_key()
if not (token := await self.cache.aget(cache_key)):
if not (token := await cache.aget(cache_key)):
token = self._create_jwt()
await self.cache.aset(cache_key, token, timedelta(minutes=8))
await cache.aset(cache_key, token, timedelta(minutes=8))
return token

def _build_installation_auth_request(self) -> httpx.Request:
Expand Down Expand Up @@ -202,8 +203,9 @@ def sync_auth_flow(
).sync_auth_flow(request)
return

cache = self.github.config.cache_strategy.get_cache_storage()
key = self._get_installation_cache_key()
if not (token := self.cache.get(key)):
if not (token := cache.get(key)):
token_request = self._build_installation_auth_request()
token_request.headers["Authorization"] = f"Bearer {self.get_jwt()}"
response = yield token_request
Expand All @@ -213,7 +215,7 @@ def sync_auth_flow(
expire = datetime.strptime(
response.parsed_data.expires_at, "%Y-%m-%dT%H:%M:%SZ"
).replace(tzinfo=timezone.utc) - datetime.now(timezone.utc)
self.cache.set(key, token, expire)
cache.set(key, token, expire)
request.headers["Authorization"] = f"token {token}"
yield request

Expand All @@ -239,8 +241,9 @@ async def async_auth_flow(
yield request
return

cache = self.github.config.cache_strategy.get_async_cache_storage()
key = self._get_installation_cache_key()
if not (token := await self.cache.aget(key)):
if not (token := await cache.aget(key)):
token_request = self._build_installation_auth_request()
token_request.headers["Authorization"] = f"Bearer {await self.aget_jwt()}"
response = yield token_request
Expand All @@ -250,7 +253,7 @@ async def async_auth_flow(
expire = datetime.strptime(
response.parsed_data.expires_at, "%Y-%m-%dT%H:%M:%SZ"
).replace(tzinfo=timezone.utc) - datetime.now(timezone.utc)
await self.cache.aset(key, token, expire)
await cache.aset(key, token, expire)
request.headers["Authorization"] = f"token {token}"
yield request

Expand All @@ -263,7 +266,6 @@ class AppAuthStrategy(BaseAuthStrategy):
private_key: str
client_id: Optional[str] = None
client_secret: Optional[str] = None
cache: "BaseCache" = DEFAULT_CACHE

def __post_init__(self):
# either app_id or client_id must be provided
Expand All @@ -288,7 +290,6 @@ def as_installation(
repositories,
repository_ids,
permissions,
self.cache,
)

def as_oauth_app(self) -> OAuthAppAuthStrategy:
Expand All @@ -305,7 +306,6 @@ def get_auth_flow(self, github: "GitHubCore") -> httpx.Auth:
self.private_key,
self.client_id,
self.client_secret,
cache=self.cache,
)


Expand All @@ -321,7 +321,6 @@ class AppInstallationAuthStrategy(BaseAuthStrategy):
repositories: Union[Unset, list[str]] = UNSET
repository_ids: Union[Unset, list[int]] = UNSET
permissions: Union[Unset, "AppPermissionsType"] = UNSET
cache: "BaseCache" = DEFAULT_CACHE

def __post_init__(self):
# either app_id or client_id must be provided
Expand All @@ -336,7 +335,6 @@ def as_app(self) -> AppAuthStrategy:
self.private_key,
self.client_id,
self.client_secret,
self.cache,
)

def get_auth_flow(self, github: "GitHubCore") -> httpx.Auth:
Expand All @@ -350,5 +348,4 @@ def get_auth_flow(self, github: "GitHubCore") -> httpx.Auth:
self.repositories,
self.repository_ids,
self.permissions,
cache=self.cache,
)
5 changes: 4 additions & 1 deletion githubkit/cache/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from .base import BaseCache as BaseCache
from .mem_cache import MemCache as MemCache
from .base import AsyncBaseCache as AsyncBaseCache
from .base import BaseCacheStrategy as BaseCacheStrategy
from .mem_cache import MemCacheStrategy as MemCacheStrategy

DEFAULT_CACHE = MemCache()
DEFAULT_CACHE_STRATEGY = MemCacheStrategy()
26 changes: 24 additions & 2 deletions githubkit/cache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,42 @@
from typing import Optional
from datetime import timedelta

from hishel import BaseStorage, AsyncBaseStorage


class BaseCache(abc.ABC):
@abc.abstractmethod
def get(self, key: str) -> Optional[str]:
raise NotImplementedError

@abc.abstractmethod
async def aget(self, key: str) -> Optional[str]:
def set(self, key: str, value: str, ex: timedelta) -> None:
raise NotImplementedError


class AsyncBaseCache(abc.ABC):
@abc.abstractmethod
def set(self, key: str, value: str, ex: timedelta) -> None:
async def aget(self, key: str) -> Optional[str]:
raise NotImplementedError

@abc.abstractmethod
async def aset(self, key: str, value: str, ex: timedelta) -> None:
raise NotImplementedError


class BaseCacheStrategy(abc.ABC):
@abc.abstractmethod
def get_cache_storage(self) -> BaseCache:
raise NotImplementedError

@abc.abstractmethod
def get_async_cache_storage(self) -> AsyncBaseCache:
raise NotImplementedError

@abc.abstractmethod
def get_hishel_storage(self) -> BaseStorage:
raise NotImplementedError

@abc.abstractmethod
def get_async_hishel_storage(self) -> AsyncBaseStorage:
raise NotImplementedError
31 changes: 29 additions & 2 deletions githubkit/cache/mem_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from dataclasses import dataclass
from datetime import datetime, timezone, timedelta

from .base import BaseCache
from hishel import InMemoryStorage, AsyncInMemoryStorage

from .base import BaseCache, AsyncBaseCache, BaseCacheStrategy


@dataclass(frozen=True)
Expand All @@ -11,7 +13,7 @@ class _Item:
expire_at: Optional[datetime] = None


class MemCache(BaseCache):
class MemCache(AsyncBaseCache, BaseCache):
"""Simple Memory Cache with Expiration Support"""

def __init__(self):
Expand All @@ -36,3 +38,28 @@ def set(self, key: str, value: str, ex: timedelta) -> None:

async def aset(self, key: str, value: str, ex: timedelta) -> None:
return self.set(key, value, ex)


class MemCacheStrategy(BaseCacheStrategy):
def __init__(self) -> None:
self._cache: Optional[MemCache] = None
self._hishel_storage: Optional[InMemoryStorage] = None
self._hishel_async_storage: Optional[AsyncInMemoryStorage] = None

def get_cache_storage(self) -> MemCache:
if self._cache is None:
self._cache = MemCache()
return self._cache

def get_async_cache_storage(self) -> MemCache:
return self.get_cache_storage()

def get_hishel_storage(self) -> InMemoryStorage:
if self._hishel_storage is None:
self._hishel_storage = InMemoryStorage()
return self._hishel_storage

def get_async_hishel_storage(self) -> AsyncInMemoryStorage:
if self._hishel_async_storage is None:
self._hishel_async_storage = AsyncInMemoryStorage()
return self._hishel_async_storage
11 changes: 11 additions & 0 deletions githubkit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .retry import RETRY_DEFAULT
from .typing import RetryDecisionFunc
from .cache import DEFAULT_CACHE_STRATEGY, BaseCacheStrategy


@dataclass(frozen=True)
Expand All @@ -15,6 +16,7 @@ class Config:
user_agent: str
follow_redirects: bool
timeout: httpx.Timeout
cache_strategy: BaseCacheStrategy
http_cache: bool
auto_retry: Optional[RetryDecisionFunc]
rest_api_validate_body: bool
Expand Down Expand Up @@ -64,6 +66,12 @@ def build_timeout(
return timeout if isinstance(timeout, httpx.Timeout) else httpx.Timeout(timeout)


def build_cache_strategy(
cache_strategy: Optional[BaseCacheStrategy],
) -> BaseCacheStrategy:
return cache_strategy or DEFAULT_CACHE_STRATEGY


def build_auto_retry(
auto_retry: Union[bool, RetryDecisionFunc] = True,
) -> Optional[RetryDecisionFunc]:
Expand All @@ -76,12 +84,14 @@ def build_auto_retry(


def get_config(
*,
base_url: Optional[Union[str, httpx.URL]] = None,
accept_format: Optional[str] = None,
previews: Optional[list[str]] = None,
user_agent: Optional[str] = None,
follow_redirects: bool = True,
timeout: Optional[Union[float, httpx.Timeout]] = None,
cache_strategy: Optional[BaseCacheStrategy] = None,
http_cache: bool = True,
auto_retry: Union[bool, RetryDecisionFunc] = True,
rest_api_validate_body: bool = True,
Expand All @@ -92,6 +102,7 @@ def get_config(
build_user_agent(user_agent),
follow_redirects,
build_timeout(timeout),
build_cache_strategy(cache_strategy),
http_cache,
build_auto_retry(auto_retry),
rest_api_validate_body,
Expand Down
Loading