Skip to content

Commit

Permalink
Update type hints (#408)
Browse files Browse the repository at this point in the history
Go from `Optional[str]` to `str | None`

Use `list`, `dict` and so on directly

Use `str | int` instead of `Union[str, int]`

And do all of it consistently.

Co-authored-by: Ozan Göktan <[email protected]>
  • Loading branch information
mathialo and ozangoktan authored Dec 13, 2024
1 parent adb07f5 commit 7a663f8
Show file tree
Hide file tree
Showing 24 changed files with 350 additions and 369 deletions.
2 changes: 2 additions & 0 deletions cognite/extractorutils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@

__version__ = "7.5.4"
from .base import Extractor

__all__ = ["Extractor"]
6 changes: 3 additions & 3 deletions cognite/extractorutils/_inner_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import json
from decimal import Decimal
from typing import Any, Dict, Union
from typing import Any


def _resolve_log_level(level: str) -> int:
Expand All @@ -37,7 +37,7 @@ def resolve_log_level_for_httpx(level: str) -> str:


class _DecimalEncoder(json.JSONEncoder):
def default(self, obj: Any) -> Dict[str, str]:
def default(self, obj: Any) -> dict[str, str]:
if isinstance(obj, Decimal):
return {"type": "decimal_encoded", "value": str(obj)}
return super(_DecimalEncoder, self).default(obj)
Expand All @@ -47,7 +47,7 @@ class _DecimalDecoder(json.JSONDecoder):
def __init__(self, *args: Any, **kwargs: Any) -> None:
json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)

def object_hook(self, obj_dict: Dict[str, str]) -> Union[Dict[str, str], Decimal]:
def object_hook(self, obj_dict: dict[str, str]) -> dict[str, str] | Decimal:
if obj_dict.get("type") == "decimal_encoded":
return Decimal(obj_dict["value"])
return obj_dict
29 changes: 14 additions & 15 deletions cognite/extractorutils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from enum import Enum
from threading import Thread
from types import TracebackType
from typing import Any, Callable, Dict, Generic, Optional, Type, TypeVar
from typing import Any, Callable, Generic, Type, TypeVar

from dotenv import find_dotenv, load_dotenv

Expand All @@ -40,6 +40,7 @@ class ReloadConfigAction(Enum):


CustomConfigClass = TypeVar("CustomConfigClass", bound=BaseConfig)
RunHandle = Callable[[CogniteClient, AbstractStateStore, CustomConfigClass, CancellationToken], None]


class Extractor(Generic[CustomConfigClass]):
Expand Down Expand Up @@ -68,27 +69,25 @@ class Extractor(Generic[CustomConfigClass]):
heartbeat_waiting_time: Time interval between each heartbeat to the extraction pipeline in seconds.
"""

_config_singleton: Optional[CustomConfigClass] = None
_statestore_singleton: Optional[AbstractStateStore] = None
_config_singleton: CustomConfigClass | None = None
_statestore_singleton: AbstractStateStore | None = None

def __init__(
self,
*,
name: str,
description: str,
version: Optional[str] = None,
run_handle: Optional[
Callable[[CogniteClient, AbstractStateStore, CustomConfigClass, CancellationToken], None]
] = None,
version: str | None = None,
run_handle: RunHandle | None = None,
config_class: Type[CustomConfigClass],
metrics: Optional[BaseMetrics] = None,
metrics: BaseMetrics | None = None,
use_default_state_store: bool = True,
cancellation_token: Optional[CancellationToken] = None,
config_file_path: Optional[str] = None,
cancellation_token: CancellationToken | None = None,
config_file_path: str | None = None,
continuous_extractor: bool = False,
heartbeat_waiting_time: int = 600,
handle_interrupts: bool = True,
reload_config_interval: Optional[int] = 300,
reload_config_interval: int | None = 300,
reload_config_action: ReloadConfigAction = ReloadConfigAction.DO_NOTHING,
):
self.name = name
Expand All @@ -111,7 +110,7 @@ def __init__(
self.cognite_client: CogniteClient
self.state_store: AbstractStateStore
self.config: CustomConfigClass
self.extraction_pipeline: Optional[ExtractionPipeline]
self.extraction_pipeline: ExtractionPipeline | None
self.logger: logging.Logger

self.should_be_restarted = False
Expand All @@ -121,7 +120,7 @@ def __init__(
else:
self.metrics = BaseMetrics(extractor_name=name, extractor_version=self.version)

def _initial_load_config(self, override_path: Optional[str] = None) -> None:
def _initial_load_config(self, override_path: str | None = None) -> None:
"""
Load a configuration file, either from the specified path, or by a path specified by the user in a command line
arg. Will quit further execution of no path is given.
Expand Down Expand Up @@ -177,7 +176,7 @@ def _load_state_store(self) -> None:
Either way, the state_store attribute is guaranteed to be set after calling this method.
"""

def recursive_find_state_store(d: Dict[str, Any]) -> Optional[StateStoreConfig]:
def recursive_find_state_store(d: dict[str, Any]) -> StateStoreConfig | None:
for k in d:
if is_dataclass(d[k]):
res = recursive_find_state_store(d[k].__dict__)
Expand Down Expand Up @@ -323,7 +322,7 @@ def heartbeat_loop() -> None:
return self

def __exit__(
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
self, exc_type: Type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> bool:
"""
Shuts down the extractor. Makes sure states are preserved, that all uploads of data and metrics are done, etc.
Expand Down
16 changes: 7 additions & 9 deletions cognite/extractorutils/configtools/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import base64
import re
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable

from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization as serialization
Expand All @@ -24,7 +24,7 @@
from cognite.extractorutils.exceptions import InvalidConfigError


def _to_snake_case(dictionary: Dict[str, Any], case_style: str) -> Dict[str, Any]:
def _to_snake_case(dictionary: dict[str, Any], case_style: str) -> dict[str, Any]:
"""
Ensure that all keys in the dictionary follows the snake casing convention (recursively, so any sub-dictionaries are
changed too).
Expand All @@ -37,11 +37,11 @@ def _to_snake_case(dictionary: Dict[str, Any], case_style: str) -> Dict[str, Any
An updated dictionary with keys in the given convention.
"""

def fix_list(list_: List[Any], key_translator: Callable[[str], str]) -> List[Any]:
def fix_list(list_: list[Any], key_translator: Callable[[str], str]) -> list[Any]:
if list_ is None:
return []

new_list: List[Any] = [None] * len(list_)
new_list: list[Any] = [None] * len(list_)
for i, element in enumerate(list_):
if isinstance(element, dict):
new_list[i] = fix_dict(element, key_translator)
Expand All @@ -51,11 +51,11 @@ def fix_list(list_: List[Any], key_translator: Callable[[str], str]) -> List[Any
new_list[i] = element
return new_list

def fix_dict(dict_: Dict[str, Any], key_translator: Callable[[str], str]) -> Dict[str, Any]:
def fix_dict(dict_: dict[str, Any], key_translator: Callable[[str], str]) -> dict[str, Any]:
if dict_ is None:
return {}

new_dict: Dict[str, Any] = {}
new_dict: dict[str, Any] = {}
for key in dict_:
if isinstance(dict_[key], dict):
new_dict[key_translator(key)] = fix_dict(dict_[key], key_translator)
Expand All @@ -81,9 +81,7 @@ def translate_camel(key: str) -> str:
raise ValueError(f"Invalid case style: {case_style}")


def _load_certificate_data(
cert_path: str | Path, password: Optional[str]
) -> Union[Tuple[str, str], Tuple[bytes, bytes]]:
def _load_certificate_data(cert_path: str | Path, password: str | None) -> tuple[str, str] | tuple[bytes, bytes]:
path = Path(cert_path) if isinstance(cert_path, str) else cert_path
cert_data = Path(path).read_bytes()

Expand Down
Loading

0 comments on commit 7a663f8

Please sign in to comment.