diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 54a347cb..f5614db1 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,8 +1,12 @@ version: 2 updates: -- package-ecosystem: pip - directory: "/" - schedule: - interval: daily - time: "04:00" - open-pull-requests-limit: 100 + - package-ecosystem: pip + directory: "/" + schedule: + interval: weekly + open-pull-requests-limit: 100 + - package-ecosystem: "github-actions" + open-pull-requests-limit: 99 + directory: "/" + schedule: + interval: weekly diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 68bf5db8..9522cee2 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -65,7 +65,7 @@ jobs: python setup.py sdist bdist_wheel - name: Publish a Python distribution to PyPI if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') - uses: pypa/gh-action-pypi-publish@master + uses: pypa/gh-action-pypi-publish@release/v1 with: user: __token__ password: ${{ secrets.pypi_password }} diff --git a/Makefile b/Makefile index 22f134cf..f20a48d8 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ all: clean test dists .PHONY: test test: - pytest + pytest -v dists: python setup.py sdist diff --git a/dataset/chunked.py b/dataset/chunked.py index a5ca158a..772a2a7c 100644 --- a/dataset/chunked.py +++ b/dataset/chunked.py @@ -80,6 +80,6 @@ def flush(self): if self.callback is not None: self.callback(self.queue) self.queue.sort(key=dict.keys) - for fields, items in itertools.groupby(self.queue, key=dict.keys): + for _, items in itertools.groupby(self.queue, key=dict.keys): self.table.update_many(list(items), self.keys) super().flush() diff --git a/dataset/database.py b/dataset/database.py index d8a07ad6..627716b6 100644 --- a/dataset/database.py +++ b/dataset/database.py @@ -44,7 +44,6 @@ def __init__( self.lock = threading.RLock() self.local = threading.local() - self.connections = {} if len(parsed_url.query): query = parse_qs(parsed_url.query) @@ -54,9 +53,10 @@ def __init__( schema = schema_qs.pop() self.schema = schema - self.engine = create_engine(url, **engine_kwargs) - self.is_postgres = self.engine.dialect.name == "postgresql" - self.is_sqlite = self.engine.dialect.name == "sqlite" + self._engine = create_engine(url, **engine_kwargs) + self.is_postgres = self._engine.dialect.name == "postgresql" + self.is_sqlite = self._engine.dialect.name == "sqlite" + self.is_mysql = "mysql" in self._engine.dialect.dbapi.__name__ if on_connect_statements is None: on_connect_statements = [] @@ -72,7 +72,7 @@ def _run_on_connect(dbapi_con, con_record): on_connect_statements.append("PRAGMA journal_mode=WAL") if len(on_connect_statements): - event.listen(self.engine, "connect", _run_on_connect) + event.listen(self._engine, "connect", _run_on_connect) self.types = Types(is_postgres=self.is_postgres) self.url = url @@ -81,24 +81,25 @@ def _run_on_connect(dbapi_con, con_record): self._tables = {} @property - def executable(self): + def conn(self): """Connection against which statements will be executed.""" - with self.lock: - tid = threading.get_ident() - if tid not in self.connections: - self.connections[tid] = self.engine.connect() - return self.connections[tid] + try: + return self.local.conn + except AttributeError: + self.local.conn = self._engine.connect() + self.local.conn.begin() + return self.local.conn @property def op(self): """Get an alembic operations context.""" - ctx = MigrationContext.configure(self.executable) + ctx = MigrationContext.configure(self.conn) return Operations(ctx) @property def inspect(self): """Get a SQLAlchemy inspector.""" - return inspect(self.executable) + return inspect(self.conn) def has_table(self, name): return self.inspect.has_table(name, schema=self.schema) @@ -106,14 +107,12 @@ def has_table(self, name): @property def metadata(self): """Return a SQLAlchemy schema cache object.""" - return MetaData(schema=self.schema, bind=self.executable) + return MetaData(schema=self.schema) @property def in_transaction(self): """Check if this database is in a transactional context.""" - if not hasattr(self.local, "tx"): - return False - return len(self.local.tx) > 0 + return self.conn.in_transaction() def _flush_tables(self): """Clear the table metadata after transaction rollbacks.""" @@ -125,32 +124,22 @@ def begin(self): No data will be written until the transaction has been committed. """ - if not hasattr(self.local, "tx"): - self.local.tx = [] - self.local.tx.append(self.executable.begin()) + if not self.in_transaction: + self.conn.begin() def commit(self): """Commit the current transaction. Make all statements executed since the transaction was begun permanent. """ - if hasattr(self.local, "tx") and self.local.tx: - tx = self.local.tx.pop() - tx.commit() - # Removed in 2020-12, I'm a bit worried this means that some DDL - # operations in transactions won't cause metadata to refresh any - # more: - # self._flush_tables() + self.conn.commit() def rollback(self): """Roll back the current transaction. Discard all statements executed since the transaction was begun. """ - if hasattr(self.local, "tx") and self.local.tx: - tx = self.local.tx.pop() - tx.rollback() - self._flush_tables() + self.conn.rollback() def __enter__(self): """Start a transaction.""" @@ -170,13 +159,10 @@ def __exit__(self, error_type, error_value, traceback): def close(self): """Close database connections. Makes this object unusable.""" - with self.lock: - for conn in self.connections.values(): - conn.close() - self.connections.clear() - self.engine.dispose() + self.local = threading.local() + self._engine.dispose() self._tables = {} - self.engine = None + self._engine = None @property def tables(self): @@ -322,7 +308,7 @@ def query(self, query, *args, **kwargs): _step = kwargs.pop("_step", QUERY_STEP) if _step is False or _step == 0: _step = None - rp = self.executable.execute(query, *args, **kwargs) + rp = self.conn.execute(query, *args, **kwargs) return ResultIter(rp, row_type=self.row_type, step=_step) def __repr__(self): diff --git a/dataset/table.py b/dataset/table.py index 08b806b2..5b3d0297 100644 --- a/dataset/table.py +++ b/dataset/table.py @@ -116,7 +116,9 @@ def insert(self, row, ensure=None, types=None): Returns the inserted row's primary key. """ row = self._sync_columns(row, ensure, types=types) - res = self.db.executable.execute(self.table.insert(row)) + res = self.db.conn.execute(self.table.insert(), row) + # if not self.db.in_transaction: + # self.db.conn.commit() if len(res.inserted_primary_key) > 0: return res.inserted_primary_key[0] return True @@ -181,7 +183,7 @@ def insert_many(self, rows, chunk_size=1000, ensure=None, types=None): # Insert when chunk_size is fulfilled or this is the last row if len(chunk) == chunk_size or index == len(rows) - 1: chunk = pad_chunk_columns(chunk, columns) - self.table.insert().execute(chunk) + self.db.conn.execute(self.table.insert(), chunk) chunk = [] def update(self, row, keys, ensure=None, types=None, return_count=False): @@ -206,8 +208,8 @@ def update(self, row, keys, ensure=None, types=None, return_count=False): clause = self._args_to_clause(args) if not len(row): return self.count(clause) - stmt = self.table.update(whereclause=clause, values=row) - rp = self.db.executable.execute(stmt) + stmt = self.table.update().where(clause).values(row) + rp = self.db.conn.execute(stmt) if rp.supports_sane_rowcount(): return rp.rowcount if return_count: @@ -241,11 +243,12 @@ def update_many(self, rows, keys, chunk_size=1000, ensure=None, types=None): # Update when chunk_size is fulfilled or this is the last row if len(chunk) == chunk_size or index == len(rows) - 1: cl = [self.table.c[k] == bindparam("_%s" % k) for k in keys] - stmt = self.table.update( - whereclause=and_(True, *cl), - values={col: bindparam(col, required=False) for col in columns}, + stmt = ( + self.table.update() + .where(and_(True, *cl)) + .values({col: bindparam(col, required=False) for col in columns}) ) - self.db.executable.execute(stmt, chunk) + self.db.conn.execute(stmt, chunk) chunk = [] def upsert(self, row, keys, ensure=None, types=None): @@ -293,8 +296,8 @@ def delete(self, *clauses, **filters): if not self.exists: return False clause = self._args_to_clause(filters, clauses=clauses) - stmt = self.table.delete(whereclause=clause) - rp = self.db.executable.execute(stmt) + stmt = self.table.delete().where(clause) + rp = self.db.conn.execute(stmt) return rp.rowcount > 0 def _reflect_table(self): @@ -303,7 +306,10 @@ def _reflect_table(self): self._columns = None try: self._table = SQLATable( - self.name, self.db.metadata, schema=self.db.schema, autoload=True + self.name, + self.db.metadata, + schema=self.db.schema, + autoload_with=self.db.conn, ) except NoSuchTableError: self._table = None @@ -345,7 +351,7 @@ def _sync_table(self, columns): for column in columns: if not column.name == self._primary_id: self._table.append_column(column) - self._table.create(self.db.executable, checkfirst=True) + self._table.create(self.db.conn, checkfirst=True) self._columns = None elif len(columns): with self.db.lock: @@ -353,7 +359,7 @@ def _sync_table(self, columns): self._threading_warn() for column in columns: if not self.has_column(column.name): - self.db.op.add_column(self.name, column, self.db.schema) + self.db.op.add_column(self.name, column, schema=self.db.schema) self._reflect_table() def _sync_columns(self, row, ensure, types=None): @@ -378,6 +384,7 @@ def _sync_columns(self, row, ensure, types=None): _type = self.db.types.guess(value) sync_columns[name] = Column(name, _type) out[name] = value + self._sync_table(sync_columns.values()) return out @@ -500,7 +507,7 @@ def drop_column(self, name): table.drop_column('created_at') """ - if self.db.engine.dialect.name == "sqlite": + if self.db.is_sqlite: raise RuntimeError("SQLite does not support dropping columns.") name = self._get_column_name(name) with self.db.lock: @@ -509,7 +516,7 @@ def drop_column(self, name): return self._threading_warn() - self.db.op.drop_column(self.table.name, name, self.table.schema) + self.db.op.drop_column(self.table.name, name, schema=self.table.schema) self._reflect_table() def drop(self): @@ -520,7 +527,7 @@ def drop(self): with self.db.lock: if self.exists: self._threading_warn() - self.table.drop(self.db.executable, checkfirst=True) + self.table.drop(self.db.conn, checkfirst=True) self._table = None self._columns = None self.db._tables.pop(self.name, None) @@ -581,7 +588,7 @@ def create_index(self, columns, name=None, **kw): kw["mysql_length"] = mysql_length idx = Index(name, *columns, **kw) - idx.create(self.db.executable) + idx.create(self.db.conn) def find(self, *_clauses, **kwargs): """Perform a simple search on the table. @@ -625,14 +632,13 @@ def find(self, *_clauses, **kwargs): order_by = self._args_to_order_by(order_by) args = self._args_to_clause(kwargs, clauses=_clauses) - query = self.table.select(whereclause=args, limit=_limit, offset=_offset) + query = self.table.select().where(args).limit(_limit).offset(_offset) if len(order_by): query = query.order_by(*order_by) - conn = self.db.executable + conn = self.db.conn if _streamed: - conn = self.db.engine.connect() - conn = conn.execution_options(stream_results=True) + conn = self.db._engine.connect().execution_options(stream_results=True) return ResultIter(conn.execute(query), row_type=self.db.row_type, step=_step) @@ -666,9 +672,9 @@ def count(self, *_clauses, **kwargs): return 0 args = self._args_to_clause(kwargs, clauses=_clauses) - query = select([func.count()], whereclause=args) + query = select(func.count()).where(args) query = query.select_from(self.table) - rp = self.db.executable.execute(query) + rp = self.db.conn.execute(query) return rp.fetchone()[0] def __len__(self): @@ -703,11 +709,11 @@ def distinct(self, *args, **_filter): if not len(columns): return iter([]) - q = expression.select( - columns, - distinct=True, - whereclause=clause, - order_by=[c.asc() for c in columns], + q = ( + expression.select(*columns) + .where(clause) + .group_by(*columns) + .order_by(*(c.asc() for c in columns)) ) return self.db.query(q) diff --git a/dataset/util.py b/dataset/util.py index 4fa225d2..cc90ff47 100644 --- a/dataset/util.py +++ b/dataset/util.py @@ -6,23 +6,11 @@ QUERY_STEP = 1000 row_type = OrderedDict -try: - # SQLAlchemy > 1.4.0, new row model. - from sqlalchemy.engine import Row # noqa - def convert_row(row_type, row): - if row is None: - return None - return row_type(row._mapping.items()) - - -except ImportError: - # SQLAlchemy < 1.4.0, no _mapping. - - def convert_row(row_type, row): - if row is None: - return None - return row_type(row.items()) +def convert_row(row_type, row): + if row is None: + return None + return row_type(zip(row._fields, row)) class DatasetException(Exception): @@ -84,12 +72,12 @@ class ResultIter(object): """SQLAlchemy ResultProxies are not iterable to get a list of dictionaries. This is to wrap them.""" - def __init__(self, result_proxy, row_type=row_type, step=None): + def __init__(self, cursor, row_type=row_type, step=None): self.row_type = row_type - self.result_proxy = result_proxy + self.cursor = cursor try: - self.keys = list(result_proxy.keys()) - self._iter = iter_result_proxy(result_proxy, step=step) + self.keys = list(cursor.keys()) + self._iter = iter_result_proxy(cursor, step=step) except ResourceClosedError: self.keys = [] self._iter = iter([]) @@ -107,7 +95,7 @@ def __iter__(self): return self def close(self): - self.result_proxy.close() + self.cursor.close() def normalize_column_name(name): diff --git a/docs/conf.py b/docs/conf.py index 5ea37fa1..abc9fe61 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -37,8 +37,8 @@ master_doc = "index" # General information about the project. -project = u"dataset" -copyright = u"2013-2021, Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer" +project = "dataset" +copyright = "2013-2021, Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the diff --git a/setup.py b/setup.py index 06913735..aea1b4b4 100644 --- a/setup.py +++ b/setup.py @@ -30,8 +30,8 @@ include_package_data=False, zip_safe=False, install_requires=[ - "sqlalchemy >= 1.3.2, < 2.0.0", - "alembic >= 0.6.2", + "sqlalchemy >= 2.0.15, < 3.0.0", + "alembic >= 1.11.1", "banal >= 1.0.1", ], extras_require={ diff --git a/test/sample_data.py b/test/conftest.py similarity index 63% rename from test/sample_data.py rename to test/conftest.py index f593fbc0..475c3dbb 100644 --- a/test/sample_data.py +++ b/test/conftest.py @@ -1,8 +1,8 @@ -# -*- encoding: utf-8 -*- -from __future__ import unicode_literals - +import pytest from datetime import datetime +import dataset + TEST_CITY_1 = "B€rkeley" TEST_CITY_2 = "G€lway" @@ -15,3 +15,21 @@ {"date": datetime(2011, 1, 2), "temperature": 8, "place": TEST_CITY_1}, {"date": datetime(2011, 1, 3), "temperature": 5, "place": TEST_CITY_1}, ] + + +@pytest.fixture(scope="function") +def db(): + db = dataset.connect() + yield db + db.close() + + +@pytest.fixture(scope="function") +def table(db): + tbl = db["weather"] + tbl.drop() + tbl.insert_many(TEST_DATA) + db.commit() + yield tbl + db.rollback() + db.commit() diff --git a/test/test_database.py b/test/test_database.py new file mode 100644 index 00000000..d3216cbd --- /dev/null +++ b/test/test_database.py @@ -0,0 +1,148 @@ +import os +import pytest +from datetime import datetime +from collections import OrderedDict +from sqlalchemy.exc import IntegrityError, SQLAlchemyError + +from dataset import connect + +from .conftest import TEST_DATA + + +def test_valid_database_url(db): + assert db.url, os.environ["DATABASE_URL"] + + +def test_database_url_query_string(db): + db = connect("sqlite:///:memory:/?cached_statements=1") + assert "cached_statements" in db.url, db.url + + +def test_tables(db, table): + assert db.tables == ["weather"], db.tables + + +def test_contains(db, table): + assert "weather" in db, db.tables + + +def test_create_table(db): + table = db["foo"] + assert db.has_table(table.table.name) + assert len(table.table.columns) == 1, table.table.columns + assert "id" in table.table.c, table.table.c + + +def test_create_table_no_ids(db): + if db.is_mysql or db.is_sqlite: + return + table = db.create_table("foo_no_id", primary_id=False) + assert table.table.name == "foo_no_id" + assert len(table.table.columns) == 0, table.table.columns + + +def test_create_table_custom_id1(db): + pid = "string_id" + table = db.create_table("foo2", pid, db.types.string(255)) + assert db.has_table(table.table.name) + assert len(table.table.columns) == 1, table.table.columns + assert pid in table.table.c, table.table.c + table.insert({pid: "foobar"}) + assert table.find_one(string_id="foobar")[pid] == "foobar" + + +def test_create_table_custom_id2(db): + pid = "string_id" + table = db.create_table("foo3", pid, db.types.string(50)) + assert db.has_table(table.table.name) + assert len(table.table.columns) == 1, table.table.columns + assert pid in table.table.c, table.table.c + + table.insert({pid: "foobar"}) + assert table.find_one(string_id="foobar")[pid] == "foobar" + + +def test_create_table_custom_id3(db): + pid = "int_id" + table = db.create_table("foo4", primary_id=pid) + assert db.has_table(table.table.name) + assert len(table.table.columns) == 1, table.table.columns + assert pid in table.table.c, table.table.c + + table.insert({pid: 123}) + table.insert({pid: 124}) + assert table.find_one(int_id=123)[pid] == 123 + assert table.find_one(int_id=124)[pid] == 124 + with pytest.raises(IntegrityError): + table.insert({pid: 123}) + db.rollback() + + +def test_create_table_shorthand1(db): + pid = "int_id" + table = db.get_table("foo5", pid) + assert len(table.table.columns) == 1, table.table.columns + assert pid in table.table.c, table.table.c + + table.insert({"int_id": 123}) + table.insert({"int_id": 124}) + assert table.find_one(int_id=123)["int_id"] == 123 + assert table.find_one(int_id=124)["int_id"] == 124 + with pytest.raises(IntegrityError): + table.insert({"int_id": 123}) + + +def test_create_table_shorthand2(db): + pid = "string_id" + table = db.get_table("foo6", primary_id=pid, primary_type=db.types.string(255)) + assert len(table.table.columns) == 1, table.table.columns + assert pid in table.table.c, table.table.c + + table.insert({"string_id": "foobar"}) + assert table.find_one(string_id="foobar")["string_id"] == "foobar" + + +def test_with(db, table): + init_length = len(table) + with pytest.raises(ValueError): + with db: + table.insert( + { + "date": datetime(2011, 1, 1), + "temperature": 1, + "place": "tmp_place", + } + ) + raise ValueError() + db.rollback() + assert len(table) == init_length + + +def test_invalid_values(db, table): + if db.is_mysql: + # WARNING: mysql seems to be doing some weird type casting + # upon insert. The mysql-python driver is not affected but + # it isn't compatible with Python 3 + # Conclusion: use postgresql. + return + with pytest.raises(SQLAlchemyError): + table.insert({"date": True, "temperature": "wrong_value", "place": "tmp_place"}) + + +def test_load_table(db, table): + tbl = db.load_table("weather") + assert tbl.table.name == table.table.name + + +def test_query(db, table): + r = db.query("SELECT COUNT(*) AS num FROM weather").next() + assert r["num"] == len(TEST_DATA), r + + +def test_table_cache_updates(db): + tbl1 = db.get_table("people") + data = OrderedDict([("first_name", "John"), ("last_name", "Smith")]) + tbl1.insert(data) + data["id"] = 1 + tbl2 = db.get_table("people") + assert dict(tbl2.all().next()) == dict(data), (tbl2.all().next(), data) diff --git a/test/test_dataset.py b/test/test_dataset.py deleted file mode 100644 index f7c94ebc..00000000 --- a/test/test_dataset.py +++ /dev/null @@ -1,602 +0,0 @@ -import os -import unittest -from datetime import datetime -from collections import OrderedDict -from sqlalchemy import TEXT, BIGINT -from sqlalchemy.exc import IntegrityError, SQLAlchemyError, ArgumentError - -from dataset import connect, chunked - -from .sample_data import TEST_DATA, TEST_CITY_1 - - -class DatabaseTestCase(unittest.TestCase): - def setUp(self): - self.db = connect() - self.tbl = self.db["weather"] - self.tbl.insert_many(TEST_DATA) - - def tearDown(self): - for table in self.db.tables: - self.db[table].drop() - - def test_valid_database_url(self): - assert self.db.url, os.environ["DATABASE_URL"] - - def test_database_url_query_string(self): - db = connect("sqlite:///:memory:/?cached_statements=1") - assert "cached_statements" in db.url, db.url - - def test_tables(self): - assert self.db.tables == ["weather"], self.db.tables - - def test_contains(self): - assert "weather" in self.db, self.db.tables - - def test_create_table(self): - table = self.db["foo"] - assert self.db.has_table(table.table.name) - assert len(table.table.columns) == 1, table.table.columns - assert "id" in table.table.c, table.table.c - - def test_create_table_no_ids(self): - if "mysql" in self.db.engine.dialect.dbapi.__name__: - return - if "sqlite" in self.db.engine.dialect.dbapi.__name__: - return - table = self.db.create_table("foo_no_id", primary_id=False) - assert table.table.exists() - assert len(table.table.columns) == 0, table.table.columns - - def test_create_table_custom_id1(self): - pid = "string_id" - table = self.db.create_table("foo2", pid, self.db.types.string(255)) - assert self.db.has_table(table.table.name) - assert len(table.table.columns) == 1, table.table.columns - assert pid in table.table.c, table.table.c - table.insert({pid: "foobar"}) - assert table.find_one(string_id="foobar")[pid] == "foobar" - - def test_create_table_custom_id2(self): - pid = "string_id" - table = self.db.create_table("foo3", pid, self.db.types.string(50)) - assert self.db.has_table(table.table.name) - assert len(table.table.columns) == 1, table.table.columns - assert pid in table.table.c, table.table.c - - table.insert({pid: "foobar"}) - assert table.find_one(string_id="foobar")[pid] == "foobar" - - def test_create_table_custom_id3(self): - pid = "int_id" - table = self.db.create_table("foo4", primary_id=pid) - assert self.db.has_table(table.table.name) - assert len(table.table.columns) == 1, table.table.columns - assert pid in table.table.c, table.table.c - - table.insert({pid: 123}) - table.insert({pid: 124}) - assert table.find_one(int_id=123)[pid] == 123 - assert table.find_one(int_id=124)[pid] == 124 - self.assertRaises(IntegrityError, lambda: table.insert({pid: 123})) - - def test_create_table_shorthand1(self): - pid = "int_id" - table = self.db.get_table("foo5", pid) - assert table.table.exists - assert len(table.table.columns) == 1, table.table.columns - assert pid in table.table.c, table.table.c - - table.insert({"int_id": 123}) - table.insert({"int_id": 124}) - assert table.find_one(int_id=123)["int_id"] == 123 - assert table.find_one(int_id=124)["int_id"] == 124 - self.assertRaises(IntegrityError, lambda: table.insert({"int_id": 123})) - - def test_create_table_shorthand2(self): - pid = "string_id" - table = self.db.get_table( - "foo6", primary_id=pid, primary_type=self.db.types.string(255) - ) - assert table.table.exists - assert len(table.table.columns) == 1, table.table.columns - assert pid in table.table.c, table.table.c - - table.insert({"string_id": "foobar"}) - assert table.find_one(string_id="foobar")["string_id"] == "foobar" - - def test_with(self): - init_length = len(self.db["weather"]) - with self.assertRaises(ValueError): - with self.db as tx: - tx["weather"].insert( - { - "date": datetime(2011, 1, 1), - "temperature": 1, - "place": "tmp_place", - } - ) - raise ValueError() - assert len(self.db["weather"]) == init_length - - def test_invalid_values(self): - if "mysql" in self.db.engine.dialect.dbapi.__name__: - # WARNING: mysql seems to be doing some weird type casting - # upon insert. The mysql-python driver is not affected but - # it isn't compatible with Python 3 - # Conclusion: use postgresql. - return - with self.assertRaises(SQLAlchemyError): - tbl = self.db["weather"] - tbl.insert( - {"date": True, "temperature": "wrong_value", "place": "tmp_place"} - ) - - def test_load_table(self): - tbl = self.db.load_table("weather") - assert tbl.table.name == self.tbl.table.name - - def test_query(self): - r = self.db.query("SELECT COUNT(*) AS num FROM weather").next() - assert r["num"] == len(TEST_DATA), r - - def test_table_cache_updates(self): - tbl1 = self.db.get_table("people") - data = OrderedDict([("first_name", "John"), ("last_name", "Smith")]) - tbl1.insert(data) - data["id"] = 1 - tbl2 = self.db.get_table("people") - assert dict(tbl2.all().next()) == dict(data), (tbl2.all().next(), data) - - -class TableTestCase(unittest.TestCase): - def setUp(self): - self.db = connect() - self.tbl = self.db["weather"] - for row in TEST_DATA: - self.tbl.insert(row) - - def tearDown(self): - self.tbl.drop() - - def test_insert(self): - assert len(self.tbl) == len(TEST_DATA), len(self.tbl) - last_id = self.tbl.insert( - {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"} - ) - assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) - assert self.tbl.find_one(id=last_id)["place"] == "Berlin" - - def test_insert_ignore(self): - self.tbl.insert_ignore( - {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}, - ["place"], - ) - assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) - self.tbl.insert_ignore( - {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}, - ["place"], - ) - assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) - - def test_insert_ignore_all_key(self): - for i in range(0, 4): - self.tbl.insert_ignore( - {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}, - ["date", "temperature", "place"], - ) - assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) - - def test_insert_json(self): - last_id = self.tbl.insert( - { - "date": datetime(2011, 1, 2), - "temperature": -10, - "place": "Berlin", - "info": { - "currency": "EUR", - "language": "German", - "population": 3292365, - }, - } - ) - assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) - assert self.tbl.find_one(id=last_id)["place"] == "Berlin" - - def test_upsert(self): - self.tbl.upsert( - {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}, - ["place"], - ) - assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) - self.tbl.upsert( - {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}, - ["place"], - ) - assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) - - def test_upsert_single_column(self): - table = self.db["banana_single_col"] - table.upsert({"color": "Yellow"}, ["color"]) - assert len(table) == 1, len(table) - table.upsert({"color": "Yellow"}, ["color"]) - assert len(table) == 1, len(table) - - def test_upsert_all_key(self): - assert len(self.tbl) == len(TEST_DATA), len(self.tbl) - for i in range(0, 2): - self.tbl.upsert( - {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}, - ["date", "temperature", "place"], - ) - assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) - - def test_upsert_id(self): - table = self.db["banana_with_id"] - data = dict(id=10, title="I am a banana!") - table.upsert(data, ["id"]) - assert len(table) == 1, len(table) - - def test_update_while_iter(self): - for row in self.tbl: - row["foo"] = "bar" - self.tbl.update(row, ["place", "date"]) - assert len(self.tbl) == len(TEST_DATA), len(self.tbl) - - def test_weird_column_names(self): - with self.assertRaises(ValueError): - self.tbl.insert( - { - "date": datetime(2011, 1, 2), - "temperature": -10, - "foo.bar": "Berlin", - "qux.bar": "Huhu", - } - ) - - def test_cased_column_names(self): - tbl = self.db["cased_column_names"] - tbl.insert({"place": "Berlin"}) - tbl.insert({"Place": "Berlin"}) - tbl.insert({"PLACE ": "Berlin"}) - assert len(tbl.columns) == 2, tbl.columns - assert len(list(tbl.find(Place="Berlin"))) == 3 - assert len(list(tbl.find(place="Berlin"))) == 3 - assert len(list(tbl.find(PLACE="Berlin"))) == 3 - - def test_invalid_column_names(self): - tbl = self.db["weather"] - with self.assertRaises(ValueError): - tbl.insert({None: "banana"}) - - with self.assertRaises(ValueError): - tbl.insert({"": "banana"}) - - with self.assertRaises(ValueError): - tbl.insert({"-": "banana"}) - - def test_delete(self): - self.tbl.insert( - {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"} - ) - original_count = len(self.tbl) - assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) - # Test bad use of API - with self.assertRaises(ArgumentError): - self.tbl.delete({"place": "Berlin"}) - assert len(self.tbl) == original_count, len(self.tbl) - - assert self.tbl.delete(place="Berlin") is True, "should return 1" - assert len(self.tbl) == len(TEST_DATA), len(self.tbl) - assert self.tbl.delete() is True, "should return non zero" - assert len(self.tbl) == 0, len(self.tbl) - - def test_repr(self): - assert ( - repr(self.tbl) == "