Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(asm): improve dependency for api security #11987

Merged
merged 11 commits into from
Jan 20, 2025
19 changes: 11 additions & 8 deletions ddtrace/appsec/_api_security/api_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@

from ddtrace import constants
from ddtrace._trace._limits import MAX_SPAN_META_VALUE_LEN
from ddtrace.appsec import _processor as appsec_processor
from ddtrace.appsec._asm_request_context import add_context_callback
from ddtrace.appsec._asm_request_context import call_waf_callback
from ddtrace.appsec._asm_request_context import remove_context_callback
from ddtrace.appsec._constants import API_SECURITY
from ddtrace.appsec._constants import SPAN_DATA_NAMES
from ddtrace.internal.logger import get_logger
Expand Down Expand Up @@ -55,6 +51,7 @@ def enable(cls) -> None:
log.debug("Enabling %s", cls.__name__)
cls._instance = cls()
cls._instance.start()

log.debug("%s enabled", cls.__name__)

@classmethod
Expand All @@ -75,12 +72,18 @@ def __init__(self) -> None:
log.debug("%s initialized", self.__class__.__name__)
self._hashtable: collections.OrderedDict[int, float] = collections.OrderedDict()

from ddtrace.appsec import _processor as appsec_processor
import ddtrace.appsec._asm_request_context as _asm_request_context
gnufede marked this conversation as resolved.
Show resolved Hide resolved

self._asm_context = _asm_request_context
self._appsec_processor = appsec_processor

def _stop_service(self) -> None:
remove_context_callback(self._schema_callback, global_callback=True)
self._asm_context.remove_context_callback(self._schema_callback, global_callback=True)
self._hashtable.clear()

def _start_service(self) -> None:
add_context_callback(self._schema_callback, global_callback=True)
self._asm_context.add_context_callback(self._schema_callback, global_callback=True)

def _should_collect_schema(self, env, priority: int) -> bool:
# Rate limit per route
Expand Down Expand Up @@ -143,7 +146,7 @@ def _schema_callback(self, env):
try:
headers = env.waf_addresses.get(SPAN_DATA_NAMES.REQUEST_HEADERS_NO_COOKIES, _sentinel)
if headers is not _sentinel:
appsec_processor._set_headers(root, headers, kind="request")
self._appsec_processor._set_headers(root, headers, kind="request")
except Exception:
log.debug("Failed to enrich request span with headers", exc_info=True)

Expand All @@ -159,7 +162,7 @@ def _schema_callback(self, env):
value = transform(value)
waf_payload[address] = value

result = call_waf_callback(waf_payload)
result = self._asm_context.call_waf_callback(waf_payload)
if result is None:
return
for meta, schema in result.derivatives.items():
Expand Down
6 changes: 4 additions & 2 deletions ddtrace/appsec/_exploit_prevention/stack_traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from ddtrace._trace.span import Span
from ddtrace.appsec._constants import STACK_TRACE
from ddtrace.settings.asm import config as asm_config
import ddtrace.tracer


def report_stack(
Expand All @@ -34,7 +33,10 @@ def report_stack(
return False

if span is None:
span = ddtrace.tracer.current_span()
from ddtrace import tracer

span = tracer.current_span()

if span is None or stack_id is None:
return False
root_span = span._local_root or span
Expand Down
4 changes: 2 additions & 2 deletions ddtrace/appsec/_metrics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ddtrace.appsec import _asm_request_context
from ddtrace.appsec import _constants
from ddtrace.appsec._ddwaf import version as _version
from ddtrace.appsec._deduplications import deduplication
from ddtrace.appsec._processor import ddwaf
from ddtrace.internal import telemetry
from ddtrace.internal.logger import get_logger
from ddtrace.internal.telemetry.constants import TELEMETRY_LOG_LEVEL
Expand All @@ -10,7 +10,7 @@

log = get_logger(__name__)

DDWAF_VERSION = _version()
DDWAF_VERSION = ddwaf.version()


@deduplication
Expand Down
35 changes: 22 additions & 13 deletions ddtrace/appsec/_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,7 @@
from ddtrace.appsec._constants import STACK_TRACE
from ddtrace.appsec._constants import WAF_ACTIONS
from ddtrace.appsec._constants import WAF_DATA_NAMES
from ddtrace.appsec._ddwaf import DDWaf_result
from ddtrace.appsec._ddwaf.ddwaf_types import ddwaf_context_capsule
from ddtrace.appsec._exploit_prevention.stack_traces import report_stack
from ddtrace.appsec._metrics import _set_waf_init_metric
from ddtrace.appsec._metrics import _set_waf_request_metrics
from ddtrace.appsec._metrics import _set_waf_updates_metric
from ddtrace.appsec._trace_utils import _asm_manual_keep
from ddtrace.appsec._utils import has_triggers
from ddtrace.constants import ORIGIN_KEY
Expand Down Expand Up @@ -141,12 +136,11 @@ def enabled(self):

def __post_init__(self) -> None:
from ddtrace.appsec import load_appsec
from ddtrace.appsec._ddwaf import DDWaf

load_appsec()
self.obfuscation_parameter_key_regexp = asm_config._asm_obfuscation_parameter_key_regexp.encode()
self.obfuscation_parameter_value_regexp = asm_config._asm_obfuscation_parameter_value_regexp.encode()
self._rules = None
self._rules: Optional[Dict[str, Any]] = None
try:
with open(self.rule_filename, "r") as f:
self._rules = json.load(f)
Expand All @@ -169,12 +163,15 @@ def __post_init__(self) -> None:
# TODO: try to log reasons
log.error("[DDAS-0001-03] ASM could not read the rule file %s.", self.rule_filename)
raise

def delayed_init(self) -> None:
try:
self._ddwaf = DDWaf(
self._rules, self.obfuscation_parameter_key_regexp, self.obfuscation_parameter_value_regexp
)
if self._rules is not None and not hasattr(self, "_ddwaf"):
self._ddwaf = ddwaf.DDWaf(
self._rules, self.obfuscation_parameter_key_regexp, self.obfuscation_parameter_value_regexp
)
_set_waf_init_metric(self._ddwaf.info)
except ValueError:
except Exception:
# Partial of DDAS-0005-00
log.warning("[DDAS-0005-00] WAF initialization failed")
raise
Expand All @@ -190,6 +187,8 @@ def _update_required(self):
self._addresses_to_keep.add(WAF_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES)

def _update_rules(self, new_rules: Dict[str, Any]) -> bool:
if not hasattr(self, "_ddwaf"):
self.delayed_init()
result = False
if asm_config._asm_static_rule_file is not None:
return result
Expand Down Expand Up @@ -221,6 +220,9 @@ def rasp_sqli_enabled(self) -> bool:
def on_span_start(self, span: Span) -> None:
from ddtrace.contrib import trace_utils

if not hasattr(self, "_ddwaf"):
self.delayed_init()

if span.span_type not in {SpanTypes.WEB, SpanTypes.GRPC}:
return

Expand Down Expand Up @@ -258,12 +260,12 @@ def waf_callable(custom_data=None, **kwargs):
def _waf_action(
self,
span: Span,
ctx: ddwaf_context_capsule,
ctx: "ddwaf.ddwaf_types.ddwaf_context_capsule",
custom_data: Optional[Dict[str, Any]] = None,
crop_trace: Optional[str] = None,
rule_type: Optional[str] = None,
force_sent: bool = False,
) -> Optional[DDWaf_result]:
) -> Optional["ddwaf.DDWaf_result"]:
"""
Call the `WAF` with the given parameters. If `custom_data_names` is specified as
a list of `(WAF_NAME, WAF_STR)` tuples specifying what values of the `WAF_DATA_NAMES`
Expand Down Expand Up @@ -434,3 +436,10 @@ def on_span_finish(self, span: Span) -> None:
del self._span_to_waf_ctx[s]
except Exception: # nosec B110
pass


# load waf at the end only to avoid possible circular imports with gevent
import ddtrace.appsec._ddwaf as ddwaf # noqa: E402
from ddtrace.appsec._metrics import _set_waf_init_metric # noqa: E402
from ddtrace.appsec._metrics import _set_waf_request_metrics # noqa: E402
from ddtrace.appsec._metrics import _set_waf_updates_metric # noqa: E402
gnufede marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 8 additions & 0 deletions tests/appsec/appsec/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def test_enable(tracer):
def test_enable_custom_rules():
with override_global_config(dict(_asm_static_rule_file=rules.RULES_GOOD_PATH)):
processor = AppSecSpanProcessor()
processor.delayed_init()

assert processor.enabled
assert processor.rule_filename == rules.RULES_GOOD_PATH
Expand Down Expand Up @@ -345,13 +346,15 @@ def test_ddwaf_not_raises_exception():
def test_obfuscation_parameter_key_empty():
with override_global_config(dict(_asm_obfuscation_parameter_key_regexp="")):
processor = AppSecSpanProcessor()
processor.delayed_init()

assert processor.enabled


def test_obfuscation_parameter_value_empty():
with override_global_config(dict(_asm_obfuscation_parameter_value_regexp="")):
processor = AppSecSpanProcessor()
processor.delayed_init()

assert processor.enabled

Expand All @@ -361,20 +364,23 @@ def test_obfuscation_parameter_key_and_value_empty():
dict(_asm_obfuscation_parameter_key_regexp="", _asm_obfuscation_parameter_value_regexp="")
):
processor = AppSecSpanProcessor()
processor.delayed_init()

assert processor.enabled


def test_obfuscation_parameter_key_invalid_regex():
with override_global_config(dict(_asm_obfuscation_parameter_key_regexp="(")):
processor = AppSecSpanProcessor()
processor.delayed_init()

assert processor.enabled


def test_obfuscation_parameter_invalid_regex():
with override_global_config(dict(_asm_obfuscation_parameter_value_regexp="(")):
processor = AppSecSpanProcessor()
processor.delayed_init()

assert processor.enabled

Expand All @@ -384,6 +390,7 @@ def test_obfuscation_parameter_key_and_value_invalid_regex():
dict(_asm_obfuscation_parameter_key_regexp="(", _asm_obfuscation_parameter_value_regexp="(")
):
processor = AppSecSpanProcessor()
processor.delayed_init()

assert processor.enabled

Expand Down Expand Up @@ -662,6 +669,7 @@ def test_asm_context_registration(tracer):
def test_required_addresses():
with override_global_config(dict(_asm_static_rule_file=rules.RULES_GOOD_PATH)):
processor = AppSecSpanProcessor()
processor.delayed_init()

assert processor._addresses_to_keep == {
"grpc.server.request.message",
Expand Down
3 changes: 2 additions & 1 deletion tests/appsec/appsec/test_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def test_log_metric_error_ddwaf_init(telemetry_writer):
_asm_static_rule_file=os.path.join(rules.ROOT_DIR, "rules-with-2-errors.json"),
)
):
AppSecSpanProcessor()
processor = AppSecSpanProcessor()
processor.delayed_init()

list_metrics_logs = list(telemetry_writer._logs)
assert len(list_metrics_logs) == 1
Expand Down
Loading