Skip to content

Commit

Permalink
Added annotations to auto-manage the sign/unsign input parameters and…
Browse files Browse the repository at this point in the history
… return (#2457)

* Adding vscode ruff extension and removing snyk

* Removing x-ray extra noise

* Fixed typo

* Expanded test coverage and fixed a few bugs in the sign/unsign annotations

* Fixed comment of the unsign_params annotation

* Forgot to add latest tests :X

Removed unused imports

* Apply black formatter formatting on misformatted files

* Fixed import ordering
  • Loading branch information
jimleroyer authored Feb 13, 2025
1 parent 5bb217c commit bc5ad54
Show file tree
Hide file tree
Showing 8 changed files with 215 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
},
"extensions": [
"bungcip.better-toml",
"charliermarsh.ruff",
"donjayamanne.python-extension-pack",
"eamodio.gitlens",
"fill-labs.dependi",
Expand All @@ -37,7 +38,6 @@
"ms-vsliveshare.vsliveshare",
"mtxr.sqltools",
"mtxr.sqltools-driver-pg",
"pmbenjamin.vscode-snyk",
"timonwong.shellcheck",
"usernamehw.errorlens",
"visualstudioexptteam.vscodeintellicode",
Expand Down
76 changes: 76 additions & 0 deletions app/annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from functools import wraps

# from flask import current_app
from inspect import signature

from app import signer_notification
from app.encryption import SignedNotification, SignedNotifications


def unsign_params(func):
"""
A decorator that verifies the SignedNotification|SignedNotifications typed
arguments of the decorated function using `CryptoSigner().verify`.
Args:
func (callable): The function to be decorated.
Returns:
callable: The wrapped function with verification, un-signing decorated
parameters typed with SignedNotification[s].
The decorated function should expect the first argument to be a signed string.
The decorator will verify this signed string before calling the decorated function.
"""

@wraps(func)
def wrapper(*args, **kwargs):
sig = signature(func)

# Find the parameter annotated with VerifyAndSign
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()

for param_name, param in sig.parameters.items():
if param.annotation in (SignedNotification, SignedNotifications):
signed = bound_args.arguments[param_name]

# Verify the signed string or list of signed strings
if param.annotation is SignedNotification:
verified_value = signer_notification.verify(signed)
elif param.annotation is SignedNotifications:
verified_value = [signer_notification.verify(item) for item in signed]

# Replace the signed value with the verified value
bound_args.arguments[param_name] = verified_value

# Call the decorated function with the verified value
result = func(*bound_args.args, **bound_args.kwargs)
return result

return wrapper


def sign_return(func):
"""
A decorator that signs the result of the decorated function using CryptoSigner.
Args:
func (callable): The function to be decorated.
Returns:
callable: The wrapped function that returns a signed result.
"""

@wraps(func)
def wrapper(*args, **kwargs):
# Call the decorated function with the verified value
result = func(*args, **kwargs)

if isinstance(result, str):
# Sign the str result of the decorated function
signed_result = signer_notification.sign(result)
elif isinstance(result, list):
# Sign the list result of the decorated function
signed_result = [signer_notification.sign(item) for item in result]
else:
signed_result = result

return signed_result

return wrapper
8 changes: 4 additions & 4 deletions app/aws/xray_celery_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
def xray_before_task_publish(
sender=None, headers=None, exchange=None, routing_key=None, properties=None, declare=None, retry_policy=None, **kwargs
):
logger.info(f"xray-celery: before publish: sender={sender}, headers={headers}, kwargs={kwargs}")
logger.debug(f"xray-celery: before publish: sender={sender}, headers={headers}, kwargs={kwargs}")
headers = headers if headers else {}
task_id = headers.get("id")
current_segment = xray_recorder.current_segment()
Expand All @@ -41,7 +41,7 @@ def xray_before_task_publish(


def xray_after_task_publish(headers=None, body=None, exchange=None, routing_key=None, **kwargs):
logger.info(
logger.debug(
f"xray-celery: after publish: headers={headers}, body={body}, exchange={exchange}, routing_key={routing_key}, kwargs={kwargs}"
)
if xray_recorder.current_subsegment():
Expand All @@ -51,7 +51,7 @@ def xray_after_task_publish(headers=None, body=None, exchange=None, routing_key=


def xray_task_prerun(task_id=None, task=None, args=None, **kwargs):
logger.info(f"xray-celery: prerun: task_id={task_id}, task={task}, kwargs={kwargs}")
logger.debug(f"xray-celery: prerun: task_id={task_id}, task={task}, kwargs={kwargs}")
xray_header = construct_xray_header(task.request)
segment = xray_recorder.begin_segment(name=task.name, traceid=xray_header.root, parent_id=xray_header.parent)
segment.save_origin_trace_header(xray_header)
Expand All @@ -61,7 +61,7 @@ def xray_task_prerun(task_id=None, task=None, args=None, **kwargs):


def xray_task_postrun(task_id=None, task=None, args=None, **kwargs):
logger.info(f"xray-celery: postrun: kwargs={kwargs}")
logger.debug(f"xray-celery: postrun: kwargs={kwargs}")
xray_recorder.end_segment()


Expand Down
1 change: 1 addition & 0 deletions app/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing_extensions import NotRequired # type: ignore

SignedNotification = NewType("SignedNotification", str)
SignedNotifications = NewType("SignedNotifications", List[SignedNotification])


class NotificationDictToSign(TypedDict):
Expand Down
17 changes: 13 additions & 4 deletions app/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@ def inflight_prefix(self, suffix: Optional[str] = None, process_type: Optional[s
return f"{Buffer.IN_FLIGHT.value}::{str(process_type)}"
return f"{Buffer.IN_FLIGHT.value}"

def inflight_name(self, receipt: UUID = uuid4(), suffix: Optional[str] = None, process_type: Optional[str] = None) -> str:
def inflight_name(
self,
receipt: UUID = uuid4(),
suffix: Optional[str] = None,
process_type: Optional[str] = None,
) -> str:
return f"{self.inflight_prefix(suffix, process_type)}:{str(receipt)}"


Expand All @@ -69,8 +74,8 @@ def poll(self, count=10) -> tuple[UUID, list[str]]:
can later be used in conjunction with the `acknowledge` function
to confirm that the polled messages were properly processed.
This will delete the in-flight messages and these will not get
back into the main inbox. Failure to achknowledge the polled
messages will get these back into the inbox after a preconfigured
back into the main inbox. Failure to acknowledge the polled
messages will get these back into the inbox after a pre-configured
timeout has passed, ready to be retried.
Args:
Expand Down Expand Up @@ -159,7 +164,11 @@ def expire_inflights(self):
self._expire_inflight_after_seconds,
]
else:
args = [f"{Buffer.IN_FLIGHT.inflight_prefix()}:{self._suffix}*", self._inbox, self._expire_inflight_after_seconds]
args = [
f"{Buffer.IN_FLIGHT.inflight_prefix()}:{self._suffix}*",
self._inbox,
self._expire_inflight_after_seconds,
]
expired = self.scripts[self.LUA_EXPIRE_INFLIGHTS](args=args)
if expired:
put_batch_saving_expiry_metric(self.__metrics_logger, self, len(expired))
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ packages = []
[tool.poetry.scripts]
notify-api = ""

[tool.ruff]
ignore = ["D101", "D102", "D103"]

[tool.pylint]
disable = ["missing-class-docstring", "missing-function-docstring"]

[build-system]
requires = ["poetry>=1.3.2"]
build-backend = "poetry.core.masonry.api"
Expand Down
103 changes: 103 additions & 0 deletions tests/app/test_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import pytest
from itsdangerous.exc import BadSignature

from app import signer_notification
from app.annotations import sign_return, unsign_params
from app.encryption import CryptoSigner, SignedNotification, SignedNotifications


class TestUnsignParamsAnnotation:
@pytest.fixture(scope="class", autouse=True)
def setup_class(self, notify_api):
# We just want to setup the notify_api flask app for tests within the class.
pass

def test_unsign_with_bad_signature_notification(self, notify_api):
@unsign_params
def annotated_unsigned_function(
signed_notification: SignedNotification,
) -> str:
return signed_notification

custom_signer = CryptoSigner()
custom_signer.init_app(notify_api, "shhhhh", "salty")

signed = custom_signer.sign("raw notification")
with pytest.raises(BadSignature):
annotated_unsigned_function(signed)

def test_unsign_with_one_signed_notification(self):
@unsign_params
def func_with_one_signed_notification(
signed_notification: SignedNotification,
) -> str:
return signed_notification

signed = signer_notification.sign("raw notification")
unsigned = func_with_one_signed_notification(signed)
assert unsigned == "raw notification"

def test_unsign_with_non_SignedNotification_parameter(self):
def func_with_one_signed_notification(signed_notification: str):
return signed_notification

signed = "raw notification"
unsigned = func_with_one_signed_notification(signed)
assert unsigned == "raw notification"

def test_unsign_with_list_of_signed_notifications(self):
@unsign_params
def func_with_list_of_signed_notifications(
signed_notifications: SignedNotifications,
):
return signed_notifications

signed = [signer_notification.sign(notification) for notification in ["raw notification 1", "raw notification 2"]]
unsigned = func_with_list_of_signed_notifications(signed)
assert unsigned == ["raw notification 1", "raw notification 2"]

def test_unsign_with_empty_list_of_signed_notifications(self):
@unsign_params
def func_with_list_of_signed_notifications(
signed_notifications: SignedNotifications,
):
return signed_notifications

signed = []
unsigned = func_with_list_of_signed_notifications(signed)
assert unsigned == []

def test_sign_return(self):
@sign_return
def func_to_sign_return():
return "raw notification"

signed = func_to_sign_return()
assert signer_notification.verify(signed) == "raw notification"

def test_sign_return_with_list(self):
@sign_return
def func_to_sign_return():
return ["raw notification 1", "raw notification 2"]

signed = func_to_sign_return()
assert [signer_notification.verify(notification) for notification in signed] == [
"raw notification 1",
"raw notification 2",
]

def test_sign_return_with_empty_list(self):
@sign_return
def func_to_sign_return():
return []

signed = func_to_sign_return()
assert signed == []

def test_sign_return_with_non_string_return(self):
@sign_return
def func_to_sign_return():
return 1

signed = func_to_sign_return()
assert signed == 1
12 changes: 11 additions & 1 deletion tests/app/test_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
from app.encryption import CryptoSigner


@pytest.fixture()
def crypto_signer(notify_api):
signer = CryptoSigner()
signer.init_app(notify_api, "secret", "salt")
yield signer


class TestEncryption:
def test_sign_and_verify(self, notify_api):
signer = CryptoSigner()
Expand Down Expand Up @@ -54,4 +61,7 @@ def test_sign_with_all_keys(self, notify_api):
signer2.init_app(notify_api, "s2", "salt")
signer12 = CryptoSigner()
signer12.init_app(notify_api, ["s1", "s2"], "salt")
assert signer12.sign_with_all_keys("this") == [signer2.sign("this"), signer1.sign("this")]
assert signer12.sign_with_all_keys("this") == [
signer2.sign("this"),
signer1.sign("this"),
]

0 comments on commit bc5ad54

Please sign in to comment.