diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 54a347cb..f5614db1 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,8 +1,12 @@ version: 2 updates: -- package-ecosystem: pip - directory: "/" - schedule: - interval: daily - time: "04:00" - open-pull-requests-limit: 100 + - package-ecosystem: pip + directory: "/" + schedule: + interval: weekly + open-pull-requests-limit: 100 + - package-ecosystem: "github-actions" + open-pull-requests-limit: 99 + directory: "/" + schedule: + interval: weekly diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 68bf5db8..9522cee2 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -65,7 +65,7 @@ jobs: python setup.py sdist bdist_wheel - name: Publish a Python distribution to PyPI if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') - uses: pypa/gh-action-pypi-publish@master + uses: pypa/gh-action-pypi-publish@release/v1 with: user: __token__ password: ${{ secrets.pypi_password }} diff --git a/Makefile b/Makefile index 22f134cf..f20a48d8 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ all: clean test dists .PHONY: test test: - pytest + pytest -v dists: python setup.py sdist diff --git a/dataset/chunked.py b/dataset/chunked.py index a5ca158a..772a2a7c 100644 --- a/dataset/chunked.py +++ b/dataset/chunked.py @@ -80,6 +80,6 @@ def flush(self): if self.callback is not None: self.callback(self.queue) self.queue.sort(key=dict.keys) - for fields, items in itertools.groupby(self.queue, key=dict.keys): + for _, items in itertools.groupby(self.queue, key=dict.keys): self.table.update_many(list(items), self.keys) super().flush() diff --git a/dataset/database.py b/dataset/database.py index d8a07ad6..627716b6 100644 --- a/dataset/database.py +++ b/dataset/database.py @@ -44,7 +44,6 @@ def __init__( self.lock = threading.RLock() self.local = threading.local() - self.connections = {} if len(parsed_url.query): query = parse_qs(parsed_url.query) @@ -54,9 +53,10 @@ def __init__( schema = schema_qs.pop() self.schema = schema - self.engine = create_engine(url, **engine_kwargs) - self.is_postgres = self.engine.dialect.name == "postgresql" - self.is_sqlite = self.engine.dialect.name == "sqlite" + self._engine = create_engine(url, **engine_kwargs) + self.is_postgres = self._engine.dialect.name == "postgresql" + self.is_sqlite = self._engine.dialect.name == "sqlite" + self.is_mysql = "mysql" in self._engine.dialect.dbapi.__name__ if on_connect_statements is None: on_connect_statements = [] @@ -72,7 +72,7 @@ def _run_on_connect(dbapi_con, con_record): on_connect_statements.append("PRAGMA journal_mode=WAL") if len(on_connect_statements): - event.listen(self.engine, "connect", _run_on_connect) + event.listen(self._engine, "connect", _run_on_connect) self.types = Types(is_postgres=self.is_postgres) self.url = url @@ -81,24 +81,25 @@ def _run_on_connect(dbapi_con, con_record): self._tables = {} @property - def executable(self): + def conn(self): """Connection against which statements will be executed.""" - with self.lock: - tid = threading.get_ident() - if tid not in self.connections: - self.connections[tid] = self.engine.connect() - return self.connections[tid] + try: + return self.local.conn + except AttributeError: + self.local.conn = self._engine.connect() + self.local.conn.begin() + return self.local.conn @property def op(self): """Get an alembic operations context.""" - ctx = MigrationContext.configure(self.executable) + ctx = MigrationContext.configure(self.conn) return Operations(ctx) @property def inspect(self): """Get a SQLAlchemy inspector.""" - return inspect(self.executable) + return inspect(self.conn) def has_table(self, name): return self.inspect.has_table(name, schema=self.schema) @@ -106,14 +107,12 @@ def has_table(self, name): @property def metadata(self): """Return a SQLAlchemy schema cache object.""" - return MetaData(schema=self.schema, bind=self.executable) + return MetaData(schema=self.schema) @property def in_transaction(self): """Check if this database is in a transactional context.""" - if not hasattr(self.local, "tx"): - return False - return len(self.local.tx) > 0 + return self.conn.in_transaction() def _flush_tables(self): """Clear the table metadata after transaction rollbacks.""" @@ -125,32 +124,22 @@ def begin(self): No data will be written until the transaction has been committed. """ - if not hasattr(self.local, "tx"): - self.local.tx = [] - self.local.tx.append(self.executable.begin()) + if not self.in_transaction: + self.conn.begin() def commit(self): """Commit the current transaction. Make all statements executed since the transaction was begun permanent. """ - if hasattr(self.local, "tx") and self.local.tx: - tx = self.local.tx.pop() - tx.commit() - # Removed in 2020-12, I'm a bit worried this means that some DDL - # operations in transactions won't cause metadata to refresh any - # more: - # self._flush_tables() + self.conn.commit() def rollback(self): """Roll back the current transaction. Discard all statements executed since the transaction was begun. """ - if hasattr(self.local, "tx") and self.local.tx: - tx = self.local.tx.pop() - tx.rollback() - self._flush_tables() + self.conn.rollback() def __enter__(self): """Start a transaction.""" @@ -170,13 +159,10 @@ def __exit__(self, error_type, error_value, traceback): def close(self): """Close database connections. Makes this object unusable.""" - with self.lock: - for conn in self.connections.values(): - conn.close() - self.connections.clear() - self.engine.dispose() + self.local = threading.local() + self._engine.dispose() self._tables = {} - self.engine = None + self._engine = None @property def tables(self): @@ -322,7 +308,7 @@ def query(self, query, *args, **kwargs): _step = kwargs.pop("_step", QUERY_STEP) if _step is False or _step == 0: _step = None - rp = self.executable.execute(query, *args, **kwargs) + rp = self.conn.execute(query, *args, **kwargs) return ResultIter(rp, row_type=self.row_type, step=_step) def __repr__(self): diff --git a/dataset/table.py b/dataset/table.py index 08b806b2..5b3d0297 100644 --- a/dataset/table.py +++ b/dataset/table.py @@ -116,7 +116,9 @@ def insert(self, row, ensure=None, types=None): Returns the inserted row's primary key. """ row = self._sync_columns(row, ensure, types=types) - res = self.db.executable.execute(self.table.insert(row)) + res = self.db.conn.execute(self.table.insert(), row) + # if not self.db.in_transaction: + # self.db.conn.commit() if len(res.inserted_primary_key) > 0: return res.inserted_primary_key[0] return True @@ -181,7 +183,7 @@ def insert_many(self, rows, chunk_size=1000, ensure=None, types=None): # Insert when chunk_size is fulfilled or this is the last row if len(chunk) == chunk_size or index == len(rows) - 1: chunk = pad_chunk_columns(chunk, columns) - self.table.insert().execute(chunk) + self.db.conn.execute(self.table.insert(), chunk) chunk = [] def update(self, row, keys, ensure=None, types=None, return_count=False): @@ -206,8 +208,8 @@ def update(self, row, keys, ensure=None, types=None, return_count=False): clause = self._args_to_clause(args) if not len(row): return self.count(clause) - stmt = self.table.update(whereclause=clause, values=row) - rp = self.db.executable.execute(stmt) + stmt = self.table.update().where(clause).values(row) + rp = self.db.conn.execute(stmt) if rp.supports_sane_rowcount(): return rp.rowcount if return_count: @@ -241,11 +243,12 @@ def update_many(self, rows, keys, chunk_size=1000, ensure=None, types=None): # Update when chunk_size is fulfilled or this is the last row if len(chunk) == chunk_size or index == len(rows) - 1: cl = [self.table.c[k] == bindparam("_%s" % k) for k in keys] - stmt = self.table.update( - whereclause=and_(True, *cl), - values={col: bindparam(col, required=False) for col in columns}, + stmt = ( + self.table.update() + .where(and_(True, *cl)) + .values({col: bindparam(col, required=False) for col in columns}) ) - self.db.executable.execute(stmt, chunk) + self.db.conn.execute(stmt, chunk) chunk = [] def upsert(self, row, keys, ensure=None, types=None): @@ -293,8 +296,8 @@ def delete(self, *clauses, **filters): if not self.exists: return False clause = self._args_to_clause(filters, clauses=clauses) - stmt = self.table.delete(whereclause=clause) - rp = self.db.executable.execute(stmt) + stmt = self.table.delete().where(clause) + rp = self.db.conn.execute(stmt) return rp.rowcount > 0 def _reflect_table(self): @@ -303,7 +306,10 @@ def _reflect_table(self): self._columns = None try: self._table = SQLATable( - self.name, self.db.metadata, schema=self.db.schema, autoload=True + self.name, + self.db.metadata, + schema=self.db.schema, + autoload_with=self.db.conn, ) except NoSuchTableError: self._table = None @@ -345,7 +351,7 @@ def _sync_table(self, columns): for column in columns: if not column.name == self._primary_id: self._table.append_column(column) - self._table.create(self.db.executable, checkfirst=True) + self._table.create(self.db.conn, checkfirst=True) self._columns = None elif len(columns): with self.db.lock: @@ -353,7 +359,7 @@ def _sync_table(self, columns): self._threading_warn() for column in columns: if not self.has_column(column.name): - self.db.op.add_column(self.name, column, self.db.schema) + self.db.op.add_column(self.name, column, schema=self.db.schema) self._reflect_table() def _sync_columns(self, row, ensure, types=None): @@ -378,6 +384,7 @@ def _sync_columns(self, row, ensure, types=None): _type = self.db.types.guess(value) sync_columns[name] = Column(name, _type) out[name] = value + self._sync_table(sync_columns.values()) return out @@ -500,7 +507,7 @@ def drop_column(self, name): table.drop_column('created_at') """ - if self.db.engine.dialect.name == "sqlite": + if self.db.is_sqlite: raise RuntimeError("SQLite does not support dropping columns.") name = self._get_column_name(name) with self.db.lock: @@ -509,7 +516,7 @@ def drop_column(self, name): return self._threading_warn() - self.db.op.drop_column(self.table.name, name, self.table.schema) + self.db.op.drop_column(self.table.name, name, schema=self.table.schema) self._reflect_table() def drop(self): @@ -520,7 +527,7 @@ def drop(self): with self.db.lock: if self.exists: self._threading_warn() - self.table.drop(self.db.executable, checkfirst=True) + self.table.drop(self.db.conn, checkfirst=True) self._table = None self._columns = None self.db._tables.pop(self.name, None) @@ -581,7 +588,7 @@ def create_index(self, columns, name=None, **kw): kw["mysql_length"] = mysql_length idx = Index(name, *columns, **kw) - idx.create(self.db.executable) + idx.create(self.db.conn) def find(self, *_clauses, **kwargs): """Perform a simple search on the table. @@ -625,14 +632,13 @@ def find(self, *_clauses, **kwargs): order_by = self._args_to_order_by(order_by) args = self._args_to_clause(kwargs, clauses=_clauses) - query = self.table.select(whereclause=args, limit=_limit, offset=_offset) + query = self.table.select().where(args).limit(_limit).offset(_offset) if len(order_by): query = query.order_by(*order_by) - conn = self.db.executable + conn = self.db.conn if _streamed: - conn = self.db.engine.connect() - conn = conn.execution_options(stream_results=True) + conn = self.db._engine.connect().execution_options(stream_results=True) return ResultIter(conn.execute(query), row_type=self.db.row_type, step=_step) @@ -666,9 +672,9 @@ def count(self, *_clauses, **kwargs): return 0 args = self._args_to_clause(kwargs, clauses=_clauses) - query = select([func.count()], whereclause=args) + query = select(func.count()).where(args) query = query.select_from(self.table) - rp = self.db.executable.execute(query) + rp = self.db.conn.execute(query) return rp.fetchone()[0] def __len__(self): @@ -703,11 +709,11 @@ def distinct(self, *args, **_filter): if not len(columns): return iter([]) - q = expression.select( - columns, - distinct=True, - whereclause=clause, - order_by=[c.asc() for c in columns], + q = ( + expression.select(*columns) + .where(clause) + .group_by(*columns) + .order_by(*(c.asc() for c in columns)) ) return self.db.query(q) diff --git a/dataset/util.py b/dataset/util.py index 4fa225d2..cc90ff47 100644 --- a/dataset/util.py +++ b/dataset/util.py @@ -6,23 +6,11 @@ QUERY_STEP = 1000 row_type = OrderedDict -try: - # SQLAlchemy > 1.4.0, new row model. - from sqlalchemy.engine import Row # noqa - def convert_row(row_type, row): - if row is None: - return None - return row_type(row._mapping.items()) - - -except ImportError: - # SQLAlchemy < 1.4.0, no _mapping. - - def convert_row(row_type, row): - if row is None: - return None - return row_type(row.items()) +def convert_row(row_type, row): + if row is None: + return None + return row_type(zip(row._fields, row)) class DatasetException(Exception): @@ -84,12 +72,12 @@ class ResultIter(object): """SQLAlchemy ResultProxies are not iterable to get a list of dictionaries. This is to wrap them.""" - def __init__(self, result_proxy, row_type=row_type, step=None): + def __init__(self, cursor, row_type=row_type, step=None): self.row_type = row_type - self.result_proxy = result_proxy + self.cursor = cursor try: - self.keys = list(result_proxy.keys()) - self._iter = iter_result_proxy(result_proxy, step=step) + self.keys = list(cursor.keys()) + self._iter = iter_result_proxy(cursor, step=step) except ResourceClosedError: self.keys = [] self._iter = iter([]) @@ -107,7 +95,7 @@ def __iter__(self): return self def close(self): - self.result_proxy.close() + self.cursor.close() def normalize_column_name(name): diff --git a/docs/conf.py b/docs/conf.py index 5ea37fa1..abc9fe61 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -37,8 +37,8 @@ master_doc = "index" # General information about the project. -project = u"dataset" -copyright = u"2013-2021, Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer" +project = "dataset" +copyright = "2013-2021, Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the diff --git a/setup.py b/setup.py index 06913735..aea1b4b4 100644 --- a/setup.py +++ b/setup.py @@ -30,8 +30,8 @@ include_package_data=False, zip_safe=False, install_requires=[ - "sqlalchemy >= 1.3.2, < 2.0.0", - "alembic >= 0.6.2", + "sqlalchemy >= 2.0.15, < 3.0.0", + "alembic >= 1.11.1", "banal >= 1.0.1", ], extras_require={ diff --git a/test/sample_data.py b/test/conftest.py similarity index 63% rename from test/sample_data.py rename to test/conftest.py index f593fbc0..475c3dbb 100644 --- a/test/sample_data.py +++ b/test/conftest.py @@ -1,8 +1,8 @@ -# -*- encoding: utf-8 -*- -from __future__ import unicode_literals - +import pytest from datetime import datetime +import dataset + TEST_CITY_1 = "B€rkeley" TEST_CITY_2 = "G€lway" @@ -15,3 +15,21 @@ {"date": datetime(2011, 1, 2), "temperature": 8, "place": TEST_CITY_1}, {"date": datetime(2011, 1, 3), "temperature": 5, "place": TEST_CITY_1}, ] + + +@pytest.fixture(scope="function") +def db(): + db = dataset.connect() + yield db + db.close() + + +@pytest.fixture(scope="function") +def table(db): + tbl = db["weather"] + tbl.drop() + tbl.insert_many(TEST_DATA) + db.commit() + yield tbl + db.rollback() + db.commit() diff --git a/test/test_database.py b/test/test_database.py new file mode 100644 index 00000000..d3216cbd --- /dev/null +++ b/test/test_database.py @@ -0,0 +1,148 @@ +import os +import pytest +from datetime import datetime +from collections import OrderedDict +from sqlalchemy.exc import IntegrityError, SQLAlchemyError + +from dataset import connect + +from .conftest import TEST_DATA + + +def test_valid_database_url(db): + assert db.url, os.environ["DATABASE_URL"] + + +def test_database_url_query_string(db): + db = connect("sqlite:///:memory:/?cached_statements=1") + assert "cached_statements" in db.url, db.url + + +def test_tables(db, table): + assert db.tables == ["weather"], db.tables + + +def test_contains(db, table): + assert "weather" in db, db.tables + + +def test_create_table(db): + table = db["foo"] + assert db.has_table(table.table.name) + assert len(table.table.columns) == 1, table.table.columns + assert "id" in table.table.c, table.table.c + + +def test_create_table_no_ids(db): + if db.is_mysql or db.is_sqlite: + return + table = db.create_table("foo_no_id", primary_id=False) + assert table.table.name == "foo_no_id" + assert len(table.table.columns) == 0, table.table.columns + + +def test_create_table_custom_id1(db): + pid = "string_id" + table = db.create_table("foo2", pid, db.types.string(255)) + assert db.has_table(table.table.name) + assert len(table.table.columns) == 1, table.table.columns + assert pid in table.table.c, table.table.c + table.insert({pid: "foobar"}) + assert table.find_one(string_id="foobar")[pid] == "foobar" + + +def test_create_table_custom_id2(db): + pid = "string_id" + table = db.create_table("foo3", pid, db.types.string(50)) + assert db.has_table(table.table.name) + assert len(table.table.columns) == 1, table.table.columns + assert pid in table.table.c, table.table.c + + table.insert({pid: "foobar"}) + assert table.find_one(string_id="foobar")[pid] == "foobar" + + +def test_create_table_custom_id3(db): + pid = "int_id" + table = db.create_table("foo4", primary_id=pid) + assert db.has_table(table.table.name) + assert len(table.table.columns) == 1, table.table.columns + assert pid in table.table.c, table.table.c + + table.insert({pid: 123}) + table.insert({pid: 124}) + assert table.find_one(int_id=123)[pid] == 123 + assert table.find_one(int_id=124)[pid] == 124 + with pytest.raises(IntegrityError): + table.insert({pid: 123}) + db.rollback() + + +def test_create_table_shorthand1(db): + pid = "int_id" + table = db.get_table("foo5", pid) + assert len(table.table.columns) == 1, table.table.columns + assert pid in table.table.c, table.table.c + + table.insert({"int_id": 123}) + table.insert({"int_id": 124}) + assert table.find_one(int_id=123)["int_id"] == 123 + assert table.find_one(int_id=124)["int_id"] == 124 + with pytest.raises(IntegrityError): + table.insert({"int_id": 123}) + + +def test_create_table_shorthand2(db): + pid = "string_id" + table = db.get_table("foo6", primary_id=pid, primary_type=db.types.string(255)) + assert len(table.table.columns) == 1, table.table.columns + assert pid in table.table.c, table.table.c + + table.insert({"string_id": "foobar"}) + assert table.find_one(string_id="foobar")["string_id"] == "foobar" + + +def test_with(db, table): + init_length = len(table) + with pytest.raises(ValueError): + with db: + table.insert( + { + "date": datetime(2011, 1, 1), + "temperature": 1, + "place": "tmp_place", + } + ) + raise ValueError() + db.rollback() + assert len(table) == init_length + + +def test_invalid_values(db, table): + if db.is_mysql: + # WARNING: mysql seems to be doing some weird type casting + # upon insert. The mysql-python driver is not affected but + # it isn't compatible with Python 3 + # Conclusion: use postgresql. + return + with pytest.raises(SQLAlchemyError): + table.insert({"date": True, "temperature": "wrong_value", "place": "tmp_place"}) + + +def test_load_table(db, table): + tbl = db.load_table("weather") + assert tbl.table.name == table.table.name + + +def test_query(db, table): + r = db.query("SELECT COUNT(*) AS num FROM weather").next() + assert r["num"] == len(TEST_DATA), r + + +def test_table_cache_updates(db): + tbl1 = db.get_table("people") + data = OrderedDict([("first_name", "John"), ("last_name", "Smith")]) + tbl1.insert(data) + data["id"] = 1 + tbl2 = db.get_table("people") + assert dict(tbl2.all().next()) == dict(data), (tbl2.all().next(), data) diff --git a/test/test_dataset.py b/test/test_dataset.py deleted file mode 100644 index f7c94ebc..00000000 --- a/test/test_dataset.py +++ /dev/null @@ -1,602 +0,0 @@ -import os -import unittest -from datetime import datetime -from collections import OrderedDict -from sqlalchemy import TEXT, BIGINT -from sqlalchemy.exc import IntegrityError, SQLAlchemyError, ArgumentError - -from dataset import connect, chunked - -from .sample_data import TEST_DATA, TEST_CITY_1 - - -class DatabaseTestCase(unittest.TestCase): - def setUp(self): - self.db = connect() - self.tbl = self.db["weather"] - self.tbl.insert_many(TEST_DATA) - - def tearDown(self): - for table in self.db.tables: - self.db[table].drop() - - def test_valid_database_url(self): - assert self.db.url, os.environ["DATABASE_URL"] - - def test_database_url_query_string(self): - db = connect("sqlite:///:memory:/?cached_statements=1") - assert "cached_statements" in db.url, db.url - - def test_tables(self): - assert self.db.tables == ["weather"], self.db.tables - - def test_contains(self): - assert "weather" in self.db, self.db.tables - - def test_create_table(self): - table = self.db["foo"] - assert self.db.has_table(table.table.name) - assert len(table.table.columns) == 1, table.table.columns - assert "id" in table.table.c, table.table.c - - def test_create_table_no_ids(self): - if "mysql" in self.db.engine.dialect.dbapi.__name__: - return - if "sqlite" in self.db.engine.dialect.dbapi.__name__: - return - table = self.db.create_table("foo_no_id", primary_id=False) - assert table.table.exists() - assert len(table.table.columns) == 0, table.table.columns - - def test_create_table_custom_id1(self): - pid = "string_id" - table = self.db.create_table("foo2", pid, self.db.types.string(255)) - assert self.db.has_table(table.table.name) - assert len(table.table.columns) == 1, table.table.columns - assert pid in table.table.c, table.table.c - table.insert({pid: "foobar"}) - assert table.find_one(string_id="foobar")[pid] == "foobar" - - def test_create_table_custom_id2(self): - pid = "string_id" - table = self.db.create_table("foo3", pid, self.db.types.string(50)) - assert self.db.has_table(table.table.name) - assert len(table.table.columns) == 1, table.table.columns - assert pid in table.table.c, table.table.c - - table.insert({pid: "foobar"}) - assert table.find_one(string_id="foobar")[pid] == "foobar" - - def test_create_table_custom_id3(self): - pid = "int_id" - table = self.db.create_table("foo4", primary_id=pid) - assert self.db.has_table(table.table.name) - assert len(table.table.columns) == 1, table.table.columns - assert pid in table.table.c, table.table.c - - table.insert({pid: 123}) - table.insert({pid: 124}) - assert table.find_one(int_id=123)[pid] == 123 - assert table.find_one(int_id=124)[pid] == 124 - self.assertRaises(IntegrityError, lambda: table.insert({pid: 123})) - - def test_create_table_shorthand1(self): - pid = "int_id" - table = self.db.get_table("foo5", pid) - assert table.table.exists - assert len(table.table.columns) == 1, table.table.columns - assert pid in table.table.c, table.table.c - - table.insert({"int_id": 123}) - table.insert({"int_id": 124}) - assert table.find_one(int_id=123)["int_id"] == 123 - assert table.find_one(int_id=124)["int_id"] == 124 - self.assertRaises(IntegrityError, lambda: table.insert({"int_id": 123})) - - def test_create_table_shorthand2(self): - pid = "string_id" - table = self.db.get_table( - "foo6", primary_id=pid, primary_type=self.db.types.string(255) - ) - assert table.table.exists - assert len(table.table.columns) == 1, table.table.columns - assert pid in table.table.c, table.table.c - - table.insert({"string_id": "foobar"}) - assert table.find_one(string_id="foobar")["string_id"] == "foobar" - - def test_with(self): - init_length = len(self.db["weather"]) - with self.assertRaises(ValueError): - with self.db as tx: - tx["weather"].insert( - { - "date": datetime(2011, 1, 1), - "temperature": 1, - "place": "tmp_place", - } - ) - raise ValueError() - assert len(self.db["weather"]) == init_length - - def test_invalid_values(self): - if "mysql" in self.db.engine.dialect.dbapi.__name__: - # WARNING: mysql seems to be doing some weird type casting - # upon insert. The mysql-python driver is not affected but - # it isn't compatible with Python 3 - # Conclusion: use postgresql. - return - with self.assertRaises(SQLAlchemyError): - tbl = self.db["weather"] - tbl.insert( - {"date": True, "temperature": "wrong_value", "place": "tmp_place"} - ) - - def test_load_table(self): - tbl = self.db.load_table("weather") - assert tbl.table.name == self.tbl.table.name - - def test_query(self): - r = self.db.query("SELECT COUNT(*) AS num FROM weather").next() - assert r["num"] == len(TEST_DATA), r - - def test_table_cache_updates(self): - tbl1 = self.db.get_table("people") - data = OrderedDict([("first_name", "John"), ("last_name", "Smith")]) - tbl1.insert(data) - data["id"] = 1 - tbl2 = self.db.get_table("people") - assert dict(tbl2.all().next()) == dict(data), (tbl2.all().next(), data) - - -class TableTestCase(unittest.TestCase): - def setUp(self): - self.db = connect() - self.tbl = self.db["weather"] - for row in TEST_DATA: - self.tbl.insert(row) - - def tearDown(self): - self.tbl.drop() - - def test_insert(self): - assert len(self.tbl) == len(TEST_DATA), len(self.tbl) - last_id = self.tbl.insert( - {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"} - ) - assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) - assert self.tbl.find_one(id=last_id)["place"] == "Berlin" - - def test_insert_ignore(self): - self.tbl.insert_ignore( - {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}, - ["place"], - ) - assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) - self.tbl.insert_ignore( - {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}, - ["place"], - ) - assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) - - def test_insert_ignore_all_key(self): - for i in range(0, 4): - self.tbl.insert_ignore( - {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}, - ["date", "temperature", "place"], - ) - assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) - - def test_insert_json(self): - last_id = self.tbl.insert( - { - "date": datetime(2011, 1, 2), - "temperature": -10, - "place": "Berlin", - "info": { - "currency": "EUR", - "language": "German", - "population": 3292365, - }, - } - ) - assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) - assert self.tbl.find_one(id=last_id)["place"] == "Berlin" - - def test_upsert(self): - self.tbl.upsert( - {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}, - ["place"], - ) - assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) - self.tbl.upsert( - {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}, - ["place"], - ) - assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) - - def test_upsert_single_column(self): - table = self.db["banana_single_col"] - table.upsert({"color": "Yellow"}, ["color"]) - assert len(table) == 1, len(table) - table.upsert({"color": "Yellow"}, ["color"]) - assert len(table) == 1, len(table) - - def test_upsert_all_key(self): - assert len(self.tbl) == len(TEST_DATA), len(self.tbl) - for i in range(0, 2): - self.tbl.upsert( - {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}, - ["date", "temperature", "place"], - ) - assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) - - def test_upsert_id(self): - table = self.db["banana_with_id"] - data = dict(id=10, title="I am a banana!") - table.upsert(data, ["id"]) - assert len(table) == 1, len(table) - - def test_update_while_iter(self): - for row in self.tbl: - row["foo"] = "bar" - self.tbl.update(row, ["place", "date"]) - assert len(self.tbl) == len(TEST_DATA), len(self.tbl) - - def test_weird_column_names(self): - with self.assertRaises(ValueError): - self.tbl.insert( - { - "date": datetime(2011, 1, 2), - "temperature": -10, - "foo.bar": "Berlin", - "qux.bar": "Huhu", - } - ) - - def test_cased_column_names(self): - tbl = self.db["cased_column_names"] - tbl.insert({"place": "Berlin"}) - tbl.insert({"Place": "Berlin"}) - tbl.insert({"PLACE ": "Berlin"}) - assert len(tbl.columns) == 2, tbl.columns - assert len(list(tbl.find(Place="Berlin"))) == 3 - assert len(list(tbl.find(place="Berlin"))) == 3 - assert len(list(tbl.find(PLACE="Berlin"))) == 3 - - def test_invalid_column_names(self): - tbl = self.db["weather"] - with self.assertRaises(ValueError): - tbl.insert({None: "banana"}) - - with self.assertRaises(ValueError): - tbl.insert({"": "banana"}) - - with self.assertRaises(ValueError): - tbl.insert({"-": "banana"}) - - def test_delete(self): - self.tbl.insert( - {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"} - ) - original_count = len(self.tbl) - assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) - # Test bad use of API - with self.assertRaises(ArgumentError): - self.tbl.delete({"place": "Berlin"}) - assert len(self.tbl) == original_count, len(self.tbl) - - assert self.tbl.delete(place="Berlin") is True, "should return 1" - assert len(self.tbl) == len(TEST_DATA), len(self.tbl) - assert self.tbl.delete() is True, "should return non zero" - assert len(self.tbl) == 0, len(self.tbl) - - def test_repr(self): - assert ( - repr(self.tbl) == "" - ), "the representation should be " - - def test_delete_nonexist_entry(self): - assert ( - self.tbl.delete(place="Berlin") is False - ), "entry not exist, should fail to delete" - - def test_find_one(self): - self.tbl.insert( - {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"} - ) - d = self.tbl.find_one(place="Berlin") - assert d["temperature"] == -10, d - d = self.tbl.find_one(place="Atlantis") - assert d is None, d - - def test_count(self): - assert len(self.tbl) == 6, len(self.tbl) - length = self.tbl.count(place=TEST_CITY_1) - assert length == 3, length - - def test_find(self): - ds = list(self.tbl.find(place=TEST_CITY_1)) - assert len(ds) == 3, ds - ds = list(self.tbl.find(place=TEST_CITY_1, _limit=2)) - assert len(ds) == 2, ds - ds = list(self.tbl.find(place=TEST_CITY_1, _limit=2, _step=1)) - assert len(ds) == 2, ds - ds = list(self.tbl.find(place=TEST_CITY_1, _limit=1, _step=2)) - assert len(ds) == 1, ds - ds = list(self.tbl.find(_step=2)) - assert len(ds) == len(TEST_DATA), ds - ds = list(self.tbl.find(order_by=["temperature"])) - assert ds[0]["temperature"] == -1, ds - ds = list(self.tbl.find(order_by=["-temperature"])) - assert ds[0]["temperature"] == 8, ds - ds = list(self.tbl.find(self.tbl.table.columns.temperature > 4)) - assert len(ds) == 3, ds - - def test_find_dsl(self): - ds = list(self.tbl.find(place={"like": "%lw%"})) - assert len(ds) == 3, ds - ds = list(self.tbl.find(temperature={">": 5})) - assert len(ds) == 2, ds - ds = list(self.tbl.find(temperature={">=": 5})) - assert len(ds) == 3, ds - ds = list(self.tbl.find(temperature={"<": 0})) - assert len(ds) == 1, ds - ds = list(self.tbl.find(temperature={"<=": 0})) - assert len(ds) == 2, ds - ds = list(self.tbl.find(temperature={"!=": -1})) - assert len(ds) == 5, ds - ds = list(self.tbl.find(temperature={"between": [5, 8]})) - assert len(ds) == 3, ds - ds = list(self.tbl.find(place={"=": "G€lway"})) - assert len(ds) == 3, ds - ds = list(self.tbl.find(place={"ilike": "%LwAy"})) - assert len(ds) == 3, ds - - def test_offset(self): - ds = list(self.tbl.find(place=TEST_CITY_1, _offset=1)) - assert len(ds) == 2, ds - ds = list(self.tbl.find(place=TEST_CITY_1, _limit=2, _offset=2)) - assert len(ds) == 1, ds - - def test_streamed(self): - ds = list(self.tbl.find(place=TEST_CITY_1, _streamed=True, _step=1)) - assert len(ds) == 3, len(ds) - for row in self.tbl.find(place=TEST_CITY_1, _streamed=True, _step=1): - row["temperature"] = -1 - self.tbl.update(row, ["id"]) - - def test_distinct(self): - x = list(self.tbl.distinct("place")) - assert len(x) == 2, x - x = list(self.tbl.distinct("place", "date")) - assert len(x) == 6, x - x = list( - self.tbl.distinct( - "place", - "date", - self.tbl.table.columns.date >= datetime(2011, 1, 2, 0, 0), - ) - ) - assert len(x) == 4, x - - x = list(self.tbl.distinct("temperature", place="B€rkeley")) - assert len(x) == 3, x - x = list(self.tbl.distinct("temperature", place=["B€rkeley", "G€lway"])) - assert len(x) == 6, x - - def test_insert_many(self): - data = TEST_DATA * 100 - self.tbl.insert_many(data, chunk_size=13) - assert len(self.tbl) == len(data) + 6, (len(self.tbl), len(data)) - - def test_chunked_insert(self): - data = TEST_DATA * 100 - with chunked.ChunkedInsert(self.tbl) as chunk_tbl: - for item in data: - chunk_tbl.insert(item) - assert len(self.tbl) == len(data) + 6, (len(self.tbl), len(data)) - - def test_chunked_insert_callback(self): - data = TEST_DATA * 100 - N = 0 - - def callback(queue): - nonlocal N - N += len(queue) - - with chunked.ChunkedInsert(self.tbl, callback=callback) as chunk_tbl: - for item in data: - chunk_tbl.insert(item) - assert len(data) == N - assert len(self.tbl) == len(data) + 6 - - def test_update_many(self): - tbl = self.db["update_many_test"] - tbl.insert_many([dict(temp=10), dict(temp=20), dict(temp=30)]) - tbl.update_many([dict(id=1, temp=50), dict(id=3, temp=50)], "id") - - # Ensure data has been updated. - assert tbl.find_one(id=1)["temp"] == tbl.find_one(id=3)["temp"] - - def test_chunked_update(self): - tbl = self.db["update_many_test"] - tbl.insert_many( - [ - dict(temp=10, location="asdf"), - dict(temp=20, location="qwer"), - dict(temp=30, location="asdf"), - ] - ) - - chunked_tbl = chunked.ChunkedUpdate(tbl, "id") - chunked_tbl.update(dict(id=1, temp=50)) - chunked_tbl.update(dict(id=2, location="asdf")) - chunked_tbl.update(dict(id=3, temp=50)) - chunked_tbl.flush() - - # Ensure data has been updated. - assert tbl.find_one(id=1)["temp"] == tbl.find_one(id=3)["temp"] == 50 - assert ( - tbl.find_one(id=2)["location"] == tbl.find_one(id=3)["location"] == "asdf" - ) # noqa - - def test_upsert_many(self): - # Also tests updating on records with different attributes - tbl = self.db["upsert_many_test"] - - W = 100 - tbl.upsert_many([dict(age=10), dict(weight=W)], "id") - assert tbl.find_one(id=1)["age"] == 10 - - tbl.upsert_many([dict(id=1, age=70), dict(id=2, weight=W / 2)], "id") - assert tbl.find_one(id=2)["weight"] == W / 2 - - def test_drop_operations(self): - assert self.tbl._table is not None, "table shouldn't be dropped yet" - self.tbl.drop() - assert self.tbl._table is None, "table should be dropped now" - assert list(self.tbl.all()) == [], self.tbl.all() - assert self.tbl.count() == 0, self.tbl.count() - - def test_table_drop(self): - assert "weather" in self.db - self.db["weather"].drop() - assert "weather" not in self.db - - def test_table_drop_then_create(self): - assert "weather" in self.db - self.db["weather"].drop() - assert "weather" not in self.db - self.db["weather"].insert({"foo": "bar"}) - - def test_columns(self): - cols = self.tbl.columns - assert len(list(cols)) == 4, "column count mismatch" - assert "date" in cols and "temperature" in cols and "place" in cols - - def test_drop_column(self): - try: - self.tbl.drop_column("date") - assert "date" not in self.tbl.columns - except RuntimeError: - pass - - def test_iter(self): - c = 0 - for row in self.tbl: - c += 1 - assert c == len(self.tbl) - - def test_update(self): - date = datetime(2011, 1, 2) - res = self.tbl.update( - {"date": date, "temperature": -10, "place": TEST_CITY_1}, ["place", "date"] - ) - assert res, "update should return True" - m = self.tbl.find_one(place=TEST_CITY_1, date=date) - assert m["temperature"] == -10, ( - "new temp. should be -10 but is %d" % m["temperature"] - ) - - def test_create_column(self): - tbl = self.tbl - flt = self.db.types.float - tbl.create_column("foo", flt) - assert "foo" in tbl.table.c, tbl.table.c - assert isinstance(tbl.table.c["foo"].type, flt), tbl.table.c["foo"].type - assert "foo" in tbl.columns, tbl.columns - - def test_ensure_column(self): - tbl = self.tbl - flt = self.db.types.float - tbl.create_column_by_example("foo", 0.1) - assert "foo" in tbl.table.c, tbl.table.c - assert isinstance(tbl.table.c["foo"].type, flt), tbl.table.c["bar"].type - tbl.create_column_by_example("bar", 1) - assert "bar" in tbl.table.c, tbl.table.c - assert isinstance(tbl.table.c["bar"].type, BIGINT), tbl.table.c["bar"].type - tbl.create_column_by_example("pippo", "test") - assert "pippo" in tbl.table.c, tbl.table.c - assert isinstance(tbl.table.c["pippo"].type, TEXT), tbl.table.c["pippo"].type - tbl.create_column_by_example("bigbar", 11111111111) - assert "bigbar" in tbl.table.c, tbl.table.c - assert isinstance(tbl.table.c["bigbar"].type, BIGINT), tbl.table.c[ - "bigbar" - ].type - tbl.create_column_by_example("littlebar", -11111111111) - assert "littlebar" in tbl.table.c, tbl.table.c - assert isinstance(tbl.table.c["littlebar"].type, BIGINT), tbl.table.c[ - "littlebar" - ].type - - def test_key_order(self): - res = self.db.query("SELECT temperature, place FROM weather LIMIT 1") - keys = list(res.next().keys()) - assert keys[0] == "temperature" - assert keys[1] == "place" - - def test_empty_query(self): - empty = list(self.tbl.find(place="not in data")) - assert len(empty) == 0, empty - - -class Constructor(dict): - """Very simple low-functionality extension to ``dict`` to - provide attribute access to dictionary contents""" - - def __getattr__(self, name): - return self[name] - - -class RowTypeTestCase(unittest.TestCase): - def setUp(self): - self.db = connect(row_type=Constructor) - self.tbl = self.db["weather"] - for row in TEST_DATA: - self.tbl.insert(row) - - def tearDown(self): - for table in self.db.tables: - self.db[table].drop() - - def test_find_one(self): - self.tbl.insert( - {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"} - ) - d = self.tbl.find_one(place="Berlin") - assert d["temperature"] == -10, d - assert d.temperature == -10, d - d = self.tbl.find_one(place="Atlantis") - assert d is None, d - - def test_find(self): - ds = list(self.tbl.find(place=TEST_CITY_1)) - assert len(ds) == 3, ds - for item in ds: - assert isinstance(item, Constructor), item - ds = list(self.tbl.find(place=TEST_CITY_1, _limit=2)) - assert len(ds) == 2, ds - for item in ds: - assert isinstance(item, Constructor), item - - def test_distinct(self): - x = list(self.tbl.distinct("place")) - assert len(x) == 2, x - for item in x: - assert isinstance(item, Constructor), item - x = list(self.tbl.distinct("place", "date")) - assert len(x) == 6, x - for item in x: - assert isinstance(item, Constructor), item - - def test_iter(self): - c = 0 - for row in self.tbl: - c += 1 - assert isinstance(row, Constructor), row - assert c == len(self.tbl) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_row_type.py b/test/test_row_type.py new file mode 100644 index 00000000..801a5158 --- /dev/null +++ b/test/test_row_type.py @@ -0,0 +1,54 @@ +from datetime import datetime + +from .conftest import TEST_CITY_1 + + +class Constructor(dict): + """Very simple low-functionality extension to ``dict`` to + provide attribute access to dictionary contents""" + + def __getattr__(self, name): + return self[name] + + +def test_find_one(db, table): + db.row_type = Constructor + table.insert({"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}) + d = table.find_one(place="Berlin") + assert d["temperature"] == -10, d + assert d.temperature == -10, d + d = table.find_one(place="Atlantis") + assert d is None, d + + +def test_find(db, table): + db.row_type = Constructor + ds = list(table.find(place=TEST_CITY_1)) + assert len(ds) == 3, ds + for item in ds: + assert isinstance(item, Constructor), item + ds = list(table.find(place=TEST_CITY_1, _limit=2)) + assert len(ds) == 2, ds + for item in ds: + assert isinstance(item, Constructor), item + + +def test_distinct(db, table): + db.row_type = Constructor + x = list(table.distinct("place")) + assert len(x) == 2, x + for item in x: + assert isinstance(item, Constructor), item + x = list(table.distinct("place", "date")) + assert len(x) == 6, x + for item in x: + assert isinstance(item, Constructor), item + + +def test_iter(db, table): + db.row_type = Constructor + c = 0 + for row in table: + c += 1 + assert isinstance(row, Constructor), row + assert c == len(table) diff --git a/test/test_table.py b/test/test_table.py new file mode 100644 index 00000000..16d444c1 --- /dev/null +++ b/test/test_table.py @@ -0,0 +1,424 @@ +import os +import pytest +from datetime import datetime +from collections import OrderedDict +from sqlalchemy.types import BIGINT, INTEGER, VARCHAR, TEXT +from sqlalchemy.exc import IntegrityError, SQLAlchemyError, ArgumentError + +from dataset import chunked + +from .conftest import TEST_DATA, TEST_CITY_1 + + +def test_insert(table): + assert len(table) == len(TEST_DATA), len(table) + last_id = table.insert( + {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"} + ) + assert len(table) == len(TEST_DATA) + 1, len(table) + assert table.find_one(id=last_id)["place"] == "Berlin" + + +def test_insert_ignore(table): + table.insert_ignore( + {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}, + ["place"], + ) + assert len(table) == len(TEST_DATA) + 1, len(table) + table.insert_ignore( + {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}, + ["place"], + ) + assert len(table) == len(TEST_DATA) + 1, len(table) + + +def test_insert_ignore_all_key(table): + for i in range(0, 4): + table.insert_ignore( + {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}, + ["date", "temperature", "place"], + ) + assert len(table) == len(TEST_DATA) + 1, len(table) + + +def test_insert_json(table): + last_id = table.insert( + { + "date": datetime(2011, 1, 2), + "temperature": -10, + "place": "Berlin", + "info": { + "currency": "EUR", + "language": "German", + "population": 3292365, + }, + } + ) + assert len(table) == len(TEST_DATA) + 1, len(table) + assert table.find_one(id=last_id)["place"] == "Berlin" + + +def test_upsert(table): + table.upsert( + {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}, + ["place"], + ) + assert len(table) == len(TEST_DATA) + 1, len(table) + table.upsert( + {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}, + ["place"], + ) + assert len(table) == len(TEST_DATA) + 1, len(table) + + +def test_upsert_single_column(db): + table = db["banana_single_col"] + table.upsert({"color": "Yellow"}, ["color"]) + assert len(table) == 1, len(table) + table.upsert({"color": "Yellow"}, ["color"]) + assert len(table) == 1, len(table) + + +def test_upsert_all_key(table): + assert len(table) == len(TEST_DATA), len(table) + for i in range(0, 2): + table.upsert( + {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}, + ["date", "temperature", "place"], + ) + assert len(table) == len(TEST_DATA) + 1, len(table) + + +def test_upsert_id(db): + table = db["banana_with_id"] + data = dict(id=10, title="I am a banana!") + table.upsert(data, ["id"]) + assert len(table) == 1, len(table) + + +def test_update_while_iter(table): + for row in table: + row["foo"] = "bar" + table.update(row, ["place", "date"]) + assert len(table) == len(TEST_DATA), len(table) + + +def test_weird_column_names(table): + with pytest.raises(ValueError): + table.insert( + { + "date": datetime(2011, 1, 2), + "temperature": -10, + "foo.bar": "Berlin", + "qux.bar": "Huhu", + } + ) + + +def test_cased_column_names(db): + tbl = db["cased_column_names"] + tbl.insert({"place": "Berlin"}) + tbl.insert({"Place": "Berlin"}) + tbl.insert({"PLACE ": "Berlin"}) + assert len(tbl.columns) == 2, tbl.columns + assert len(list(tbl.find(Place="Berlin"))) == 3 + assert len(list(tbl.find(place="Berlin"))) == 3 + assert len(list(tbl.find(PLACE="Berlin"))) == 3 + + +def test_invalid_column_names(db): + tbl = db["weather"] + with pytest.raises(ValueError): + tbl.insert({None: "banana"}) + + with pytest.raises(ValueError): + tbl.insert({"": "banana"}) + + with pytest.raises(ValueError): + tbl.insert({"-": "banana"}) + + +def test_delete(table): + table.insert({"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}) + original_count = len(table) + assert len(table) == len(TEST_DATA) + 1, len(table) + # Test bad use of API + with pytest.raises(ArgumentError): + table.delete({"place": "Berlin"}) + assert len(table) == original_count, len(table) + + assert table.delete(place="Berlin") is True, "should return 1" + assert len(table) == len(TEST_DATA), len(table) + assert table.delete() is True, "should return non zero" + assert len(table) == 0, len(table) + + +def test_repr(table): + assert ( + repr(table) == "" + ), "the representation should be " + + +def test_delete_nonexist_entry(table): + assert ( + table.delete(place="Berlin") is False + ), "entry not exist, should fail to delete" + + +def test_find_one(table): + table.insert({"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}) + d = table.find_one(place="Berlin") + assert d["temperature"] == -10, d + d = table.find_one(place="Atlantis") + assert d is None, d + + +def test_count(table): + assert len(table) == 6, len(table) + length = table.count(place=TEST_CITY_1) + assert length == 3, length + + +def test_find(table): + ds = list(table.find(place=TEST_CITY_1)) + assert len(ds) == 3, ds + ds = list(table.find(place=TEST_CITY_1, _limit=2)) + assert len(ds) == 2, ds + ds = list(table.find(place=TEST_CITY_1, _limit=2, _step=1)) + assert len(ds) == 2, ds + ds = list(table.find(place=TEST_CITY_1, _limit=1, _step=2)) + assert len(ds) == 1, ds + ds = list(table.find(_step=2)) + assert len(ds) == len(TEST_DATA), ds + ds = list(table.find(order_by=["temperature"])) + assert ds[0]["temperature"] == -1, ds + ds = list(table.find(order_by=["-temperature"])) + assert ds[0]["temperature"] == 8, ds + ds = list(table.find(table.table.columns.temperature > 4)) + assert len(ds) == 3, ds + + +def test_find_dsl(table): + ds = list(table.find(place={"like": "%lw%"})) + assert len(ds) == 3, ds + ds = list(table.find(temperature={">": 5})) + assert len(ds) == 2, ds + ds = list(table.find(temperature={">=": 5})) + assert len(ds) == 3, ds + ds = list(table.find(temperature={"<": 0})) + assert len(ds) == 1, ds + ds = list(table.find(temperature={"<=": 0})) + assert len(ds) == 2, ds + ds = list(table.find(temperature={"!=": -1})) + assert len(ds) == 5, ds + ds = list(table.find(temperature={"between": [5, 8]})) + assert len(ds) == 3, ds + ds = list(table.find(place={"=": "G€lway"})) + assert len(ds) == 3, ds + ds = list(table.find(place={"ilike": "%LwAy"})) + assert len(ds) == 3, ds + + +def test_offset(table): + ds = list(table.find(place=TEST_CITY_1, _offset=1)) + assert len(ds) == 2, ds + ds = list(table.find(place=TEST_CITY_1, _limit=2, _offset=2)) + assert len(ds) == 1, ds + + +def test_streamed_update(table): + ds = list(table.find(place=TEST_CITY_1, _streamed=True, _step=1)) + assert len(ds) == 3, len(ds) + for row in table.find(place=TEST_CITY_1, _streamed=True, _step=1): + row["temperature"] = -1 + table.update(row, ["id"]) + + +def test_distinct(table): + x = list(table.distinct("place")) + assert len(x) == 2, x + x = list(table.distinct("place", "date")) + assert len(x) == 6, x + x = list( + table.distinct( + "place", + "date", + table.table.columns.date >= datetime(2011, 1, 2, 0, 0), + ) + ) + assert len(x) == 4, x + + x = list(table.distinct("temperature", place="B€rkeley")) + assert len(x) == 3, x + x = list(table.distinct("temperature", place=["B€rkeley", "G€lway"])) + assert len(x) == 6, x + + +def test_insert_many(table): + data = TEST_DATA * 100 + table.insert_many(data, chunk_size=13) + assert len(table) == len(data) + 6, (len(table), len(data)) + + +def test_chunked_insert(table): + data = TEST_DATA * 100 + with chunked.ChunkedInsert(table) as chunk_tbl: + for item in data: + chunk_tbl.insert(item) + assert len(table) == len(data) + 6, (len(table), len(data)) + + +def test_chunked_insert_callback(table): + data = TEST_DATA * 100 + N = 0 + + def callback(queue): + nonlocal N + N += len(queue) + + with chunked.ChunkedInsert(table, callback=callback) as chunk_tbl: + for item in data: + chunk_tbl.insert(item) + assert len(data) == N + assert len(table) == len(data) + 6 + + +def test_update_many(db): + tbl = db["update_many_test"] + tbl.insert_many([dict(temp=10), dict(temp=20), dict(temp=30)]) + tbl.update_many([dict(id=1, temp=50), dict(id=3, temp=50)], "id") + + # Ensure data has been updated. + assert tbl.find_one(id=1)["temp"] == tbl.find_one(id=3)["temp"] + + +def test_chunked_update(db): + tbl = db["update_many_test"] + tbl.insert_many( + [ + dict(temp=10, location="asdf"), + dict(temp=20, location="qwer"), + dict(temp=30, location="asdf"), + ] + ) + db.commit() + + chunked_tbl = chunked.ChunkedUpdate(tbl, ["id"]) + chunked_tbl.update(dict(id=1, temp=50)) + chunked_tbl.update(dict(id=2, location="asdf")) + chunked_tbl.update(dict(id=3, temp=50)) + chunked_tbl.flush() + db.commit() + + # Ensure data has been updated. + assert tbl.find_one(id=1)["temp"] == tbl.find_one(id=3)["temp"] == 50 + assert tbl.find_one(id=2)["location"] == tbl.find_one(id=3)["location"] == "asdf" + + +def test_upsert_many(db): + # Also tests updating on records with different attributes + tbl = db["upsert_many_test"] + + W = 100 + tbl.upsert_many([dict(age=10), dict(weight=W)], "id") + assert tbl.find_one(id=1)["age"] == 10 + + tbl.upsert_many([dict(id=1, age=70), dict(id=2, weight=W / 2)], "id") + assert tbl.find_one(id=2)["weight"] == W / 2 + + +def test_drop_operations(table): + assert table._table is not None, "table shouldn't be dropped yet" + table.drop() + assert table._table is None, "table should be dropped now" + assert list(table.all()) == [], table.all() + assert table.count() == 0, table.count() + + +def test_table_drop(db, table): + assert "weather" in db + db["weather"].drop() + assert "weather" not in db + + +def test_table_drop_then_create(db, table): + assert "weather" in db + db["weather"].drop() + assert "weather" not in db + db["weather"].insert({"foo": "bar"}) + + +def test_columns(table): + cols = table.columns + assert len(list(cols)) == 4, "column count mismatch" + assert "date" in cols and "temperature" in cols and "place" in cols + + +def test_drop_column(table): + try: + table.drop_column("date") + assert "date" not in table.columns + except RuntimeError: + pass + + +def test_iter(table): + c = 0 + for row in table: + c += 1 + assert c == len(table) + + +def test_update(table): + date = datetime(2011, 1, 2) + res = table.update( + {"date": date, "temperature": -10, "place": TEST_CITY_1}, ["place", "date"] + ) + assert res, "update should return True" + m = table.find_one(place=TEST_CITY_1, date=date) + assert m["temperature"] == -10, ( + "new temp. should be -10 but is %d" % m["temperature"] + ) + + +def test_create_column(db, table): + flt = db.types.float + table.create_column("foo", flt) + assert "foo" in table.table.c, table.table.c + assert isinstance(table.table.c["foo"].type, flt), table.table.c["foo"].type + assert "foo" in table.columns, table.columns + + +def test_ensure_column(db, table): + flt = db.types.float + table.create_column_by_example("foo", 0.1) + assert "foo" in table.table.c, table.table.c + assert isinstance(table.table.c["foo"].type, flt), table.table.c["bar"].type + table.create_column_by_example("bar", 1) + assert "bar" in table.table.c, table.table.c + assert isinstance(table.table.c["bar"].type, BIGINT), table.table.c["bar"].type + table.create_column_by_example("pippo", "test") + assert "pippo" in table.table.c, table.table.c + assert isinstance(table.table.c["pippo"].type, TEXT), table.table.c["pippo"].type + table.create_column_by_example("bigbar", 11111111111) + assert "bigbar" in table.table.c, table.table.c + assert isinstance(table.table.c["bigbar"].type, BIGINT), table.table.c[ + "bigbar" + ].type + table.create_column_by_example("littlebar", -11111111111) + assert "littlebar" in table.table.c, table.table.c + assert isinstance(table.table.c["littlebar"].type, BIGINT), table.table.c[ + "littlebar" + ].type + + +def test_key_order(db, table): + res = db.query("SELECT temperature, place FROM weather LIMIT 1") + keys = list(res.next().keys()) + assert keys[0] == "temperature" + assert keys[1] == "place" + + +def test_empty_query(table): + empty = list(table.find(place="not in data")) + assert len(empty) == 0, empty