Skip to content

Commit

Permalink
Merge pull request #267 from ydb-platform/sqlalchemy-fix
Browse files Browse the repository at this point in the history
sqlalchemy: validate identifiers to prevent injection
  • Loading branch information
Valeria1235 authored Apr 3, 2023
2 parents 5bad7e2 + 50de0d3 commit b41e3e2
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 6 deletions.
23 changes: 17 additions & 6 deletions ydb/_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import uuid
import decimal
import string

import ydb
from .errors import DatabaseError, ProgrammingError
Expand All @@ -11,6 +12,17 @@
logger = logging.getLogger(__name__)


identifier_starts = {x for x in itertools.chain(string.ascii_letters, "_")}
valid_identifier_chars = {x for x in itertools.chain(identifier_starts, string.digits)}


def check_identifier_valid(idt: str):
valid = idt and idt[0] in identifier_starts and all(c in valid_identifier_chars for c in idt)
if not valid:
raise ProgrammingError(f"Invalid identifier {idt}")
return valid


def get_column_type(type_obj):
return str(ydb.convert.type_to_native(type_obj))

Expand Down Expand Up @@ -48,7 +60,7 @@ def _generate_type_str(value):
stype = f"Set<{nested_type}>"

if stype is None:
raise ProgrammingError("Cannot translate python type to ydb type.", tvalue, value)
raise ProgrammingError(f"Cannot translate value {value} (type {tvalue}) to ydb type.")

return stype

Expand All @@ -70,6 +82,8 @@ def execute(self, sql, parameters=None, context=None):
sql_params = None

if parameters:
for name in parameters.keys():
check_identifier_valid(name)
sql = sql % {k: f"${k}" for k, v in parameters.items()}
sql_params = {f"${k}": v for k, v in parameters.items()}
declare_stms = _generate_declare_stms(sql_params)
Expand Down Expand Up @@ -137,13 +151,10 @@ def executescript(self, script):
return self.execute(script)

def fetchone(self):
if self.rows is None:
return None
return next(self.rows, None)
return next(self.rows or [], None)

def fetchmany(self, size=None):
size = self.arraysize if size is None else size
return list(itertools.islice(self.rows, size))
return list(itertools.islice(self.rows, size or self.arraysize))

def fetchall(self):
return list(self.rows)
Expand Down
55 changes: 55 additions & 0 deletions ydb/_dbapi/test_cursor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pytest
import uuid
import decimal
from datetime import date, datetime, timedelta

from .cursor import _generate_type_str, check_identifier_valid, ProgrammingError


def test_check_identifier_valid():
assert check_identifier_valid("id")
assert check_identifier_valid("_id")
assert check_identifier_valid("id0")
assert check_identifier_valid("foo_bar")
assert check_identifier_valid("foo_bar_1")

with pytest.raises(ProgrammingError):
check_identifier_valid("")

with pytest.raises(ProgrammingError):
check_identifier_valid("01")

with pytest.raises(ProgrammingError):
check_identifier_valid("(a)")

with pytest.raises(ProgrammingError):
check_identifier_valid("drop table")


def test_generate_type_str():
assert _generate_type_str(True) == "Bool"
assert _generate_type_str(1) == "Int64"
assert _generate_type_str("foo") == "Utf8"
assert _generate_type_str(b"foo") == "String"
assert _generate_type_str(3.1415) == "Double"
assert _generate_type_str(uuid.uuid4()) == "Uuid"
assert _generate_type_str(decimal.Decimal("3.1415926535")) == "Decimal(22, 9)"

assert _generate_type_str([1, 2, 3]) == "List<Int64>"
assert _generate_type_str((1, "2", False)) == "Tuple<Int64, Utf8, Bool>"
assert _generate_type_str({1, 2, 3}) == "Set<Int64>"
assert _generate_type_str({"foo": 1, "bar": 2, "kek": 3.14}) == "Struct<foo: Int64, bar: Int64, kek: Double>"

assert _generate_type_str([[1], [2], [3]]) == "List<List<Int64>>"
assert _generate_type_str([{"a": 1, "b": 2}, {"a": 11, "b": 22}]) == "List<Struct<a: Int64, b: Int64>>"
assert _generate_type_str(("foo", [1], 3.14)) == "Tuple<Utf8, List<Int64>, Double>"

assert _generate_type_str(datetime.now()) == "Timestamp"
assert _generate_type_str(date.today()) == "Date"
assert _generate_type_str(timedelta(days=2)) == "Interval"

with pytest.raises(ProgrammingError):
assert _generate_type_str(None)

with pytest.raises(ProgrammingError):
assert _generate_type_str(object())

0 comments on commit b41e3e2

Please sign in to comment.