Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sqlalchemy: validate identifiers to prevent injection #267

Merged
merged 1 commit into from
Apr 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions ydb/_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import uuid
import decimal
import string

import ydb
from .errors import DatabaseError, ProgrammingError
Expand All @@ -11,6 +12,17 @@
logger = logging.getLogger(__name__)


identifier_starts = {x for x in itertools.chain(string.ascii_letters, "_")}
valid_identifier_chars = {x for x in itertools.chain(identifier_starts, string.digits)}


def check_identifier_valid(idt: str):
valid = idt and idt[0] in identifier_starts and all(c in valid_identifier_chars for c in idt)
if not valid:
raise ProgrammingError(f"Invalid identifier {idt}")
return valid


def get_column_type(type_obj):
return str(ydb.convert.type_to_native(type_obj))

Expand Down Expand Up @@ -48,7 +60,7 @@ def _generate_type_str(value):
stype = f"Set<{nested_type}>"

if stype is None:
raise ProgrammingError("Cannot translate python type to ydb type.", tvalue, value)
raise ProgrammingError(f"Cannot translate value {value} (type {tvalue}) to ydb type.")

return stype

Expand All @@ -70,6 +82,8 @@ def execute(self, sql, parameters=None, context=None):
sql_params = None

if parameters:
for name in parameters.keys():
check_identifier_valid(name)
sql = sql % {k: f"${k}" for k, v in parameters.items()}
sql_params = {f"${k}": v for k, v in parameters.items()}
declare_stms = _generate_declare_stms(sql_params)
Expand Down Expand Up @@ -137,13 +151,10 @@ def executescript(self, script):
return self.execute(script)

def fetchone(self):
if self.rows is None:
return None
return next(self.rows, None)
return next(self.rows or [], None)

def fetchmany(self, size=None):
size = self.arraysize if size is None else size
return list(itertools.islice(self.rows, size))
return list(itertools.islice(self.rows, size or self.arraysize))

def fetchall(self):
return list(self.rows)
Expand Down
55 changes: 55 additions & 0 deletions ydb/_dbapi/test_cursor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pytest
import uuid
import decimal
from datetime import date, datetime, timedelta

from .cursor import _generate_type_str, check_identifier_valid, ProgrammingError


def test_check_identifier_valid():
assert check_identifier_valid("id")
assert check_identifier_valid("_id")
assert check_identifier_valid("id0")
assert check_identifier_valid("foo_bar")
assert check_identifier_valid("foo_bar_1")

with pytest.raises(ProgrammingError):
check_identifier_valid("")

with pytest.raises(ProgrammingError):
check_identifier_valid("01")

with pytest.raises(ProgrammingError):
check_identifier_valid("(a)")

with pytest.raises(ProgrammingError):
check_identifier_valid("drop table")


def test_generate_type_str():
assert _generate_type_str(True) == "Bool"
assert _generate_type_str(1) == "Int64"
assert _generate_type_str("foo") == "Utf8"
assert _generate_type_str(b"foo") == "String"
assert _generate_type_str(3.1415) == "Double"
assert _generate_type_str(uuid.uuid4()) == "Uuid"
assert _generate_type_str(decimal.Decimal("3.1415926535")) == "Decimal(22, 9)"

assert _generate_type_str([1, 2, 3]) == "List<Int64>"
assert _generate_type_str((1, "2", False)) == "Tuple<Int64, Utf8, Bool>"
assert _generate_type_str({1, 2, 3}) == "Set<Int64>"
assert _generate_type_str({"foo": 1, "bar": 2, "kek": 3.14}) == "Struct<foo: Int64, bar: Int64, kek: Double>"

assert _generate_type_str([[1], [2], [3]]) == "List<List<Int64>>"
assert _generate_type_str([{"a": 1, "b": 2}, {"a": 11, "b": 22}]) == "List<Struct<a: Int64, b: Int64>>"
assert _generate_type_str(("foo", [1], 3.14)) == "Tuple<Utf8, List<Int64>, Double>"

assert _generate_type_str(datetime.now()) == "Timestamp"
assert _generate_type_str(date.today()) == "Date"
assert _generate_type_str(timedelta(days=2)) == "Interval"

with pytest.raises(ProgrammingError):
assert _generate_type_str(None)

with pytest.raises(ProgrammingError):
assert _generate_type_str(object())