From 50de0d33c07f0ed384ba0ab7dd19a0057d0fed4d Mon Sep 17 00:00:00 2001 From: Valeriya Popova Date: Thu, 30 Mar 2023 17:06:45 +0300 Subject: [PATCH] sqlalchemy: validate identifiers to prevent injection --- ydb/_dbapi/cursor.py | 23 +++++++++++----- ydb/_dbapi/test_cursor.py | 55 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 6 deletions(-) create mode 100644 ydb/_dbapi/test_cursor.py diff --git a/ydb/_dbapi/cursor.py b/ydb/_dbapi/cursor.py index 9a8243d2..2375e96f 100644 --- a/ydb/_dbapi/cursor.py +++ b/ydb/_dbapi/cursor.py @@ -3,6 +3,7 @@ import logging import uuid import decimal +import string import ydb from .errors import DatabaseError, ProgrammingError @@ -11,6 +12,17 @@ logger = logging.getLogger(__name__) +identifier_starts = {x for x in itertools.chain(string.ascii_letters, "_")} +valid_identifier_chars = {x for x in itertools.chain(identifier_starts, string.digits)} + + +def check_identifier_valid(idt: str): + valid = idt and idt[0] in identifier_starts and all(c in valid_identifier_chars for c in idt) + if not valid: + raise ProgrammingError(f"Invalid identifier {idt}") + return valid + + def get_column_type(type_obj): return str(ydb.convert.type_to_native(type_obj)) @@ -48,7 +60,7 @@ def _generate_type_str(value): stype = f"Set<{nested_type}>" if stype is None: - raise ProgrammingError("Cannot translate python type to ydb type.", tvalue, value) + raise ProgrammingError(f"Cannot translate value {value} (type {tvalue}) to ydb type.") return stype @@ -70,6 +82,8 @@ def execute(self, sql, parameters=None, context=None): sql_params = None if parameters: + for name in parameters.keys(): + check_identifier_valid(name) sql = sql % {k: f"${k}" for k, v in parameters.items()} sql_params = {f"${k}": v for k, v in parameters.items()} declare_stms = _generate_declare_stms(sql_params) @@ -137,13 +151,10 @@ def executescript(self, script): return self.execute(script) def fetchone(self): - if self.rows is None: - return None - return next(self.rows, None) + return next(self.rows or [], None) def fetchmany(self, size=None): - size = self.arraysize if size is None else size - return list(itertools.islice(self.rows, size)) + return list(itertools.islice(self.rows, size or self.arraysize)) def fetchall(self): return list(self.rows) diff --git a/ydb/_dbapi/test_cursor.py b/ydb/_dbapi/test_cursor.py new file mode 100644 index 00000000..d300186e --- /dev/null +++ b/ydb/_dbapi/test_cursor.py @@ -0,0 +1,55 @@ +import pytest +import uuid +import decimal +from datetime import date, datetime, timedelta + +from .cursor import _generate_type_str, check_identifier_valid, ProgrammingError + + +def test_check_identifier_valid(): + assert check_identifier_valid("id") + assert check_identifier_valid("_id") + assert check_identifier_valid("id0") + assert check_identifier_valid("foo_bar") + assert check_identifier_valid("foo_bar_1") + + with pytest.raises(ProgrammingError): + check_identifier_valid("") + + with pytest.raises(ProgrammingError): + check_identifier_valid("01") + + with pytest.raises(ProgrammingError): + check_identifier_valid("(a)") + + with pytest.raises(ProgrammingError): + check_identifier_valid("drop table") + + +def test_generate_type_str(): + assert _generate_type_str(True) == "Bool" + assert _generate_type_str(1) == "Int64" + assert _generate_type_str("foo") == "Utf8" + assert _generate_type_str(b"foo") == "String" + assert _generate_type_str(3.1415) == "Double" + assert _generate_type_str(uuid.uuid4()) == "Uuid" + assert _generate_type_str(decimal.Decimal("3.1415926535")) == "Decimal(22, 9)" + + assert _generate_type_str([1, 2, 3]) == "List" + assert _generate_type_str((1, "2", False)) == "Tuple" + assert _generate_type_str({1, 2, 3}) == "Set" + assert _generate_type_str({"foo": 1, "bar": 2, "kek": 3.14}) == "Struct" + + assert _generate_type_str([[1], [2], [3]]) == "List>" + assert _generate_type_str([{"a": 1, "b": 2}, {"a": 11, "b": 22}]) == "List>" + assert _generate_type_str(("foo", [1], 3.14)) == "Tuple, Double>" + + assert _generate_type_str(datetime.now()) == "Timestamp" + assert _generate_type_str(date.today()) == "Date" + assert _generate_type_str(timedelta(days=2)) == "Interval" + + with pytest.raises(ProgrammingError): + assert _generate_type_str(None) + + with pytest.raises(ProgrammingError): + assert _generate_type_str(object())